Commit 2fd762b0 by zhlj

fix default probability

parent 3814b57f
...@@ -9,7 +9,7 @@ partitions="4" ...@@ -9,7 +9,7 @@ partitions="4"
node_per="4" node_per="4"
nnodes="1" nnodes="1"
node_rank="0" node_rank="0"
probability_params=("0" "0.1" "0.05" "0.01") probability_params=("0.1" "0" "0.05" "0.01")
sample_type_params=("boundery_recent_decay" "recent" ) sample_type_params=("boundery_recent_decay" "recent" )
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
...@@ -17,7 +17,7 @@ memory_type=( "historical" "local" "all_update") ...@@ -17,7 +17,7 @@ memory_type=( "historical" "local" "all_update")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7") shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=( "TaoBao" "StackOverflow" "GDELT") data_param=( "LASTFM" "TaoBao" "StackOverflow" "GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
......
_1007 8卡都是新跑的
07上_1007是LASTFM 的4卡结果
05上_1007是WIKI的4卡结果
...@@ -63,11 +63,11 @@ parser.add_argument('--partition', default='part', type=str, metavar='W', ...@@ -63,11 +63,11 @@ parser.add_argument('--partition', default='part', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--topk', default='0', type=str, metavar='W', parser.add_argument('--topk', default='0', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--probability', default=True, type=float, metavar='W', parser.add_argument('--probability', default=1, type=float, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W', parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--local_neg_sample', default=True, type=bool, metavar='W', parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W', parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model') help='name of model')
...@@ -656,8 +656,9 @@ def main(): ...@@ -656,8 +656,9 @@ def main():
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep,t_test)) print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep,t_test))
torch.save(model.module.state_dict(), get_checkpoint_path(e)) torch.save(model.module.state_dict(), get_checkpoint_path(e))
if args.model == 'TGN': if args.model == 'TGN':
print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) pass
print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt)) # print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(val_list,'all/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(loss_list,'all/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(test_ap_list,'all/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment