Commit a3cc8ba3 by zlj

fix config

parent 81de4b74
...@@ -21,6 +21,6 @@ gnn: ...@@ -21,6 +21,6 @@ gnn:
train: train:
- epoch: 50 - epoch: 50
batch_size: 3000 batch_size: 3000
lr: 0.0004 lr: 0.0016
dropout: 0.1 dropout: 0.1
all_on_gpu: True all_on_gpu: True
\ No newline at end of file
...@@ -19,7 +19,7 @@ memory_type=("historical") ...@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT") data_param=("StackOverflow" "GDELT")
# "StackOverflow" "GDELT") # "StackOverflow" "GDELT")
#"GDELT") #"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
......
...@@ -167,12 +167,13 @@ class AdaParameter: ...@@ -167,12 +167,13 @@ class AdaParameter:
#print(self.alpha) #print(self.alpha)
self.beta = max(min(self.beta, self.max_beta),self.min_beta) self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha) self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
#print(self.count_fetch,self.count_memory_update,self.count_gnn_aggregate,self.count_memory_sync)
#print(self.beta,self.alpha) #print(self.beta,self.alpha)
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta]) beta_comm=torch.tensor([self.beta],dtype=torch.float)
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group) torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
self.beta = beta_comm[0].item()/ctx.world_size self.beta = beta_comm[0].item()/ctx.world_size
alpha_comm=torch.tensor([self.alpha]) alpha_comm=torch.tensor([self.alpha],dtype=torch.float)
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group) torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()/ctx.world_size self.alpha = alpha_comm[0].item()/ctx.world_size
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time)) #print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment