Commit 09fbaa3f by zlj

no update when traiing is false

parent 82d55ff8
...@@ -21,6 +21,6 @@ gnn: ...@@ -21,6 +21,6 @@ gnn:
train: train:
- epoch: 50 - epoch: 50
batch_size: 3000 batch_size: 3000
lr: 0.0002 lr: 0.0004
dropout: 0.1 dropout: 0.1
all_on_gpu: True all_on_gpu: True
\ No newline at end of file
scp -r examples/MemShare_v1/all_12357/LASTFM/TGN/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/LASTFM_TGN_4.out
scp -r examples/MemShare_v1/all_12357/WikiTalk/TGN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/WikiTalk_TGN_4.out
scp -r examples/MemShare_v1/all_12357/StackOverflow/TGN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/StackOverflow_TGN_4.out
scp -r examples/MemShare_v1/all_12357/GDELT/TGN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/GDELT_TGN_4.out
scp -r examples/MemShare_v1/all_12357/LASTFM/APAN/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/LASTFM_APAN_4.out
scp -r examples/MemShare_v1/all_12357/LASTFM/JODIE/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/LASTFM_JODIE_4.out
scp -r examples/MemShare_v1/all_12357/WikiTalk/APAN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/WikiTalk_APAN_4.out
scp -r examples/MemShare_v1/all_12357/WikiTalk/JODIE_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/WikiTalk_JODIE_4.out
scp -r examples/MemShare_v1/all_12357/StackOverflow/APAN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/StackOverflow_APAN_4.out
scp -r examples/MemShare_v1/all_12357/StackOverflow/JODIE_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/StackOverflow_JODIE_4.out
scp -r examples/MemShare_v1/all_12357/GDELT/APAN_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/GDELT_APAN_4.out
scp -r examples/MemShare_v1/all_12357/GDELT/JODIE_large/4-ours-0-all_update-recent.out gpu05:/home/zlj/BTS_MTGNN_paper/rebuttal/baseline/TGL-dist/GDELT_JODIE_4.out
\ No newline at end of file
...@@ -80,7 +80,7 @@ class AdaParameter: ...@@ -80,7 +80,7 @@ class AdaParameter:
start_event.record() start_event.record()
return start_event return start_event
def update_fetch_time(self,start_event): def update_fetch_time(self,start_event):
if start_event is None or self.training: if start_event is None or self.training == False:
return return
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
end_event.record() end_event.record()
...@@ -140,8 +140,10 @@ class AdaParameter: ...@@ -140,8 +140,10 @@ class AdaParameter:
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha)) #print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0: #if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
# return # return
#print(self.training)
if self.training == False: if self.training == False:
return return
#print('{} {} {} {} '.format(self.end_event_fetch,self.end_event_memory_sync,self.end_event_memory_update,self.end_event_gnn_aggregate))
if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None: if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None:
return return
self.end_event_fetch[1].synchronize() self.end_event_fetch[1].synchronize()
...@@ -165,7 +167,7 @@ class AdaParameter: ...@@ -165,7 +167,7 @@ class AdaParameter:
#print(self.alpha) #print(self.alpha)
self.beta = max(min(self.beta, self.max_beta),self.min_beta) self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha) self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
#print(self.beta,self.alpha)
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta]) beta_comm=torch.tensor([self.beta])
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group) torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
...@@ -173,8 +175,8 @@ class AdaParameter: ...@@ -173,8 +175,8 @@ class AdaParameter:
alpha_comm=torch.tensor([self.alpha]) alpha_comm=torch.tensor([self.alpha])
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group) torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()/ctx.world_size self.alpha = alpha_comm[0].item()/ctx.world_size
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time)) #print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha)) #print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time() #self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold) #log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold)) #2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
......
...@@ -281,6 +281,7 @@ class DistributedDataLoader: ...@@ -281,6 +281,7 @@ class DistributedDataLoader:
self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat)) self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat))
self.submit() self.submit()
if self.ada_param is not None: if self.ada_param is not None:
self.ada_param.last_start_event_fetch = self.ada_param.start_event() self.ada_param.last_start_event_fetch = self.ada_param.start_event()
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment