Commit 4f3d6846 by zhljJoan

fix for test

parent 5a9ddff8
......@@ -2,7 +2,7 @@
#跑了4卡的TaoBao
# 定义数组变量
seed=$1
addr="192.168.1.105"
addr="192.168.1.107"
partition_params=("ours")
partition="ours"
#"metis" "ldg" "random")
......@@ -10,12 +10,12 @@ partition="ours"
partitions="8"
node_per="4"
nnodes="2"
node_rank="0"
sample_type_params=("recent" "recent" "recent" "recent" "boundery_recent_decay")
node_rank="1"
sample_type_params=("recent" "recent" "boundery_recent_decay")
probability_params=("0.1")
neg_policy=("all" "local" "local" "local" "local")
topk_list=("0" "0" "0.1" "0.1" "0.1")
memory_type=("all_update" "all_update" "all_update" "all_update" "historical")
neg_policy=("all" "all" "all")
topk_list=("0.1" "0.1" "0.1")
memory_type=("all_update" "historical" "historical")
shared_memory_ssim=("0.3")
ssim="0.3"
neighbor_num=( "10" "20")
......
......@@ -495,7 +495,10 @@ class AsyncMemeoryUpdater(torch.nn.Module):
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0)
if self.mode == 'historical':
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
else:
updated_memory = updated_memory0
#print('check {} {} {} {}'.format(self.gamma,updated_memory,torch.isnan(updated_memory).sum(),torch.isnan(updated_memory0).sum()))
#print(updated_memory,isnan)
with torch.no_grad():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment