Commit 99e5b95a by xxx

fix bugs and add APAN model

parent b305c21a
...@@ -3,10 +3,7 @@ sampling: ...@@ -3,10 +3,7 @@ sampling:
neighbor: neighbor:
- 10 - 10
strategy: 'recent' strategy: 'recent'
prop_time: False
history: 1 history: 1
duration: 0
num_thread: 32
no_neg: True no_neg: True
memory: memory:
- type: 'node' - type: 'node'
......
sampling: sampling:
- no_sample: True - strategy: 'identity'
history: 1 history: 1
memory: memory:
- type: 'node' - type: 'node'
dim_time: 100 dim_time: 100
......
...@@ -3,10 +3,7 @@ sampling: ...@@ -3,10 +3,7 @@ sampling:
neighbor: neighbor:
- 10 - 10
strategy: 'recent' strategy: 'recent'
prop_time: False
history: 1 history: 1
duration: 0
num_thread: 32
memory: memory:
- type: 'node' - type: 'node'
dim_time: 100 dim_time: 100
......
bash test_all.sh 13357 > 13357.out
wait
bash test_all.sh 12347 > 12347.out bash test_all.sh 12347 > 12347.out
wait wait
bash test_all.sh 63377 > 63377.out bash test_all.sh 63377 > 63377.out
......
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763']
partition = 'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0
mem='all_update'#'historical'
model='TGN'
for sd in seed :
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
if data == 'WIKI' or data =='LASTFM':
model = 'TGN'
else:
model = 'TGN_large'
if p == 1:
file = 'all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = 'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
with open(file, 'r') as file:
for line in file:
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
ap_list.append(test_ap)
print('data {} seed {} ap: {}'.format(data,sd,ap_list))
# prefix = 'best test AP:'
# cnt = 0
# sum = 0
# with open(file, 'r') as file:
# for line in file:
# if line.startswith(prefix):
# ap = float(line.lstrip(prefix).split(' ')[0])
# pos = line.find('remote node number tensor')
# if(pos!=-1):
# posr = line.find(']',pos+2+len('remote node number tensor'),)
# #print(line,line[pos+2+len('remote node number tensor'):posr])
# comm = int(line[pos+2+len('remote node number tensor'):posr])
# #print()
# sum = sum+comm
# cnt = cnt+1
# #print(comm)
# ap_list.append(ap)
# comm_list.append(sum/cnt*4)
# # 绘制柱状图
# print('{} TestAP={}\n'.format(data,ap_list))
# bar_width = 0.4
# #shared comm tensor
# # 设置柱状图的位置
# bars = range(len(probability_values))
# # 绘制柱状图
# plt.bar([b for b in bars], ap_list, width=bar_width)
# # 绘制柱状图
# plt.ylim([0.9,1])
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Test AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# print(comm_list)
# plt.bar([b for b in bars], comm_list, width=bar_width)
# # 绘制柱状图
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Communication volume')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# if partition == 'ours_shared':
# partition0 = 'ours'
# else:
# partition0=partition
# for p in probability_values:
# file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
# val_ap = torch.tensor(torch.load(file))[:,0]
# epoch = torch.arange(val_ap.shape[0])
# #绘制曲线图
# plt.plot(epoch,val_ap, label='probability={}'.format(p))
# plt.xlabel('Epoch')
# plt.ylabel('Val AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# # plt.grid(True)
# plt.legend()
# plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
# plt.clf()
...@@ -4,15 +4,15 @@ import torch ...@@ -4,15 +4,15 @@ import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值 ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0] probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据 data_values = ['WIKI','WikiTalk'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763'] seed = ['13357','12347','63377','53473','54763']
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4 partitions=4
topk=0.01 topk=0.01
mem='all_update'#'historical' mem='local'#'historical'
model='TGN' model='TGN'
for sd in seed : for sd in seed :
for data in data_values: for data in data_values:
......
...@@ -6,19 +6,20 @@ addr="192.168.1.107" ...@@ -6,19 +6,20 @@ addr="192.168.1.107"
partition_params=("ours" ) partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="4" partitions="8"
node_per="4" node_per="4"
nnodes="1" nnodes="2"
node_rank="0" node_rank="1"
probability_params=("0.1" "0" "0.05" "0.01") probability_params=("0.1" "0.05")
sample_type_params=("boundery_recent_decay" "recent") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update")
memory_type=( "all_update" "historical" "local")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7") shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin") data_param=("StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
...@@ -30,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin") ...@@ -30,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="TGN_large" model="JODIE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="JODIE"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
...@@ -57,8 +58,8 @@ for data in "${data_param[@]}"; do ...@@ -57,8 +58,8 @@ for data in "${data_param[@]}"; do
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out & # torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait # wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait wait
...@@ -80,13 +81,13 @@ for data in "${data_param[@]}"; do ...@@ -80,13 +81,13 @@ for data in "${data_param[@]}"; do
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out& torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait wait
fi fi
else #else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out & # torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait # wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then # if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out & # torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait # wait
fi # fi
fi fi
done done
done done
......
...@@ -126,7 +126,7 @@ def seed_everything(seed=42): ...@@ -126,7 +126,7 @@ def seed_everything(seed=42):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
seed_everything(args.seed)
total_next_batch = 0 total_next_batch = 0
total_forward = 0 total_forward = 0
total_count_score = 0 total_count_score = 0
...@@ -267,9 +267,15 @@ def main(): ...@@ -267,9 +267,15 @@ def main():
if args.local_neg_sample: if args.local_neg_sample:
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0])) print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique()) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else: else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
...@@ -338,10 +344,10 @@ def main(): ...@@ -338,10 +344,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge)) print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)).cuda()
device = torch.device('cuda') device = torch.device('cuda')
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg))
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=True)
def count_parameters(model): def count_parameters(model):
...@@ -531,9 +537,12 @@ def main(): ...@@ -531,9 +537,12 @@ def main():
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item()) total_loss += float(loss.item())
#mailbox.handle_last_async() #mailbox.handle_last_async()
#trainloader.async_feature() #trainloader.async_feature()
...@@ -663,9 +672,9 @@ def main(): ...@@ -663,9 +672,9 @@ def main():
pass pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) # print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt)) # print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all_args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(val_list,'all_{}/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(loss_list,'all_{}/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(test_ap_list,'all_{}/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time) print(avg_time)
if not early_stop: if not early_stop:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#跑了4卡的TaoBao #跑了4卡的TaoBao
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.107" addr="192.168.1.106"
partition_params=("ours" ) partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
...@@ -10,13 +10,13 @@ partitions="8" ...@@ -10,13 +10,13 @@ partitions="8"
node_per="4" node_per="4"
nnodes="2" nnodes="2"
node_rank="0" node_rank="0"
probability_params=("0.1" "0.01" "0.05") probability_params=("0.1")
sample_type_params=("boundery_recent_decay") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update") memory_type=( "historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<< HEAD <<<<<<< HEAD
data_param=("GDELT") data_param=("GDELT")
...@@ -34,9 +34,9 @@ data_param=("LASTFM") ...@@ -34,9 +34,9 @@ data_param=("LASTFM")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="TGN_large" model="APAN"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="APAN"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
......
LOCAL RANK 2, RANK2
LOCAL RANK 1, RANK1 LOCAL RANK 1, RANK1
LOCAL RANK 3, RANK3
LOCAL RANK 0, RANK0 LOCAL RANK 0, RANK0
LOCAL RANK 2, RANK2 LOCAL RANK 3, RANK3
in
in
in
in
local rank is 1 world_size is 4 memory group is 0 memory rank is 1 memory group size is 4
local rank is 0 world_size is 4 memory group is 0 memory rank is 0 memory group size is 4
local rank is 3 world_size is 4 memory group is 0 memory rank is 3 memory group size is 4
[0, 1, 2, 3]
[0, 1, 2, 3][0, 1, 2, 3]
local rank is 2 world_size is 4 memory group is 0 memory rank is 2 memory group size is 4
[0, 1, 2, 3]
use cuda on 0
use cuda on 3
use cuda on 2
use cuda on 1
LOCAL RANK 0, RANK0
use cuda on 0
994790
get_neighbors consume: 6.18508s
Epoch 0:
train loss:12236.7578 train ap:0.986976 val ap:0.934674 val auc:0.946284
total time:630.37s prep time:545.79s
fetch time:0.00s write back time:0.00s
Epoch 1:
train loss:11833.1818 train ap:0.987815 val ap:0.960581 val auc:0.965728
total time:628.44s prep time:542.56s
fetch time:0.00s write back time:0.00s
Epoch 2:
train loss:11622.9559 train ap:0.988244 val ap:0.956752 val auc:0.963083
total time:622.89s prep time:538.77s
fetch time:0.00s write back time:0.00s
Epoch 3:
train loss:11679.1400 train ap:0.988072 val ap:0.929351 val auc:0.943797
total time:681.88s prep time:569.50s
fetch time:0.00s write back time:0.00s
Epoch 4:
train loss:11676.1710 train ap:0.988098 val ap:0.936353 val auc:0.948531
total time:849.98s prep time:741.47s
fetch time:0.00s write back time:0.00s
Epoch 5:
train loss:11745.6001 train ap:0.987897 val ap:0.950828 val auc:0.958958
total time:862.77s prep time:750.90s
fetch time:0.00s write back time:0.00s
Epoch 6:
Early stopping at epoch 6
Loading the best model at epoch 1
0.9248434901237488 0.929413378238678
0.8653780221939087 0.861071765422821
test AP:0.847958 test AUC:0.837159
test_dataset 6647176 avg_time 87.00003329753876
...@@ -228,16 +228,19 @@ def main(): ...@@ -228,16 +228,19 @@ def main():
num_layers = sample_param['layer'] if 'layer' in sample_param else 1 num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10] fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent' policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False
if policy != 'recent':
policy_train = args.sample_type#'boundery_recent_decay' policy_train = args.sample_type#'boundery_recent_decay'
else:
policy_train = policy
if memory_param['type'] != 'none': if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0, mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),cache_route = cache_route,shared_ssim=args.shared_memory_ssim) shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),shared_ssim=args.shared_memory_ssim)
else: else:
mailbox = None mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability) sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability) eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1) train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device)) train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
...@@ -272,8 +275,10 @@ def main(): ...@@ -272,8 +275,10 @@ def main():
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0] remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio #train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
train_ratio_neg = args.probability * (1-remote_ratio) #train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
...@@ -402,7 +407,8 @@ def main(): ...@@ -402,7 +407,8 @@ def main():
aps.append(average_precision_score(y_true, y_pred.detach().numpy())) aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred)) aucs_mrrs.append(roc_auc_score(y_true, y_pred))
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p() mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
""" """
if mailbox is not None: if mailbox is not None:
src = metadata['src_pos_index'] src = metadata['src_pos_index']
...@@ -536,7 +542,7 @@ def main(): ...@@ -536,7 +542,7 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones/train_ratio_pos,ones/train_ratio_neg).reshape(-1,1) weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
...@@ -554,7 +560,8 @@ def main(): ...@@ -554,7 +560,8 @@ def main():
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy())) #train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize() #torch.cuda.synchronize()
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p() mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
...@@ -18,7 +18,7 @@ class CachePushRoute(): ...@@ -18,7 +18,7 @@ class CachePushRoute():
def __init__(self,local_num, local_drop_edge, dist_index): def __init__(self,local_num, local_drop_edge, dist_index):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
dist_drop = dist_index[local_drop_edge] dist_drop = dist_index[local_drop_edge]
id_mask = ~(DistIndex(dist_drop) == ctx.memory_group_rank) id_mask = not (DistIndex(dist_drop) == ctx.memory_group_rank)
remote_id = local_drop_edge[id_mask] remote_id = local_drop_edge[id_mask]
route = torch.zeros(local_num,dtype=torch.long) route = torch.zeros(local_num,dtype=torch.long)
src_mask = DistIndex(dist_drop[0,:]).part == ctx.memory_group_rank src_mask = DistIndex(dist_drop[0,:]).part == ctx.memory_group_rank
......
...@@ -352,31 +352,60 @@ class TransformerMemoryUpdater(torch.nn.Module): ...@@ -352,31 +352,60 @@ class TransformerMemoryUpdater(torch.nn.Module):
rst = self.dropout(rst) rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst) rst = torch.nn.functional.relu(rst)
return rst return rst
"""
(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
set_p2p=False,
mode=None,
filter=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False)
"""
class AsyncMemeoryUpdater(torch.nn.Module): class AsyncMemeoryUpdater(torch.nn.Module):
def all_update_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True) def all_update_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None: self.mailbox.set_memory_all_reduce(
nxt_fetch_func() index,memory,memory_ts,
def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func): mail_index,mail,mail_ts,
self.mailbox.handle_last_async() reduce_Op='max',async_op=False,
submit_to_queue = False mode='all_reduce',
if nxt_fetch_func is not None: wait_submit=False,spread_mail=spread_mail,
nxt_fetch_func() update_cross_mm=True,
submit_to_queue = True )
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue) #self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func): # def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.handle_last_async()
# submit_to_queue = False
# if nxt_fetch_func is not None:
# nxt_fetch_func()
# submit_to_queue = True
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
# def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
# if nxt_fetch_func is not None:
# nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.sychronize_shared() self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
submit_to_queue = False submit_to_queue = False
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
submit_to_queue = True submit_to_queue = True
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='historical',set_remote=False,submit = submit_to_queue) self.mailbox.set_memory_all_reduce(
def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func): index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op='max',async_op=True,
mode='historical',
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
def transformer_updater(self,b): def transformer_updater(self,b):
...@@ -473,7 +502,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -473,7 +502,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None) None)
#print(index.shape[0]) #print(index.shape[0])
if param[0]: if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail( index0, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats, b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory, self.last_updated_memory,
None,False,False,block=b None,False,False,block=b
...@@ -483,9 +512,12 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -483,9 +512,12 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory) self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape) ##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask = (DistIndex(index).part==torch.distributed.get_rank()) local_mask = (DistIndex(index).part==torch.distributed.get_rank())
self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max') local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func) #self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
......
...@@ -97,6 +97,7 @@ class NeighborSampler(BaseSampler): ...@@ -97,6 +97,7 @@ class NeighborSampler(BaseSampler):
node_part = None, node_part = None,
edge_part = None, edge_part = None,
probability = 1, probability = 1,
no_neg = False,
) -> None: ) -> None:
r"""__init__ r"""__init__
Args: Args:
...@@ -122,7 +123,7 @@ class NeighborSampler(BaseSampler): ...@@ -122,7 +123,7 @@ class NeighborSampler(BaseSampler):
self.is_distinct = is_distinct self.is_distinct = is_distinct
assert graph_name is not None assert graph_name is not None
self.graph_name = graph_name self.graph_name = graph_name
self.no_neg = no_neg
if(tnb is None): if(tnb is None):
if(graph_data.edge_ts is not None): if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort() timestamp,ind = graph_data.edge_ts.sort()
...@@ -314,7 +315,10 @@ class NeighborSampler(BaseSampler): ...@@ -314,7 +315,10 @@ class NeighborSampler(BaseSampler):
else: else:
seed, inverse_seed = seed.unique(return_inverse=True) seed, inverse_seed = seed.unique(return_inverse=True)
""" """
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False) if self.no_neg:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False)
else:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]], seed_ts[:seed.shape[0]//3*2], is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device) src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device) dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet(): if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment