Commit c0862f4e by zlj

found bugs in historical memory

parent 5ac6e955
......@@ -19,8 +19,8 @@ gnn:
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 250
- epoch: 100
batch_size: 1000
lr: 0.0002
lr: 0.0004
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
......@@ -190,7 +190,7 @@ def query():
"total_update_memory":total_update_memory,
"total_remote_update":total_remote_update,}
def main():
torch.backends.cudnn.benchmark = True
#torch.backends.cudnn.benchmark = True
#torch.autograd.set_detect_anomaly(True)
print('LOCAL RANK {}, RANK{}'.format(os.environ["LOCAL_RANK"],os.environ["RANK"]))
use_cuda = True
......@@ -434,7 +434,7 @@ def main():
cos = torch.nn.CosineSimilarity(dim=0)
return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0)
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])#,weight_decay=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
......@@ -476,7 +476,7 @@ def main():
t1 = time_count.start_gpu()
if mailbox is not None:
if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda',non_blocking=True)
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')#,non_blocking=True)
else:
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
src = metadata['src_pos_index']
......@@ -509,7 +509,7 @@ def main():
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
if ada_param is not None and e < 3:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
......@@ -597,6 +597,7 @@ def main():
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep,t_test))
print('alpha is {} beta is {}'.format(ada_param.alpha,ada_param.beta))
torch.save(model.module.state_dict(), get_checkpoint_path(e))
if args.model == 'TGN':
pass
......
......@@ -21,7 +21,7 @@ class Filter(nn.Module):
"""
# Treat filter as parameter so that it is saved and loaded together with the model
self.count = torch.zeros((self.n_nodes),1).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension),dtype=torch.float32).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
def get_count(self, node_idxs):
......
......@@ -483,11 +483,13 @@ class AsyncMemeoryUpdater(torch.nn.Module):
transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max()
transition_dense = 2*transition_dense - 1
upd0[mask] = transition_dense.to(upd0.dtype)#b.srcdata['his_mem'][mask] + transition_dense
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
upd0[mask] = transition_dense#b.srcdata['his_mem'][mask] + transition_dense
#updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
else:
#upd0[mask] = updated_memory0[mask]
updated_memory = updated_memory0
upd0[mask] = updated_memory0[mask]
#updated_memory = updated_memory0
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
with torch.no_grad():
if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask]
......
......@@ -72,6 +72,7 @@ class LocalNegativeSampling(NegativeSampling):
self.train_ratio_pos = 1.0/(1-prob+ prob * remote_ratio) if ((prob<1) & (prob > 0)) else 1
self.train_ratio_neg = 1.0/(prob*remote_ratio) if ((prob <1) & (prob > 0)) else 1
p = torch.rand(size=(num_samples,))
#print(prob,remote_ratio,self.train_ratio_pos,self.train_ratio_neg)
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=prob,sr,sl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment