Commit 911d3eb8 by zlj

fix gamma by sigmoid

parent 29c53410
......@@ -20,7 +20,7 @@ memory:
gnn:
- arch: 'identity'
train:
- epoch: 3
- epoch: 50
batch_size: 3000
lr: 0.0002
dropout: 0.1
......
......@@ -10,23 +10,22 @@ partitions="8"
node_per="4"
nnodes="2"
node_rank="0"
probability_params=("0.6" "1")
sample_type_params=("boundery_recent_decay" "recent")
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
#"boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("historical")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#("0" "0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2")
shared_memory_ssim=("0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
neighbor_num=( "10" "20")
neighbor="10"
topk_list=("0.1" "0.02" "0.04" "0.06" "0.08" "0.1" "0.2" "0.3")
data_param=("GDELT")
topk_list=("0.1")
data_param=("LASTFM" "StackOverflow" "GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......
......@@ -206,7 +206,7 @@ def main():
if(args.dataname=='GDELT'):
train_param['epoch'] = 10
#if(args.probability > 0.005):
train_param['epoch'] = 1
#train_param['epoch'] = 1
#torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用
if torch.cuda.is_available():
......
......@@ -406,11 +406,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
#self.mailbox.update_shared()
#self.mailbox.update_p2p_mem()
#self.mailbox.update_p2p_mail()
#self.mailbox.sychronize_shared()
#self.mailbox.handle_last_async()
self.mailbox.update_shared()
self.mailbox.update_p2p_mem()
self.mailbox.update_p2p_mail()
self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
if nxt_fetch_func is not None:
nxt_fetch_func()
......@@ -482,6 +482,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
#print(transition_dense.shape)
#print('check historical {} {} {} {}'.format(b.srcdata['his_mem'][mask],transition_dense,torch.isnan(transition_dense).sum(),torch.isnan(b.srcdata['his_mem'][mask]).sum()))
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max()
......@@ -494,7 +495,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
#print('check {} {} {} {}'.format(self.gamma,updated_memory,torch.isnan(updated_memory).sum(),torch.isnan(updated_memory0).sum()))
#print(updated_memory,isnan)
with torch.no_grad():
if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment