Commit 4f3d6846 by zhljJoan

fix for test

parent 5a9ddff8
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#跑了4卡的TaoBao #跑了4卡的TaoBao
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.105" addr="192.168.1.107"
partition_params=("ours") partition_params=("ours")
partition="ours" partition="ours"
#"metis" "ldg" "random") #"metis" "ldg" "random")
...@@ -10,12 +10,12 @@ partition="ours" ...@@ -10,12 +10,12 @@ partition="ours"
partitions="8" partitions="8"
node_per="4" node_per="4"
nnodes="2" nnodes="2"
node_rank="0" node_rank="1"
sample_type_params=("recent" "recent" "recent" "recent" "boundery_recent_decay") sample_type_params=("recent" "recent" "boundery_recent_decay")
probability_params=("0.1") probability_params=("0.1")
neg_policy=("all" "local" "local" "local" "local") neg_policy=("all" "all" "all")
topk_list=("0" "0" "0.1" "0.1" "0.1") topk_list=("0.1" "0.1" "0.1")
memory_type=("all_update" "all_update" "all_update" "all_update" "historical") memory_type=("all_update" "historical" "historical")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
ssim="0.3" ssim="0.3"
neighbor_num=( "10" "20") neighbor_num=( "10" "20")
......
...@@ -495,7 +495,10 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -495,7 +495,10 @@ class AsyncMemeoryUpdater(torch.nn.Module):
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask]) #upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem']) #updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0) # ,updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0) if self.mode == 'historical':
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
else:
updated_memory = updated_memory0
#print('check {} {} {} {}'.format(self.gamma,updated_memory,torch.isnan(updated_memory).sum(),torch.isnan(updated_memory0).sum())) #print('check {} {} {} {}'.format(self.gamma,updated_memory,torch.isnan(updated_memory).sum(),torch.isnan(updated_memory0).sum()))
#print(updated_memory,isnan) #print(updated_memory,isnan)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment