Commit ab4e56d0 by xxx

1

parent ef367556
sampling:
- layer: 1
neighbor:
- 20
strategy: 'recent'
history: 1
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
historical_fix: False
async: True
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 50
batch_size: 3000
lr: 0.0002
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- strategy: 'identity'
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
historical_fix: False
async: True
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
use_src_emb: False
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 50
batch_size: 3000
lr: 0.0002
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
...@@ -504,7 +504,9 @@ def main(): ...@@ -504,7 +504,9 @@ def main():
local_edge_comm = [] local_edge_comm = []
remote_edge_comm = [] remote_edge_comm = []
b_cnt = 0 b_cnt = 0
start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
t1 = time_count.elapsed_event(start)
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
#local_access.append(trainloader.local_node) #local_access.append(trainloader.local_node)
...@@ -557,6 +559,7 @@ def main(): ...@@ -557,6 +559,7 @@ def main():
#torch.cuda.synchronize() #torch.cuda.synchronize()
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p() mailbox.update_p2p()
start = time_count.start_gpu()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
probability_values = [0.1]#[0.1,0.05,0.01,0]
data_values = ['WikiTalk'] # 存储从文件中读取的数据
seed = ['12357']#,'12347','63377','53473','54763']
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='historical-0.3'#'historical'
model0='APAN'
def average(l):
return sum(l)/len(l)
for data in data_values:
ap_list = []
comm_list = []
for sd in seed :
for p in probability_values:
if data == 'WIKI' or data =='LASTFM':
model = model0
else:
model = model0+'_large'
if p == 1:
file = 'all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = 'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
with open(file, 'r') as file:
for line in file:
if line.find('Epoch 50:')!=-1:
break
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
ap_list.append(test_ap)
print('data {} model {} ap: {}'.format(data,model,ap_list))
...@@ -472,7 +472,7 @@ def main(): ...@@ -472,7 +472,7 @@ def main():
cos = torch.nn.CosineSimilarity(dim=0) cos = torch.nn.CosineSimilarity(dim=0)
return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0) return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0)
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])#,weight_decay=1e-4)
early_stopper = EarlyStopMonitor(max_round=args.patience) early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth' MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time = 0 total_test_time = 0
...@@ -481,7 +481,7 @@ def main(): ...@@ -481,7 +481,7 @@ def main():
val_list = [] val_list = []
loss_list = [] loss_list = []
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
# model.module.memory_updater.empty_cache() model.module.memory_updater.empty_cache()
tt._zero() tt._zero()
torch.cuda.synchronize() torch.cuda.synchronize()
epoch_start_time = time.time() epoch_start_time = time.time()
...@@ -509,7 +509,10 @@ def main(): ...@@ -509,7 +509,10 @@ def main():
local_edge_comm = [] local_edge_comm = []
remote_edge_comm = [] remote_edge_comm = []
b_cnt = 0 b_cnt = 0
start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
end = time_count.elapsed_event(start)
#print('time {}'.format(end))
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
#local_access.append(trainloader.local_node) #local_access.append(trainloader.local_node)
...@@ -525,7 +528,7 @@ def main(): ...@@ -525,7 +528,7 @@ def main():
# sum_remote_edge_comm +=remote_edge_comm[b_cnt-1] # sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
#sum_local_comm +=local_comm[b_cnt-1] #sum_local_comm +=local_comm[b_cnt-1]
#sum_remote_comm +=remote_comm[b_cnt-1] #sum_remote_comm +=remote_comm[b_cnt-1]
t1 = time_count.start_gpu()
if mailbox is not None: if mailbox is not None:
if(graph.efeat.device.type != 'cpu'): if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda') edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
...@@ -539,12 +542,14 @@ def main(): ...@@ -539,12 +542,14 @@ def main():
param = (update_mail,src,dst,ts,edge_feats,trainloader.async_feature) param = (update_mail,src,dst,ts,edge_feats,trainloader.async_feature)
else: else:
param = None param = None
#print(time_count.elapsed_event(t1))
model.train() model.train()
t2 = time_count.start_gpu()
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1) weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
#print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg)) loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
...@@ -563,6 +568,7 @@ def main(): ...@@ -563,6 +568,7 @@ def main():
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p_mem() mailbox.update_p2p_mem()
mailbox.update_p2p_mail() mailbox.update_p2p_mail()
start = time_count.start_gpu()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
...@@ -33,21 +33,23 @@ class time_count: ...@@ -33,21 +33,23 @@ class time_count:
def start_gpu(): def start_gpu():
# Uncomment for better breakdown timings # Uncomment for better breakdown timings
#torch.cuda.synchronize() #torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) #start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) #end_event = torch.cuda.Event(enable_timing=True)
start_event.record() #start_event.record()
return start_event,end_event #return start_event,end_event
return 0,0
@staticmethod @staticmethod
def start(): def start():
return time.perf_counter(),0 # return time.perf_counter(),0
return 0,0
@staticmethod @staticmethod
def elapsed_event(start_event): def elapsed_event(start_event):
#if isinstance(start_event,tuple): # if isinstance(start_event,tuple):
# start_event,end_event = start_event # start_event,end_event = start_event
# end_event.record() # end_event.record()
# end_event.synchronize() # end_event.synchronize()
# return start_event.elapsed_time(end_event) # return start_event.elapsed_time(end_event)
#else: # else:
# torch.cuda.synchronize() # torch.cuda.synchronize()
# return time.perf_counter() - start_event # return time.perf_counter() - start_event
return 0 return 0
......
...@@ -187,13 +187,16 @@ class SharedMailBox(): ...@@ -187,13 +187,16 @@ class SharedMailBox():
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape)) #print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype) #pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos,idx = torch_scatter.scatter_max(mail_ts,root,0) pos,idx = torch_scatter.scatter_max(mail_ts,root,0)
mail = torch.cat([mail, mail[idx]],dim=0) #print(block.number_of_edges())
mail_ts = torch.cat([mail_ts, mail_ts[idx]], dim=0) mail = torch.cat([mail, mail[idx[block.edges()[0].long()]]],dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[idx[block.edges()[0].long()]]], dim=0)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long())) #print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0) #mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0) #mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long()) #print(root,block.edges()[1].long())
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0) index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
#print(index)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0) #mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0) #mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0) #index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment