Commit a3cc8ba3 by zlj

fix config

parent 81de4b74
......@@ -21,6 +21,6 @@ gnn:
train:
- epoch: 50
batch_size: 3000
lr: 0.0004
lr: 0.0016
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
......@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
data_param=("StackOverflow" "GDELT")
# "StackOverflow" "GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
......
......@@ -167,12 +167,13 @@ class AdaParameter:
#print(self.alpha)
self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
#print(self.count_fetch,self.count_memory_update,self.count_gnn_aggregate,self.count_memory_sync)
#print(self.beta,self.alpha)
ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta])
beta_comm=torch.tensor([self.beta],dtype=torch.float)
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
self.beta = beta_comm[0].item()/ctx.world_size
alpha_comm=torch.tensor([self.alpha])
alpha_comm=torch.tensor([self.alpha],dtype=torch.float)
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()/ctx.world_size
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment