Commit ff19c482 by zlj

test

parent d1128852
......@@ -111,6 +111,7 @@ if not 'MASTER_ADDR' in os.environ:
os.environ["MASTER_ADDR"] = '192.168.2.107'
if not 'MASTER_PORT' in os.environ:
os.environ["MASTER_PORT"] = '9337'
os.environ["NCCL_IB_DISABLE"]='1'
os.environ['NCCL_SOCKET_IFNAME']=matching_interfaces[0]
print('rank {}'.format(int(os.environ["LOCAL_RANK"])))
......@@ -262,7 +263,7 @@ def main():
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773)
......@@ -279,7 +280,7 @@ def main():
is_pipeline=True,
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=0.1,#train_cross_probability,
probability=args.probability,
reversed = (gnn_param['arch'] == 'identity')
)
......@@ -457,12 +458,11 @@ def main():
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4)
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time = 0
epoch_cnt = 0
test_ap_list = []
val_list = []
loss_list = []
def fetch_async():
trainloader.async_feature()
for e in range(train_param['epoch']):
model.module.memory_updater.empty_cache()
tt._zero()
......@@ -616,9 +616,15 @@ def main():
tt.weight_count_remote=0
tt.ssim_cnt=0
ap, auc = eval('val')
torch.cuda.synchronize()
t_test = time.time()
test_ap,test_auc = eval('test')
torch.cuda.synchronize()
t_test = time.time() - t_test
total_test_time += t_test
test_ap_list.append((test_ap,test_auc))
early_stop = early_stopper.early_stop_check(ap)
early_stopper.early_stop_check(ap)
early_stop = False
trainloader.local_node = 0
trainloader.remote_node = 0
trainloader.local_edge = 0
......@@ -647,7 +653,7 @@ def main():
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n test time {:.2f}'.format(time.time()-epoch_start_time, time_prep),t_test)
torch.save(model.module.state_dict(), get_checkpoint_path(e))
if args.model == 'TGN':
print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
......@@ -665,7 +671,7 @@ def main():
print('best test AP:{:4f} test auc{:4f}'.format(*test_ap_list[early_stopper.best_epoch]))
val_list = torch.tensor(val_list)
loss_list = torch.tensor(loss_list)
print('test_dataset {} avg_time {} \n'.format(test_data.edges.shape[1],avg_time/epoch_cnt))
print('test_dataset {} avg_time {} test time {}\n'.format(test_data.edges.shape[1],avg_time/epoch_cnt,total_test_time/epoch_cnt))
torch.save(model.module.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
......
......@@ -464,8 +464,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
change = 2*change - 1
self.filter.update(shared_ind,change)
#print(transition_dense)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma)
#print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
#print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
......
......@@ -21,7 +21,8 @@ class LocalNegativeSampling(NegativeSampling):
src_node_list: torch.Tensor = None,
dst_node_list: torch.Tensor = None,
local_mask = None,
seed = None
seed = None,
prob = None
):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
......@@ -37,6 +38,7 @@ class LocalNegativeSampling(NegativeSampling):
self.local_mask = local_mask
if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask]
self.prob = prob
#self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool:
......@@ -61,7 +63,7 @@ class LocalNegativeSampling(NegativeSampling):
p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=1,sl,sr)
s=torch.where(p<=self.prob,sl,sr)
return sr
else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment