Commit 9a31f5e3 by zlj

fix time count

parent d43a5181
......@@ -6,9 +6,9 @@ addr="192.168.1.107"
partition_params=("ours")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="4"
partitions="8"
node_per="4"
nnodes="1"
nnodes="2"
node_rank="0"
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
......@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow")
data_param=("WikiTalk")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......
......@@ -39,7 +39,7 @@ class EarlyStopMonitor(object):
return self.num_round >= self.max_round
class AdaParameter:
def __init__(self, wait_threshold=0.1, init_beta = 0.1 ,init_alpha = 0.1, min_beta = 0.01, max_beta = 1, min_alpha = 0.1, max_alpha = 1):
def __init__(self, wait_threshold=0.1, init_beta = 0.1 ,init_alpha = 0.1, min_beta = 0.01, max_beta = 0.8, min_alpha = 0.1, max_alpha = 1):
self.wait_threshold = wait_threshold
self.beta = init_beta
self.alpha = init_alpha
......@@ -58,6 +58,11 @@ class AdaParameter:
self.last_start_event_memory_update = None
self.last_start_event_gnn_aggregate = None
self.end_event_fetch = None
self.end_event_memory_sync = None
self.end_event_memory_update = None
self.end_event_gnn_aggregate = None
self.min_beta = min_beta
self.max_beta = max_beta
self.max_alpha = max_alpha
......@@ -77,30 +82,35 @@ class AdaParameter:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_fetch += elapsed_time_ms
self.count_fetch += 1
#end_event.synchronize()
#elapsed_time_ms = start_event.elapsed_time(end_event)
#self.average_fetch += elapsed_time_ms
#self.count_fetch += 1
self.end_event_fetch = (self.last_start_event_fetch,end_event)
def update_memory_sync_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_sync += elapsed_time_ms
self.count_memory_sync += 1
#end_event.synchronize()
#elapsed_time_ms = start_event.elapsed_time(end_event)
#self.average_memory_sync += elapsed_time_ms
#self.count_memory_sync += 1
self.end_event_memory_sync = (self.last_start_event_memory_sync,end_event)
def update_memory_update_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_update += elapsed_time_ms
self.count_memory_update += 1
#end_event.synchronize()
#elapsed_time_ms = start_event.elapsed_time(end_event)
#self.average_memory_update += elapsed_time_ms
#self.count_memory_update += 1
self.end_event_memory_update = (self.last_start_event_memory_update,end_event)
def update_gnn_aggregate_time(self,start_event):
......@@ -108,10 +118,11 @@ class AdaParameter:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_gnn_aggregate += elapsed_time_ms
self.count_gnn_aggregate += 1
#end_event.synchronize()
#elapsed_time_ms = start_event.elapsed_time(end_event)
#self.average_gnn_aggregate += elapsed_time_ms
#self.count_gnn_aggregate += 1
self.end_event_gnn_aggregate = (self.last_start_event_gnn_aggregate,end_event)
def reset_time(self):
self.average_fetch = 0
......@@ -125,8 +136,22 @@ class AdaParameter:
def update_parameter(self):
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
# return
if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None:
return
self.end_event_fetch[1].synchronize()
self.end_event_memory_sync[1].synchronize()
self.end_event_memory_update[1].synchronize()
self.end_event_gnn_aggregate[1].synchronize()
self.average_fetch += self.end_event_fetch[0].elapsed_time(self.end_event_fetch[1])
self.count_fetch += 1
self.average_memory_sync += self.end_event_memory_sync[0].elapsed_time(self.end_event_memory_sync[1])
self.count_memory_sync += 1
self.average_memory_update += self.end_event_memory_update[0].elapsed_time(self.end_event_memory_update[1])
self.count_memory_update += 1
self.average_gnn_aggregate += self.end_event_gnn_aggregate[0].elapsed_time(self.end_event_gnn_aggregate[1])
self.count_gnn_aggregate += 1
average_gnn_aggregate = self.average_gnn_aggregate/self.count_gnn_aggregate
average_fetch = self.average_fetch/self.count_fetch
self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold)
......
......@@ -394,6 +394,7 @@ class DistributedDataLoader:
mem = self.mailbox.unpack(node_feat0,mailbox = True)
if self.ada_param is not None:
self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
self.ada_param.update_parameter()
#print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait()
#node_feat = node_feat[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment