Commit a225e572 by zlj

add seed run test

parent 0dc877d4
bash test_all.sh 13357 > 13357.out
wait
bash test_all.sh 12347 > 12347.out
wait
bash test_all.sh 63377 > 63377.out
wait
bash test_all.sh 53473 > 53473.out
wait
bash test_all.sh 54763 > 54763.out
wait
\ No newline at end of file
...@@ -77,6 +77,8 @@ parser.add_argument('--eval_neg_samples', default=1, type=int, metavar='W', ...@@ -77,6 +77,8 @@ parser.add_argument('--eval_neg_samples', default=1, type=int, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--memory_type', default='all_update', type=str, metavar='W', parser.add_argument('--memory_type', default='all_update', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--seed', default=6773, type=int, metavar='W',
help='name of model')
#boundery_recent_uniform boundery_recent_decay #boundery_recent_uniform boundery_recent_decay
args = parser.parse_args() args = parser.parse_args()
if args.memory_type == 'all_local' or args.topk != '0': if args.memory_type == 'all_local' or args.topk != '0':
...@@ -186,7 +188,6 @@ def query(): ...@@ -186,7 +188,6 @@ def query():
"total_update_mail" :total_update_mail , "total_update_mail" :total_update_mail ,
"total_update_memory":total_update_memory, "total_update_memory":total_update_memory,
"total_remote_update":total_remote_update,} "total_remote_update":total_remote_update,}
seed_everything(34)
def main(): def main():
#torch.autograd.set_detect_anomaly(True) #torch.autograd.set_detect_anomaly(True)
print('LOCAL RANK {}, RANK{}'.format(os.environ["LOCAL_RANK"],os.environ["RANK"])) print('LOCAL RANK {}, RANK{}'.format(os.environ["LOCAL_RANK"],os.environ["RANK"]))
...@@ -270,7 +271,7 @@ def main(): ...@@ -270,7 +271,7 @@ def main():
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
...@@ -610,7 +611,7 @@ def main(): ...@@ -610,7 +611,7 @@ def main():
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm)) print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count)) print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
#if(e==0): #if(e==0):
# torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) # torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all_args.seed/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
ap = 0 ap = 0
auc = 0 auc = 0
tt.ssim_remote=0 tt.ssim_remote=0
...@@ -662,9 +663,9 @@ def main(): ...@@ -662,9 +663,9 @@ def main():
pass pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) # print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt)) # print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(val_list,'all_args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(loss_list,'all_args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(test_ap_list,'all_args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time) print(avg_time)
if not early_stop: if not early_stop:
......
...@@ -18,9 +18,9 @@ class MemoryMoniter: ...@@ -18,9 +18,9 @@ class MemoryMoniter:
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos')) #self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid) #self.nid_list.append(nid)
def draw(self,degree,data,model,e): def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e)) torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e)) torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all/{}/{}/memcos_{}.pt'.format(data,model,e)) torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data) # path = './memory/{}/'.format(data)
# if not os.path.exists(path): # if not os.path.exists(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment