Commit ff19c482 by zlj

test

parent d1128852
...@@ -111,6 +111,7 @@ if not 'MASTER_ADDR' in os.environ: ...@@ -111,6 +111,7 @@ if not 'MASTER_ADDR' in os.environ:
os.environ["MASTER_ADDR"] = '192.168.2.107' os.environ["MASTER_ADDR"] = '192.168.2.107'
if not 'MASTER_PORT' in os.environ: if not 'MASTER_PORT' in os.environ:
os.environ["MASTER_PORT"] = '9337' os.environ["MASTER_PORT"] = '9337'
os.environ["NCCL_IB_DISABLE"]='1' os.environ["NCCL_IB_DISABLE"]='1'
os.environ['NCCL_SOCKET_IFNAME']=matching_interfaces[0] os.environ['NCCL_SOCKET_IFNAME']=matching_interfaces[0]
print('rank {}'.format(int(os.environ["LOCAL_RANK"]))) print('rank {}'.format(int(os.environ["LOCAL_RANK"])))
...@@ -262,7 +263,7 @@ def main(): ...@@ -262,7 +263,7 @@ def main():
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique()) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else: else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank())) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773)
...@@ -279,7 +280,7 @@ def main(): ...@@ -279,7 +280,7 @@ def main():
is_pipeline=True, is_pipeline=True,
use_local_feature = False, use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)), device = torch.device('cuda:{}'.format(local_rank)),
probability=0.1,#train_cross_probability, probability=args.probability,
reversed = (gnn_param['arch'] == 'identity') reversed = (gnn_param['arch'] == 'identity')
) )
...@@ -457,12 +458,11 @@ def main(): ...@@ -457,12 +458,11 @@ def main():
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4)
early_stopper = EarlyStopMonitor(max_round=args.patience) early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth' MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time = 0
epoch_cnt = 0 epoch_cnt = 0
test_ap_list = [] test_ap_list = []
val_list = [] val_list = []
loss_list = [] loss_list = []
def fetch_async():
trainloader.async_feature()
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
model.module.memory_updater.empty_cache() model.module.memory_updater.empty_cache()
tt._zero() tt._zero()
...@@ -616,9 +616,15 @@ def main(): ...@@ -616,9 +616,15 @@ def main():
tt.weight_count_remote=0 tt.weight_count_remote=0
tt.ssim_cnt=0 tt.ssim_cnt=0
ap, auc = eval('val') ap, auc = eval('val')
torch.cuda.synchronize()
t_test = time.time()
test_ap,test_auc = eval('test') test_ap,test_auc = eval('test')
torch.cuda.synchronize()
t_test = time.time() - t_test
total_test_time += t_test
test_ap_list.append((test_ap,test_auc)) test_ap_list.append((test_ap,test_auc))
early_stop = early_stopper.early_stop_check(ap) early_stopper.early_stop_check(ap)
early_stop = False
trainloader.local_node = 0 trainloader.local_node = 0
trainloader.remote_node = 0 trainloader.remote_node = 0
trainloader.local_edge = 0 trainloader.local_edge = 0
...@@ -647,7 +653,7 @@ def main(): ...@@ -647,7 +653,7 @@ def main():
break break
else: else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep),t_test)
torch.save(model.module.state_dict(), get_checkpoint_path(e)) torch.save(model.module.state_dict(), get_checkpoint_path(e))
if args.model == 'TGN': if args.model == 'TGN':
print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
...@@ -665,7 +671,7 @@ def main(): ...@@ -665,7 +671,7 @@ def main():
print('best test AP:{:4f} test auc{:4f}'.format(*test_ap_list[early_stopper.best_epoch])) print('best test AP:{:4f} test auc{:4f}'.format(*test_ap_list[early_stopper.best_epoch]))
val_list = torch.tensor(val_list) val_list = torch.tensor(val_list)
loss_list = torch.tensor(loss_list) loss_list = torch.tensor(loss_list)
print('test_dataset {} avg_time {} \n'.format(test_data.edges.shape[1],avg_time/epoch_cnt)) print('test_dataset {} avg_time {} test time {}\n'.format(test_data.edges.shape[1],avg_time/epoch_cnt,total_test_time/epoch_cnt))
torch.save(model.module.state_dict(), MODEL_SAVE_PATH) torch.save(model.module.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown() ctx.shutdown()
......
...@@ -464,8 +464,8 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -464,8 +464,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
change = 2*change - 1 change = 2*change - 1
self.filter.update(shared_ind,change) self.filter.update(shared_ind,change)
#print(transition_dense) #print(transition_dense)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask)) #print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma) #print(self.gamma)
self.pre_mem = b.srcdata['his_mem'] self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
......
...@@ -21,7 +21,8 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -21,7 +21,8 @@ class LocalNegativeSampling(NegativeSampling):
src_node_list: torch.Tensor = None, src_node_list: torch.Tensor = None,
dst_node_list: torch.Tensor = None, dst_node_list: torch.Tensor = None,
local_mask = None, local_mask = None,
seed = None seed = None,
prob = None
): ):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique) super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
...@@ -37,6 +38,7 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -37,6 +38,7 @@ class LocalNegativeSampling(NegativeSampling):
self.local_mask = local_mask self.local_mask = local_mask
if self.local_mask is not None: if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask] self.local_dst = dst_node_list[local_mask]
self.prob = prob
#self.rdm.manual_seed(42) #self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list)) #print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool: def is_binary(self) -> bool:
...@@ -61,7 +63,7 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -61,7 +63,7 @@ class LocalNegativeSampling(NegativeSampling):
p = torch.rand(size=(num_samples,)) p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)] sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)] sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=1,sl,sr) s=torch.where(p<=self.prob,sl,sr)
return sr return sr
else: else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm) s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment