Commit 39172101 by zlj

try all to all operator

parent 8a16ee73
...@@ -83,7 +83,7 @@ for data in "${data_param[@]}"; do ...@@ -83,7 +83,7 @@ for data in "${data_param[@]}"; do
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out & #torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
......
...@@ -383,7 +383,7 @@ def main(): ...@@ -383,7 +383,7 @@ def main():
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
batch_cnt = 0 batch_cnt = 0
for roots,mfgs,metadata in loader: for roots,mfgs,metadata in loader:
print(batch_cnt) print('val batch {}\n'.format(batch_cnt))
batch_cnt = batch_cnt+1 batch_cnt = batch_cnt+1
""" """
if ctx.memory_group == 0: if ctx.memory_group == 0:
...@@ -407,14 +407,18 @@ def main(): ...@@ -407,14 +407,18 @@ def main():
param = (update_mail,src,dst,ts,edge_feats,loader.async_feature) param = (update_mail,src,dst,ts,edge_feats,loader.async_feature)
else: else:
param = None param = None
print('finish stage 2\n')
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
print('finish stage 3\n')
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy())) aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred)) aucs_mrrs.append(roc_auc_score(y_true, y_pred))
print('finish stage 4\n')
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p_mem() mailbox.update_p2p_mem()
mailbox.update_p2p_mail() mailbox.update_p2p_mail()
print('finish {}\n'.format(batch_cnt))
""" """
if mailbox is not None: if mailbox is not None:
src = metadata['src_pos_index'] src = metadata['src_pos_index']
...@@ -453,6 +457,8 @@ def main(): ...@@ -453,6 +457,8 @@ def main():
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce',submit=False) mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce',submit=False)
mailbox.sychronize_shared() mailbox.sychronize_shared()
""" """
torch.cuda.synchronize()
dist.barrier()
ap = torch.empty([1]) ap = torch.empty([1])
auc_mrr = torch.empty([1]) auc_mrr = torch.empty([1])
#if(ctx.memory_group==0): #if(ctx.memory_group==0):
...@@ -624,6 +630,7 @@ def main(): ...@@ -624,6 +630,7 @@ def main():
train_ap = float(torch.tensor(train_aps).mean()) train_ap = float(torch.tensor(train_aps).mean())
print('\ttrain time:{:.2f}s\n'.format(time_prep)) print('\ttrain time:{:.2f}s\n'.format(time_prep))
print(trainloader.local_node) print(trainloader.local_node)
dist.barrier()
local_node=torch.tensor([trainloader.local_node]) local_node=torch.tensor([trainloader.local_node])
remote_node=torch.tensor([trainloader.remote_node]) remote_node=torch.tensor([trainloader.remote_node])
local_edge=torch.tensor([trainloader.local_edge]) local_edge=torch.tensor([trainloader.local_edge])
...@@ -639,6 +646,7 @@ def main(): ...@@ -639,6 +646,7 @@ def main():
print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge)) print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge))
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm)) print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count)) print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
dist.barrier()
#if(e==0): #if(e==0):
# torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all_args.seed/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) # torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all_args.seed/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
ap = 0 ap = 0
...@@ -650,16 +658,17 @@ def main(): ...@@ -650,16 +658,17 @@ def main():
tt.ssim_cnt=0 tt.ssim_cnt=0
ap, auc = eval('val') ap, auc = eval('val')
print('finish val') print('finish val')
#torch.cuda.synchronize() torch.cuda.synchronize()
dist.barrier()
print('start') print('start')
t_test = time.time() t_test = time.time()
print('test') print('test')
test_ap,test_auc = 0,0#eval('test') test_ap,test_auc = 0,0#eval('test')
#torch.cuda.synchronize() torch.cuda.synchronize()
dist.barrier()
t_test = time.time() - t_test t_test = time.time() - t_test
total_test_time += t_test total_test_time += t_test
test_ap_list.append((test_ap,test_auc)) test_ap_list.append((test_ap,test_auc))
early_stopper.early_stop_check(ap)
early_stop = False early_stop = False
trainloader.local_node = 0 trainloader.local_node = 0
trainloader.remote_node = 0 trainloader.remote_node = 0
...@@ -691,6 +700,8 @@ def main(): ...@@ -691,6 +700,8 @@ def main():
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep,t_test)) print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep,t_test))
torch.save(model.module.state_dict(), get_checkpoint_path(e)) torch.save(model.module.state_dict(), get_checkpoint_path(e))
torch.cuda.synchronize()
dist.barrier()
if args.model == 'TGN': if args.model == 'TGN':
pass pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) # print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment