Commit ab4e56d0 by xxx

1

parent ef367556
sampling:
- layer: 1
neighbor:
- 20
strategy: 'recent'
history: 1
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
historical_fix: False
async: True
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 50
batch_size: 3000
lr: 0.0002
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- strategy: 'identity'
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
historical_fix: False
async: True
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
use_src_emb: False
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 50
batch_size: 3000
lr: 0.0002
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
...@@ -504,7 +504,9 @@ def main(): ...@@ -504,7 +504,9 @@ def main():
local_edge_comm = [] local_edge_comm = []
remote_edge_comm = [] remote_edge_comm = []
b_cnt = 0 b_cnt = 0
start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
t1 = time_count.elapsed_event(start)
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
#local_access.append(trainloader.local_node) #local_access.append(trainloader.local_node)
...@@ -557,6 +559,7 @@ def main(): ...@@ -557,6 +559,7 @@ def main():
#torch.cuda.synchronize() #torch.cuda.synchronize()
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p() mailbox.update_p2p()
start = time_count.start_gpu()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
probability_values = [0.1]#[0.1,0.05,0.01,0]
data_values = ['WikiTalk'] # 存储从文件中读取的数据
seed = ['12357']#,'12347','63377','53473','54763']
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='historical-0.3'#'historical'
model0='APAN'
def average(l):
return sum(l)/len(l)
for data in data_values:
ap_list = []
comm_list = []
for sd in seed :
for p in probability_values:
if data == 'WIKI' or data =='LASTFM':
model = model0
else:
model = model0+'_large'
if p == 1:
file = 'all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = 'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
with open(file, 'r') as file:
for line in file:
if line.find('Epoch 50:')!=-1:
break
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
ap_list.append(test_ap)
print('data {} model {} ap: {}'.format(data,model,ap_list))
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#跑了4卡的TaoBao #跑了4卡的TaoBao
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.105" addr="192.168.1.106"
partition_params=("ours" ) partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
...@@ -20,6 +20,7 @@ memory_type=("historical") ...@@ -20,6 +20,7 @@ memory_type=("historical")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT") data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
...@@ -31,9 +32,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT") ...@@ -31,9 +32,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="JODIE" model="APAN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="JODIE" model="APAN"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
...@@ -96,8 +97,139 @@ for data in "${data_param[@]}"; do ...@@ -96,8 +97,139 @@ for data in "${data_param[@]}"; do
done done
done done
for data in "${data_param[@]}"; do
model="JODIE_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="JODIE"
#continue
fi
#model="APAN"
mkdir all_"$seed"/"$data"
mkdir all_"$seed"/"$data"/"$model"
mkdir all_"$seed"/"$data"/"$model"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 --seed "$seed" > all_"$seed"/"$data"/"$model"/1.out &
wait
for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do
if [ "$sample" = "recent" ]; then
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
fi
done
else
for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
#torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
fi
done
done
fi
done
done
done
data_param=("StackOverflow" "GDELT")
for data in "${data_param[@]}"; do
model="TGN"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN_large"
#continue
fi
#model="APAN"
mkdir all_"$seed"/"$data"
mkdir all_"$seed"/"$data"/"$model"
mkdir all_"$seed"/"$data"/"$model"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 --seed "$seed" > all_"$seed"/"$data"/"$model"/1.out &
wait
for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do
if [ "$sample" = "recent" ]; then
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
fi
done
else
for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
#torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
fi
done
done
fi
done
done
done
# for data in "${data_param[@]}"; do # for data in "${data_param[@]}"; do
# model="JODILE" # model="JODILE"
# if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then # if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
......
...@@ -472,7 +472,7 @@ def main(): ...@@ -472,7 +472,7 @@ def main():
cos = torch.nn.CosineSimilarity(dim=0) cos = torch.nn.CosineSimilarity(dim=0)
return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0) return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0)
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])#,weight_decay=1e-4)
early_stopper = EarlyStopMonitor(max_round=args.patience) early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth' MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time = 0 total_test_time = 0
...@@ -481,7 +481,7 @@ def main(): ...@@ -481,7 +481,7 @@ def main():
val_list = [] val_list = []
loss_list = [] loss_list = []
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
# model.module.memory_updater.empty_cache() model.module.memory_updater.empty_cache()
tt._zero() tt._zero()
torch.cuda.synchronize() torch.cuda.synchronize()
epoch_start_time = time.time() epoch_start_time = time.time()
...@@ -509,7 +509,10 @@ def main(): ...@@ -509,7 +509,10 @@ def main():
local_edge_comm = [] local_edge_comm = []
remote_edge_comm = [] remote_edge_comm = []
b_cnt = 0 b_cnt = 0
start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
end = time_count.elapsed_event(start)
#print('time {}'.format(end))
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
#local_access.append(trainloader.local_node) #local_access.append(trainloader.local_node)
...@@ -525,7 +528,7 @@ def main(): ...@@ -525,7 +528,7 @@ def main():
# sum_remote_edge_comm +=remote_edge_comm[b_cnt-1] # sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
#sum_local_comm +=local_comm[b_cnt-1] #sum_local_comm +=local_comm[b_cnt-1]
#sum_remote_comm +=remote_comm[b_cnt-1] #sum_remote_comm +=remote_comm[b_cnt-1]
t1 = time_count.start_gpu()
if mailbox is not None: if mailbox is not None:
if(graph.efeat.device.type != 'cpu'): if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda') edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
...@@ -539,12 +542,14 @@ def main(): ...@@ -539,12 +542,14 @@ def main():
param = (update_mail,src,dst,ts,edge_feats,trainloader.async_feature) param = (update_mail,src,dst,ts,edge_feats,trainloader.async_feature)
else: else:
param = None param = None
#print(time_count.elapsed_event(t1))
model.train() model.train()
t2 = time_count.start_gpu()
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1) weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
#print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg)) loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
...@@ -563,6 +568,7 @@ def main(): ...@@ -563,6 +568,7 @@ def main():
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p_mem() mailbox.update_p2p_mem()
mailbox.update_p2p_mail() mailbox.update_p2p_mail()
start = time_count.start_gpu()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
...@@ -33,21 +33,23 @@ class time_count: ...@@ -33,21 +33,23 @@ class time_count:
def start_gpu(): def start_gpu():
# Uncomment for better breakdown timings # Uncomment for better breakdown timings
#torch.cuda.synchronize() #torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) #start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) #end_event = torch.cuda.Event(enable_timing=True)
start_event.record() #start_event.record()
return start_event,end_event #return start_event,end_event
return 0,0
@staticmethod @staticmethod
def start(): def start():
return time.perf_counter(),0 # return time.perf_counter(),0
return 0,0
@staticmethod @staticmethod
def elapsed_event(start_event): def elapsed_event(start_event):
#if isinstance(start_event,tuple): # if isinstance(start_event,tuple):
# start_event,end_event = start_event # start_event,end_event = start_event
# end_event.record() # end_event.record()
# end_event.synchronize() # end_event.synchronize()
# return start_event.elapsed_time(end_event) # return start_event.elapsed_time(end_event)
#else: # else:
# torch.cuda.synchronize() # torch.cuda.synchronize()
# return time.perf_counter() - start_event # return time.perf_counter() - start_event
return 0 return 0
......
...@@ -187,13 +187,16 @@ class SharedMailBox(): ...@@ -187,13 +187,16 @@ class SharedMailBox():
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape)) #print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype) #pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos,idx = torch_scatter.scatter_max(mail_ts,root,0) pos,idx = torch_scatter.scatter_max(mail_ts,root,0)
mail = torch.cat([mail, mail[idx]],dim=0) #print(block.number_of_edges())
mail_ts = torch.cat([mail_ts, mail_ts[idx]], dim=0) mail = torch.cat([mail, mail[idx[block.edges()[0].long()]]],dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[idx[block.edges()[0].long()]]], dim=0)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long())) #print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0) #mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0) #mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long()) #print(root,block.edges()[1].long())
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0) index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
#print(index)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0) #mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0) #mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0) #index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment