Commit 7ae21c05 by zlj

fix time count

parent 9a31f5e3
...@@ -562,6 +562,14 @@ def main(): ...@@ -562,6 +562,14 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
if len(trainloader.result_queue) > 0:
batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = trainloader.result_queue[0]
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#print(time_count.elapsed_event(t2)) #print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False: if args.local_neg_sample is False:
...@@ -588,7 +596,7 @@ def main(): ...@@ -588,7 +596,7 @@ def main():
mailbox.update_p2p_mem() mailbox.update_p2p_mem()
mailbox.update_p2p_mail() mailbox.update_p2p_mail()
start = time_count.start_gpu() start = time_count.start_gpu()
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
#ada_param.update_parameter() #ada_param.update_parameter()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
......
...@@ -141,9 +141,9 @@ class AdaParameter: ...@@ -141,9 +141,9 @@ class AdaParameter:
if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None: if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None:
return return
self.end_event_fetch[1].synchronize() self.end_event_fetch[1].synchronize()
self.end_event_memory_sync[1].synchronize() #self.end_event_memory_sync[1].synchronize()
self.end_event_memory_update[1].synchronize() #self.end_event_memory_update[1].synchronize()
self.end_event_gnn_aggregate[1].synchronize() #self.end_event_gnn_aggregate[1].synchronize()
self.average_fetch += self.end_event_fetch[0].elapsed_time(self.end_event_fetch[1]) self.average_fetch += self.end_event_fetch[0].elapsed_time(self.end_event_fetch[1])
self.count_fetch += 1 self.count_fetch += 1
self.average_memory_sync += self.end_event_memory_sync[0].elapsed_time(self.end_event_memory_sync[1]) self.average_memory_sync += self.end_event_memory_sync[0].elapsed_time(self.end_event_memory_sync[1])
......
...@@ -392,9 +392,7 @@ class DistributedDataLoader: ...@@ -392,9 +392,7 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = None node_feat = None
mem = self.mailbox.unpack(node_feat0,mailbox = True) mem = self.mailbox.unpack(node_feat0,mailbox = True)
if self.ada_param is not None:
self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
self.ada_param.update_parameter()
#print(node_feat.shape,edge_feat.shape,mem[0].shape) #print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait() #node_feat[1].wait()
#node_feat = node_feat[0] #node_feat = node_feat[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment