Commit e713643f by zhlj

update memory

parent e0711e14
...@@ -9,6 +9,7 @@ __pycache__/ ...@@ -9,6 +9,7 @@ __pycache__/
*.so *.so
# Distribution / packaging # Distribution / packaging
examples/all*
.Python .Python
examples/all* examples/all*
build/ build/
......
...@@ -14,6 +14,8 @@ memory: ...@@ -14,6 +14,8 @@ memory:
deliver_to: 'neighbors' deliver_to: 'neighbors'
mail_combine: 'last' mail_combine: 'last'
memory_update: 'transformer' memory_update: 'transformer'
historical_fix: False
async: True
attention_head: 2 attention_head: 2
mailbox_size: 10 mailbox_size: 10
combine_node_feature: False combine_node_feature: False
...@@ -22,7 +24,7 @@ gnn: ...@@ -22,7 +24,7 @@ gnn:
- arch: 'identity' - arch: 'identity'
train: train:
- epoch: 100 - epoch: 100
batch_size: 600 batch_size: 1000
lr: 0.0001 lr: 0.0001
dropout: 0.1 dropout: 0.1
att_dropout: 0.1 att_dropout: 0.1
......
...@@ -7,6 +7,8 @@ memory: ...@@ -7,6 +7,8 @@ memory:
deliver_to: 'self' deliver_to: 'self'
mail_combine: 'last' mail_combine: 'last'
memory_update: 'rnn' memory_update: 'rnn'
historical_fix: False
async: True
mailbox_size: 1 mailbox_size: 1
combine_node_feature: True combine_node_feature: True
dim_out: 100 dim_out: 100
...@@ -16,8 +18,8 @@ gnn: ...@@ -16,8 +18,8 @@ gnn:
use_dst_emb: False use_dst_emb: False
time_transform: 'JODIE' time_transform: 'JODIE'
train: train:
- epoch: 100 - epoch: 250
batch_size: 1000 batch_size: 1000
lr: 0.0001 lr: 0.0002
dropout: 0.1 dropout: 0.1
all_on_gpu: True all_on_gpu: True
\ No newline at end of file
...@@ -27,7 +27,7 @@ gnn: ...@@ -27,7 +27,7 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 100 - epoch: 200
batch_size: 1000 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0004 lr: 0.0004
......
...@@ -4,11 +4,11 @@ import torch ...@@ -4,11 +4,11 @@ import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值 ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0] probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WIKI'] # 存储从文件中读取的数据 data_values = ['WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4 partitions=8
topk=0.01 topk=0.01
mem='all_update'#'historical' mem='all_update'#'historical'
for data in data_values: for data in data_values:
...@@ -35,11 +35,12 @@ for data in data_values: ...@@ -35,11 +35,12 @@ for data in data_values:
# 绘制柱状图 # 绘制柱状图
print('{} TestAP={}\n'.format(data,ap_list))
bar_width = 0.4 bar_width = 0.4
#shared comm tensor #shared comm tensor
# 设置柱状图的位置 # 设置柱状图的位置
bars = range(len(ssim_values)) bars = range(len(probability_values))
# 绘制柱状图 # 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width) plt.bar([b for b in bars], ap_list, width=bar_width)
...@@ -49,7 +50,7 @@ for data in data_values: ...@@ -49,7 +50,7 @@ for data in data_values:
plt.xlabel('probability') plt.xlabel('probability')
plt.ylabel('Test AP') plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}.png'.format(data)) plt.savefig('boundary_AP_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width) plt.bar([b for b in bars], comm_list, width=bar_width)
...@@ -58,7 +59,7 @@ for data in data_values: ...@@ -58,7 +59,7 @@ for data in data_values:
plt.xlabel('probability') plt.xlabel('probability')
plt.ylabel('Communication volume') plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}.png'.format(data)) plt.savefig('boundary_comm_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
if partition == 'ours_shared': if partition == 'ours_shared':
...@@ -76,5 +77,5 @@ for data in data_values: ...@@ -76,5 +77,5 @@ for data in data_values:
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True) # plt.grid(True)
plt.legend() plt.legend()
plt.savefig('{}_boundary_Convergence_rate.png'.format(data)) plt.savefig('{}_{}_boundary_Convergence_rate.png'.format(data,partitions))
plt.clf() plt.clf()
...@@ -2,12 +2,13 @@ import matplotlib.pyplot as plt ...@@ -2,12 +2,13 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.5, 1.0, 1.5, 2] # 假设这是你的 ssim 参数值 ssim_values = [-1,0,0.3,0.7,2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据 data_values = ['WIKI','LASTFM','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4 partitions=4
model = 'JODIE'
topk=0.01 topk=0.01
mem='historical' mem='historical'
for data in data_values: for data in data_values:
...@@ -15,9 +16,11 @@ for data in data_values: ...@@ -15,9 +16,11 @@ for data in data_values:
comm_list = [] comm_list = []
for ssim in ssim_values: for ssim in ssim_values:
if ssim == 2: if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk) file = '{}/{}/{}-{}-{}-local-recent.out'.format(data,model,partitions,partition,topk)
elif ssim == -1:
file = '{}/{}/{}-{}-{}-all_update-recent.out'.format(data,model,partitions,partition,topk)
else: else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim) file = '{}/{}/{}-{}-{}-{}-{}-recent.out'.format(data,model,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:' prefix = 'best test AP:'
with open(file, 'r') as file: with open(file, 'r') as file:
for line in file: for line in file:
...@@ -26,6 +29,7 @@ for data in data_values: ...@@ -26,6 +29,7 @@ for data in data_values:
pos = line.find('shared comm tensor') pos = line.find('shared comm tensor')
if(pos!=-1): if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3]) comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
print(ap)
ap_list.append(ap) ap_list.append(ap)
comm_list.append(comm) comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list)) print('{} TestAP={}\n'.format(data,ap_list))
...@@ -33,7 +37,7 @@ for data in data_values: ...@@ -33,7 +37,7 @@ for data in data_values:
# 绘制柱状图 # 绘制柱状图
bar_width = 0.4 bar_width = 0.4
#shared comm tensor #shared comm tensor
print('{} TestAP={}\n'.format(data,ap_list))
# 设置柱状图的位置 # 设置柱状图的位置
bars = range(len(ssim_values)) bars = range(len(ssim_values))
...@@ -43,8 +47,10 @@ for data in data_values: ...@@ -43,8 +47,10 @@ for data in data_values:
plt.xticks([b for b in bars], ssim_values) plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP') plt.ylabel('Test AP')
#if(data=='WIKI'):
# plt.ylim([0.97,1])
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}.png'.format(data,partitions)) plt.savefig('ssim_{}_{}_{}.png'.format(data,partitions,model))
plt.clf() plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width) plt.bar([b for b in bars], comm_list, width=bar_width)
...@@ -53,7 +59,7 @@ for data in data_values: ...@@ -53,7 +59,7 @@ for data in data_values:
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume') plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions)) plt.savefig('ssim_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf() plt.clf()
if partition == 'ours_shared': if partition == 'ours_shared':
...@@ -62,18 +68,28 @@ for data in data_values: ...@@ -62,18 +68,28 @@ for data in data_values:
partition0=partition partition0=partition
for ssim in ssim_values: for ssim in ssim_values:
if ssim == 2: if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,) file = '{}/{}/test_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,model,partition0,topk,partitions,)
elif ssim == -1:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_all_update_2.pt'.format(data,model,partition0,topk,partitions,)
else: else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim)) file = '{}/{}/test_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,model,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file)) val_ap = torch.tensor(torch.load(file))[:,0]
print(val_ap)
epoch = torch.arange(val_ap.shape[0]) epoch = torch.arange(val_ap.shape[0])
#绘制曲线图 #绘制曲线图
print(val_ap) #print(val_ap)
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim)) if ssim == -1:
plt.plot(epoch,val_ap, label='all-update')
elif ssim == 2:
plt.plot(epoch,val_ap, label='local')
else:
plt.plot(epoch,val_ap, label='ssim = {}'.format(ssim))
if(data=='WIKI'):
plt.ylim([0.85,0.90])
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('Val AP') plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True) # plt.grid(True)
plt.legend() plt.legend()
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions)) plt.savefig('{}_{}_{}_ssim_Convergence_rate.png'.format(data,partitions,model))
plt.clf() plt.clf()
#!/bin/bash #!/bin/bash
# 定义数组变量 # 定义数组变量
addr="192.168.1.107" addr="192.168.1.105"
partition_params=("dis_tgl" "ours" "metis" "random") partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="4" partitions="4"
...@@ -12,14 +12,14 @@ node_rank="0" ...@@ -12,14 +12,14 @@ node_rank="0"
probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0") probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0")
#sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
sample_type_params=("recent") sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("all_update" "local" "historical") memory_type=("all_update")
#"historical" "all_update") #"local" "historical") #"historical" "all_update") #"local" "historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0" "0.5" "1.0" "1.5") shared_memory_ssim=("0" "0.3" "0.5" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM") data_param=("WIKI" "LASTFM")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
...@@ -28,14 +28,84 @@ mkdir -p all ...@@ -28,14 +28,84 @@ mkdir -p all
# 遍历数组并执行命令 # 遍历数组并执行命令
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="TGN_large" model="TGN"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="TGN"
fi fi
#model="APAN"
mkdir all/"$data"
mkdir all/"$data"/"$model"
mkdir all/"$data"/"$model"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 > all/"$data"/"$model"/1.out &
wait
for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do
if [ "$sample" = "recent" ]; then
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
fi
done
else
for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
continue
# for ssim in "${shared_memory_ssim[@]}"; do
# if [ "$partition" = "ours" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
# wait
# fi
# done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
fi
done
done
fi
done
done
done
for data in "${data_param[@]}"; do
model="JODILE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="JODIE"
fi
#model="APAN"
mkdir all/"$data" mkdir all/"$data"
mkdir all/"$data"/comm mkdir all/"$data"/"$model"
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition ours --memory_type local --sample_type recent --topk 0 > all/"$data"/1.out & mkdir all/"$data"/"$model"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 > all/"$data"/"$model"/1.out &
wait wait
for partition in "${partition_params[@]}"; do for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do for sample in "${sample_type_params[@]}"; do
...@@ -44,20 +114,20 @@ for data in "${data_param[@]}"; do ...@@ -44,20 +114,20 @@ for data in "${data_param[@]}"; do
if [ "$mem" = "historical" ]; then if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait wait
fi fi
done done
elif [ "$mem" = "all_reduce" ]; then elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait wait
fi fi
fi fi
...@@ -75,14 +145,14 @@ for data in "${data_param[@]}"; do ...@@ -75,14 +145,14 @@ for data in "${data_param[@]}"; do
# done # done
elif [ "$mem" = "all_reduce" ]; then elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours"]; then if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out& torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait wait
fi fi
fi fi
......
...@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=0.1, type=float, metavar='W', ...@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=0.1, type=float, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W', parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--local_neg_sample', default=True, type=bool, metavar='W', parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W', parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model') help='name of model')
...@@ -196,7 +196,7 @@ def main(): ...@@ -196,7 +196,7 @@ def main():
torch.set_num_threads(10) torch.set_num_threads(10)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu')) graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用 # 确保 CUDA 可用
if torch.cuda.is_available(): if torch.cuda.is_available():
print("Total GPU memory: ", torch.cuda.get_device_properties(0).total_memory/1024**3) print("Total GPU memory: ", torch.cuda.get_device_properties(0).total_memory/1024**3)
...@@ -230,7 +230,7 @@ def main(): ...@@ -230,7 +230,7 @@ def main():
mailbox = None mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability) sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = 'recent', graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability) eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1) train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device)) train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
...@@ -262,8 +262,9 @@ def main(): ...@@ -262,8 +262,9 @@ def main():
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique()) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else: else:
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()))
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=12357) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773)
trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
...@@ -278,7 +279,8 @@ def main(): ...@@ -278,7 +279,8 @@ def main():
is_pipeline=True, is_pipeline=True,
use_local_feature = False, use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)), device = torch.device('cuda:{}'.format(local_rank)),
probability=train_cross_probability probability=1,#train_cross_probability,
reversed = (gnn_param['arch'] == 'identity')
) )
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler, eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
...@@ -291,24 +293,26 @@ def main(): ...@@ -291,24 +293,26 @@ def main():
mode='eval_train', mode='eval_train',
queue_size = 100, queue_size = 100,
mailbox = mailbox, mailbox = mailbox,
device = torch.device('cuda:{}'.format(local_rank)) device = torch.device('cuda:{}'.format(local_rank)),
reversed = (gnn_param['arch']=='identity')
) )
testloader = DistributedDataLoader(graph,test_data,sampler = eval_sampler, testloader = DistributedDataLoader(graph,test_data,sampler = eval_sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = train_param['batch_size'], batch_size = train_param['batch_size']*dist.get_world_size(),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
mode='test', mode='test',
queue_size = 100, queue_size = 100,
mailbox = mailbox, mailbox = mailbox,
device = torch.device('cuda:{}'.format(local_rank)) device = torch.device('cuda:{}'.format(local_rank)),
reversed = (gnn_param['arch']=='identity')
) )
valloader = DistributedDataLoader(graph,val_data,sampler = eval_sampler, valloader = DistributedDataLoader(graph,val_data,sampler = eval_sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = train_param['batch_size'], batch_size = train_param['batch_size']*dist.get_world_size(),
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -316,13 +320,15 @@ def main(): ...@@ -316,13 +320,15 @@ def main():
mode='val', mode='val',
queue_size = 100, queue_size = 100,
mailbox = mailbox, mailbox = mailbox,
device = torch.device('cuda:{}'.format(local_rank)) device = torch.device('cuda:{}'.format(local_rank)),
reversed = (gnn_param['arch']=='identity')
) )
print('init dataloader') print('init dataloader')
gnn_dim_node = 0 if graph.nfeat is None else graph.nfeat.shape[1] gnn_dim_node = 0 if graph.nfeat is None else graph.nfeat.shape[1]
gnn_dim_edge = 0 if graph.efeat is None else graph.efeat.shape[1] gnn_dim_edge = 0 if graph.efeat is None else graph.efeat.shape[1]
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda()
...@@ -337,6 +343,7 @@ def main(): ...@@ -337,6 +343,7 @@ def main():
print(f'The model has {count_parameters(model):,} trainable parameters') print(f'The model has {count_parameters(model):,} trainable parameters')
train_stream = torch.cuda.Stream() train_stream = torch.cuda.Stream()
def eval(mode='val'): def eval(mode='val'):
model.eval() model.eval()
aps = list() aps = list()
...@@ -356,7 +363,7 @@ def main(): ...@@ -356,7 +363,7 @@ def main():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader: for roots,mfgs,metadata in loader:
"""
if ctx.memory_group == 0: if ctx.memory_group == 0:
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=neg_samples) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=neg_samples)
#print('check {}\n'.format(model.module.memory_updater.last_updated_nid)) #print('check {}\n'.format(model.module.memory_updater.last_updated_nid))
...@@ -364,7 +371,28 @@ def main(): ...@@ -364,7 +371,28 @@ def main():
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy())) aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred)) aucs_mrrs.append(roc_auc_score(y_true, y_pred))
"""
if mailbox is not None:
if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')].to('cuda'),is_sorted = False) #graph.efeat[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
update_mail = True
param = (update_mail,src,dst,ts,edge_feats,loader.async_feature)
else:
param = None
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
mailbox.update_shared()
mailbox.update_p2p()
"""
if mailbox is not None: if mailbox is not None:
src = metadata['src_pos_index'] src = metadata['src_pos_index']
dst = metadata['dst_pos_index'] dst = metadata['dst_pos_index']
...@@ -399,7 +427,9 @@ def main(): ...@@ -399,7 +427,9 @@ def main():
if memory_param['historical_fix'] == True: if memory_param['historical_fix'] == True:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=model.module.memory_updater.filter,set_remote=True,mode='historical') mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=model.module.memory_updater.filter,set_remote=True,mode='historical')
else: else:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce') mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce',submit=False)
mailbox.sychronize_shared()
"""
ap = torch.empty([1]) ap = torch.empty([1])
auc_mrr = torch.empty([1]) auc_mrr = torch.empty([1])
if(ctx.memory_group==0): if(ctx.memory_group==0):
...@@ -471,12 +501,14 @@ def main(): ...@@ -471,12 +501,14 @@ def main():
remote_edge_access.append(trainloader.remote_edge) remote_edge_access.append(trainloader.remote_edge)
local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item()) local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item()) remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item()) if 'ID' in mfgs[0][0].edata:
remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item()) local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
sum_local_edge_comm +=local_edge_comm[b_cnt-1]
sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
sum_local_comm +=local_comm[b_cnt-1] sum_local_comm +=local_comm[b_cnt-1]
sum_remote_comm +=remote_comm[b_cnt-1] sum_remote_comm +=remote_comm[b_cnt-1]
sum_local_edge_comm +=local_edge_comm[b_cnt-1]
sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
if mailbox is not None: if mailbox is not None:
if(graph.efeat.device.type != 'cpu'): if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda') edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
...@@ -487,7 +519,7 @@ def main(): ...@@ -487,7 +519,7 @@ def main():
dst = metadata['dst_pos_index'] dst = metadata['dst_pos_index']
ts = roots.ts ts = roots.ts
update_mail = True update_mail = True
param = (update_mail,src,dst,ts,edge_feats,fetch_async) param = (update_mail,src,dst,ts,edge_feats,trainloader.async_feature)
else: else:
param = None param = None
model.train() model.train()
...@@ -497,6 +529,7 @@ def main(): ...@@ -497,6 +529,7 @@ def main():
if memory_param['historical_fix'] == True: if memory_param['historical_fix'] == True:
loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.update_memory,model.module.memory_updater.prev_memory) loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.update_memory,model.module.memory_updater.prev_memory)
else: else:
#loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.last_updated_memory,model.module.memory_updater.pre_mem)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) total_loss += float(loss)
...@@ -574,9 +607,14 @@ def main(): ...@@ -574,9 +607,14 @@ def main():
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm)) print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count)) print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
if(e==0): if(e==0):
torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
ap = 0 ap = 0
auc = 0 auc = 0
tt.ssim_remote=0
tt.ssim_local=0
tt.weight_count_local=0
tt.weight_count_remote=0
tt.ssim_cnt=0
ap, auc = eval('val') ap, auc = eval('val')
test_ap,test_auc = eval('test') test_ap,test_auc = eval('test')
test_ap_list.append((test_ap,test_auc)) test_ap_list.append((test_ap,test_auc))
...@@ -593,12 +631,13 @@ def main(): ...@@ -593,12 +631,13 @@ def main():
counts = counts.to('cpu') counts = counts.to('cpu')
node_degree[value] = counts node_degree[value] = counts
if dist.get_world_size()==1: if dist.get_world_size()==1:
mailbox.mon.draw(node_degree,args.dataname,e) mailbox.mon.draw(node_degree,args.dataname,args.model,e)
mailbox.mon.set_zero() mailbox.mon.set_zero()
#mailbox.mon.draw(node_degree,args.dataname,e) #mailbox.mon.draw(node_degree,args.dataname,e)
#mailbox.mon.set_zero() #mailbox.mon.set_zero()
loss_list.append(total_loss) loss_list.append(total_loss)
val_list.append(ap) val_list.append(ap)
if early_stop: if early_stop:
dist.barrier() dist.barrier()
print("Early stopping at epoch {:d}\n".format(e)) print("Early stopping at epoch {:d}\n".format(e))
...@@ -610,8 +649,12 @@ def main(): ...@@ -610,8 +649,12 @@ def main():
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}\n'.format(total_loss,train_ap, ap, auc,test_ap,test_auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep))
torch.save(model.module.state_dict(), get_checkpoint_path(e)) torch.save(model.module.state_dict(), get_checkpoint_path(e))
torch.save(val_list,'all/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) if args.model == 'TGN':
torch.save(loss_list,'all/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time) print(avg_time)
if not early_stop: if not early_stop:
......
...@@ -41,7 +41,7 @@ pinn_memory = {} ...@@ -41,7 +41,7 @@ pinn_memory = {}
class HistoricalCache: class HistoricalCache:
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 5, use_rpc = True, num_threshold = 0): def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
#self.cache_index = cache_index #self.cache_index = cache_index
self.layer = layer self.layer = layer
print(shape) print(shape)
...@@ -88,12 +88,17 @@ class HistoricalCache: ...@@ -88,12 +88,17 @@ class HistoricalCache:
return torch.sum((x -y)**2,dim = 1) return torch.sum((x -y)**2,dim = 1)
def historical_check(self,index,new_data,ts): def historical_check(self,index,new_data,ts):
if self.time_threshold is not None: if self.time_threshold is not None:
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold)) mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
else: else:
#print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index])) #print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index]))
#print(new_data,self.local_historical_data[index]) #print(new_data,self.local_historical_data[index])
#print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold)) #print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold))
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold) mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold)
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
return mask
def read_synchronize(self): def read_synchronize(self):
torch.cuda.synchronize(get_stream_set(self.layer)) torch.cuda.synchronize(get_stream_set(self.layer))
......
...@@ -2,6 +2,8 @@ from os.path import abspath, join, dirname ...@@ -2,6 +2,8 @@ from os.path import abspath, join, dirname
import os import os
import sys import sys
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
from starrygl.distributed.utils import DistIndex
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
import torch import torch
import dgl import dgl
...@@ -172,7 +174,7 @@ class MixerMLP(torch.nn.Module): ...@@ -172,7 +174,7 @@ class MixerMLP(torch.nn.Module):
self.block_padding(b) self.block_padding(b)
#return x #return x
class TransfomerAttentionLayer(torch.nn.Module): class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False): def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
...@@ -286,9 +288,22 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -286,9 +288,22 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1)) #att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2))) att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att) att = self.att_dropout(att)
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone()
V_remote = V.clone()
V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V b.edata['v'] = V
b.edata['v0'] = V_local
b.edata['v1'] = V_remote
b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h')) b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
tt.ssim_cnt += b.num_dst_nodes()
#print('dst {}'.format(b.dstdata['h'])) #print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0) #b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h')) #b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
...@@ -331,6 +331,7 @@ class TransformerMemoryUpdater(torch.nn.Module): ...@@ -331,6 +331,7 @@ class TransformerMemoryUpdater(torch.nn.Module):
self.mlp = torch.nn.Linear(dim_out, dim_out) self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout']) self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout']) self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
def forward(self, b, param = None): def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1)) Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
...@@ -396,7 +397,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -396,7 +397,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param) self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param)
self.updater = self.transformer_updater self.updater = self.transformer_updater
else: else:
self.updater = updater(dim_in + dim_time, dim_hid) self.ceil_updater = updater(dim_in + dim_time, dim_hid)
self.updater = self.rnn_updater self.updater = self.rnn_updater
self.last_updated_memory = None self.last_updated_memory = None
self.last_updated_ts = None self.last_updated_ts = None
...@@ -419,13 +420,30 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -419,13 +420,30 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.update_hunk = self.historical_func self.update_hunk = self.historical_func
elif self.mode == 'local' or self.mode=='all_local': elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func self.update_hunk = self.local_func
if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]),
requires_grad=True)
else:
self.gamma = 1
def forward(self, mfg, param = None): def forward(self, mfg, param = None):
for b in mfg: for b in mfg:
updated_memory = self.updater(b) mail_input = b.srcdata['mem_input']
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
upd0 = torch.zeros_like(updated_memory0)
upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem']),updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() self.last_updated_nid = b.srcdata['ID'].detach().clone()
with torch.no_grad(): with torch.no_grad():
if param is not None: if param is not None:
_,src,dst,ts,edge_feats,nxt_fetch_func = param _,src,dst,ts,edge_feats,nxt_fetch_func = param
...@@ -434,12 +452,14 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -434,12 +452,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_memory[indx], self.last_updated_memory[indx],
self.last_updated_ts[indx], self.last_updated_ts[indx],
None) None)
#print(index.shape[0])
if param[0]: if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail( index, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats, b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory, self.last_updated_memory,
None,False,False, None,False,False,block=b
) )
#print(index.shape[0])
if torch.distributed.get_world_size() == 0: if torch.distributed.get_world_size() == 0:
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory) self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape) ##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
...@@ -447,14 +467,14 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -447,14 +467,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max') self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max') self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func) self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func)
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += memory b.srcdata['h'] += updated_memory
else: else:
b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h']) b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else: else:
b.srcdata['h'] = memory b.srcdata['h'] = updated_memory
def empty_cache(self): def empty_cache(self):
pass pass
...@@ -86,7 +86,7 @@ class GeneralModel(torch.nn.Module): ...@@ -86,7 +86,7 @@ class GeneralModel(torch.nn.Module):
# else: # else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -123,7 +123,7 @@ def get_node_all_to_all_route(graph:DistributedGraphStore,mailbox:SharedMailBox, ...@@ -123,7 +123,7 @@ def get_node_all_to_all_route(graph:DistributedGraphStore,mailbox:SharedMailBox,
ind_dict = graph.nfeat.all_to_all_ind2ptr(query_nid_feature,group = group) ind_dict = graph.nfeat.all_to_all_ind2ptr(query_nid_feature,group = group)
memory,memory_ts,mail,mail_ts = mailbox.gather_local_memory(ind_dict['recv_ind'],compute_device=out_device) memory,memory_ts,mail,mail_ts = mailbox.gather_local_memory(ind_dict['recv_ind'],compute_device=out_device)
memory_ts = memory_ts.reshape(-1,1) memory_ts = memory_ts.reshape(-1,1)
mail_ts = mail_ts.reshape(-1,1) mail_ts = mail_ts.reshape(mail_ts.shape[0],-1)
mail = mail.reshape(mail.shape[0],-1) mail = mail.reshape(mail.shape[0],-1)
data.append(torch.cat((memory,memory_ts,mail,mail_ts),dim = 1)) data.append(torch.cat((memory,memory_ts,mail,mail_ts),dim = 1))
if ind_dict is not None: if ind_dict is not None:
...@@ -144,14 +144,14 @@ def get_edge_all_to_all_route(graph:DistributedGraphStore,query_eid_feature,out_ ...@@ -144,14 +144,14 @@ def get_edge_all_to_all_route(graph:DistributedGraphStore,query_eid_feature,out_
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid=None,dist_eid=None): def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid=None,dist_eid=None):
for i,mfg in enumerate(mfgs): for i,mfg in enumerate(mfgs):
for b in mfg: for b in mfg:
e_idx = b.edata['__ID']
idx = b.srcdata['__ID'] idx = b.srcdata['__ID']
if dist_eid is not None: if '__ID' in b.edata:
e_idx = b.edata['__ID']
b.edata['ID'] = dist_eid[e_idx] b.edata['ID'] = dist_eid[e_idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if dist_nid is not None: if dist_nid is not None:
b.srcdata['ID'] = dist_nid[idx] b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0: if i == 0:
if node_feat is not None: if node_feat is not None:
b.srcdata['h'] = node_feat[idx] b.srcdata['h'] = node_feat[idx]
...@@ -336,15 +336,96 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -336,15 +336,96 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
data,mfgs,metadata = build_block() data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid return (data,mfgs,metadata),dist_nid,dist_eid
def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),unique = True,identity=False):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
nid_mapper: torch.Tensor = graph.nids_mapper
if identity is False:
assert len(sample_out) == 1
ret = sample_out[0]
eid_len = ret.eid().shape[0]
t0 = time.time()
dst_ts = ret.sample_nodes_ts().to(device)
dst = nid_mapper[ret.sample_nodes()].to(device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device)
else:
src_index = torch.tensor([],dtype=torch.long,device=device)
dst = torch.tensor([],dtype=torch.long,device=device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
root_ts = data.ts.to(device)
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
root_ts = metadata.pop('seed_ts').to(device)
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = metadata[k].to(device)
src_node = root_node
src_ts = root_ts
nid_tensor = torch.cat([root_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
"""
对于同id和同时间的节点去重取得index
"""
block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True)
first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id)
first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool)
first_mask[first_index] = True
first_index = unq_id[first_mask]
first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype)
first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype)
first_block_id = first_block_id[unq_id].contiguous()
block_node_list = block_node_list[:,first_index]
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = first_block_id[metadata[k]]
t2 = time.time()
def build_block():
mfgs = list()
col_len = 0
row_len = root_len
col = first_block_id[:row_len]
max_row = col.max().item()+1
b = dgl.create_block((col[src_index].to(device),
torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1,
num_dst_nodes=dst.shape[0])
idx = block_node_list[0,b.srcnodes()].to(torch.long)
b.srcdata['__ID'] = idx
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
b.dstdata['ID'] = dst
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid
import concurrent.futures import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None): def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None,reversed=False):
t_s = time.time() t_s = time.time()
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device} param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param) out = sample_fn(sampler,data,neg_sampling,**param)
out,dist_nid,dist_eid = to_block(graph,data,out,out_device) if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device,reversed)
else:
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,reversed)
t_e = time.time() t_e = time.time()
#print(t_e-t_s) #print(t_e-t_s)
return out,dist_nid,dist_eid return out,dist_nid,dist_eid
......
...@@ -12,6 +12,13 @@ class time_count: ...@@ -12,6 +12,13 @@ class time_count:
time_sample_and_build = 0 time_sample_and_build = 0
time_memory_fetch = 0 time_memory_fetch = 0
weight_count_remote = 0
weight_count_local = 0
ssim_remote = 0
ssim_cnt = 0
ssim_local = 0
ssim_cnt = 0
@staticmethod @staticmethod
def _zero(): def _zero():
time_count.time_forward = 0 time_count.time_forward = 0
...@@ -34,8 +41,9 @@ class time_count: ...@@ -34,8 +41,9 @@ class time_count:
def start(): def start():
return time.perf_counter(),0 return time.perf_counter(),0
@staticmethod @staticmethod
def elapsed_event(start_event,end_event): def elapsed_event(start_event):
if start_event.isinstance(torch.cuda.Event): if isinstance(start_event,tuple):
start_event,end_event = start_event
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
return start_event.elapsed_time(end_event) return start_event.elapsed_time(end_event)
...@@ -51,4 +59,6 @@ class time_count: ...@@ -51,4 +59,6 @@ class time_count:
time_count.time_local_update, time_count.time_local_update,
time_count.time_memory_sync, time_count.time_memory_sync,
time_count.time_sample_and_build, time_count.time_sample_and_build,
time_count.time_memory_fetch )) time_count.time_memory_fetch ))
\ No newline at end of file
\ No newline at end of file
...@@ -100,8 +100,10 @@ class DistributedDataLoader: ...@@ -100,8 +100,10 @@ class DistributedDataLoader:
cache_mask = None, cache_mask = None,
use_local_feature = True, use_local_feature = True,
probability = 1, probability = 1,
reversed = False,
**kwargs **kwargs
): ):
self.reversed = reversed
self.use_local_feature = use_local_feature self.use_local_feature = use_local_feature
self.local_embedding = local_embedding self.local_embedding = local_embedding
self.chunk_size = chunk_size self.chunk_size = chunk_size
...@@ -223,7 +225,7 @@ class DistributedDataLoader: ...@@ -223,7 +225,7 @@ class DistributedDataLoader:
if self.mode=='train' and self.probability < 1: if self.mode=='train' and self.probability < 1:
mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank())) mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank()))
if self.probability > 0: if self.probability > 0:
mask[~mask] = (torch.rand((~mask)) < self.probability) mask[~mask] = (torch.rand((~mask).sum().item()) < self.probability)
next_data = next_data[mask.to(next_data.device)] next_data = next_data[mask.to(next_data.device)]
self.submitted = self.submitted + 1 self.submitted = self.submitted + 1
return next_data return next_data
...@@ -239,13 +241,14 @@ class DistributedDataLoader: ...@@ -239,13 +241,14 @@ class DistributedDataLoader:
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper, eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
) )
self.result_queue.append((fut)) self.result_queue.append((fut))
@torch.no_grad() @torch.no_grad()
def async_feature(self): def async_feature(self):
if(self.recv_idxs >= self.expected_idx): if(self.recv_idxs >= self.expected_idx or self.is_pipeline == False):
return return
is_local = (self.is_train & self.use_local_feature) is_local = (self.is_train & self.use_local_feature)
if(is_local): if(is_local):
...@@ -303,7 +306,9 @@ class DistributedDataLoader: ...@@ -303,7 +306,9 @@ class DistributedDataLoader:
data,self.neg_sampler, data,self.neg_sampler,
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper) eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
root,mfgs,metadata = batch_data root,mfgs,metadata = batch_data
t_sample = tt.elapsed_event(t0) t_sample = tt.elapsed_event(t0)
...@@ -312,6 +317,12 @@ class DistributedDataLoader: ...@@ -312,6 +317,12 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
mask = DistIndex(batch_data[1][0][0].srcdata['ID']).is_shared
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
t_fetch = tt.elapsed_event(t1) t_fetch = tt.elapsed_event(t1)
tt.time_memory_fetch += t_fetch tt.time_memory_fetch += t_fetch
#if(self.mailbox is not None and self.mailbox.historical_cache is not None): #if(self.mailbox is not None and self.mailbox.historical_cache is not None):
...@@ -335,10 +346,15 @@ class DistributedDataLoader: ...@@ -335,10 +346,15 @@ class DistributedDataLoader:
data,self.neg_sampler, data,self.neg_sampler,
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper) eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
#if(self.mailbox is not None and self.mailbox.historical_cache is not None): #if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID'] # id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared # mask = DistIndex(id).is_shared
...@@ -363,14 +379,14 @@ class DistributedDataLoader: ...@@ -363,14 +379,14 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = node_feat0[:,:self.graph.nfeat.shape[1]] node_feat = node_feat0[:,:self.graph.nfeat.shape[1]]
if self.graph.nfeat.shape[1] < node_feat0.shape[1]: if self.graph.nfeat.shape[1] < node_feat0.shape[1]:
mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:]) mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:],mailbox = True)
else: else:
mem = None mem = None
elif self.mailbox is not None: elif self.mailbox is not None:
node_feat0[1].wait() node_feat0[1].wait()
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = None node_feat = None
mem = self.mailbox.unpack(node_feat0) mem = self.mailbox.unpack(node_feat0,mailbox = True)
#print(node_feat.shape,edge_feat.shape,mem[0].shape) #print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait() #node_feat[1].wait()
#node_feat = node_feat[0] #node_feat = node_feat[0]
...@@ -394,7 +410,11 @@ class DistributedDataLoader: ...@@ -394,7 +410,11 @@ class DistributedDataLoader:
if(self.mailbox is not None and self.mailbox.historical_cache is not None): if(self.mailbox is not None and self.mailbox.historical_cache is not None):
id = batch_data[1][0][0].srcdata['ID'] id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(id).is_shared mask = DistIndex(id).is_shared
batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]] batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)#self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]]
#his_mem = torch.clone(batch_data[1][0][0].srcdata['mem']) #his_mem = torch.clone(batch_data[1][0][0].srcdata['mem'])
#his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc] #his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc]
#maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask] #maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask]
......
import starrygl
from typing import Union
from typing import List
from typing import Optional
import torch
from torch.distributed import rpc
import torch_scatter
from starrygl import distributed
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist
from starrygl.module import historical_cache
from starrygl.sample.graph_core.utils import _get_pin
from starrygl.sample.memory.change import MemoryMoniter
#from starrygl.utils.uvm import cudaMemoryAdvise
class SharedMailBox():
'''
We will first define our mailbox, including our definitions of mialbox and memory:
.. code-block:: python
from starrygl.sample.memory.shared_mailbox import SharedMailBox
mailbox = SharedMailBox(num_nodes=num_nodes, mefinishmory_param=memory_param, dim_edge_feat=dim_edge_feat)
Args:
num_nodes (int): number of nodes
memory_param (dict): the memory parameters in the yaml file,refer to TGL
dim_edge_feat (int): the dim of edge feature
device (torch.device): the device used to store MailBox
uvm (bool): 1-use uvm, 0-don't use uvm
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.memory.shared_mailbox import SharedMailBox
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
We then need to hand over the mailbox to the data loader as in the above example, so that the relevant memory/mailbox can be directly loaded during training.
During the training, we will call `get_update_memory`/`get_update_mail` function constantly updates
the relevant storage,which is the idea related to TGN.
'''
def __init__(self,
num_nodes,
memory_param,
dim_edge_feat,
shared_nodes_index = None,
device = torch.device('cuda'),
ts_dtye = torch.float32,
uvm = False,
use_pin = False,
cache_route = None,
shared_ssim = 2):
ctx = distributed.context._get_default_dist_context()
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
if memory_param['type'] != 'node':
raise NotImplementedError
self.memory_param = memory_param
self.memory_size = memory_param['dim_out']
assert not (device.type =='cpu' and uvm is True),\
'set uvm must set device on cuda'
memory_device = device
if device.type == 'cuda' and uvm is True:
memory_device = torch.device('cpu')
node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =memory_device)
node_memory_ts = torch.zeros(self.num_nodes,
dtype=ts_dtye,
device = self.device)
mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat,
device = memory_device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']),
dtype=ts_dtye,device = self.device)
self.mailbox_shape = len(mailbox[0,:].reshape(-1))
self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox)
self.mailbox_ts = DistributedTensor(mailbox_ts)
self.next_mail_pos = torch.zeros((self.num_nodes),
dtype=torch.long,
device = self.device)
self.tot_comm_count = 0
self.tot_shared_count = 0
self.shared_nodes_index = None
self.deliver_to = memory_param['deliver_to'] if 'deliver_to' in memory_param else 'self'
if shared_nodes_index is not None:
self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank))
self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1
self.is_shared_mask[shared_nodes_index] = torch.arange(self.shared_nodes_index.shape[0],dtype=torch.int,
device=torch.device('cuda:{}'.format(ctx.local_rank)))
print(self.shared_nodes_index)
if cache_route is not None:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim)
self._mem_pin = {}
self._mail_pin = {}
self.use_pin = use_pin
self.last_memory_sync = None
self.last_job = None
self.mon = MemoryMoniter()
def reset(self):
self.node_memory.accessor.data.zero_()
self.node_memory_ts.accessor.data.zero_()
self.mailbox.accessor.data.zero_()
self.mailbox_ts.accessor.data.zero_()
self.next_mail_pos.zero_()
self.historical_cache.empty()
self.last_memory_sync = None
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.node_memory.accessor.data[index] = source
self.node_memory_ts.accessor.data[index] = source_ts.float()
def set_mailbox_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[index] = torch.remainder(
self.next_mail_pos[index] + 1,
self.memory_param['mailbox_size'])
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False,
block = None,Reduce_score=None,):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
dst = dst.to(self.device)
index = torch.cat([src, dst]).reshape(-1)
index = dist_indx_mapper[index]
mem_src = memory[src]
mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts)
#print(self.deliver_to)
if self.deliver_to == 'neighbors':
assert block is not None and Reduce_score is None
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None:
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
mail_ts = max_ts
mail = mail[idx]
index = unq_index
else:
unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts = mail_ts[idx]
mail = mail[idx]
index = unq_index
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts,embedding):
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts
index = unq_index
memory = memory[idx]
return index,memory,ts
def pack(self,memory=None,memory_ts=None,mail=None,mail_ts=None,index = None,mode=None):
if memory is not None and mail is not None:
mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
elif mail is not None:
mem = torch.cat((mail,mail_ts.view(-1,1)),dim = 1)
else:
mem = torch.cat((memory,memory_ts.view(-1,1)),dim = 1)
#if index is None:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem
def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1]
mail_ts = mem[:,-1].view(-1)
return mail,mail_ts
elif mailbox is False:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:-1]
mail_ts = mem[:,-1].view(-1)
return memory,memory_ts,mail,mail_ts
else:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None):
if self.last_memory_sync is not None:
gather_id_list,handle0,gather_memory,handle1 = self.last_memory_sync
self.last_memory_sync = None
handle0.wait()
handle1.wait()
if isinstance(gather_memory,list):
gather_memory = torch.cat(gather_memory,dim = 0)
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update()
if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out
index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data
self.node_memory_ts.accessor.data[index] = shared_ts
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self):
ctx = DistributedContext.get_default_context()
if self.last_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.last_job
self.last_job = None
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
def update_p2p(self):
if self.last_job is None:
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.last_job
self.last_job = None
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
#print(input_split,output_split)
handle1 = torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part)
if set_remote is False:
pass
# self.set_mailbox_local(DistIndex(index).loc,mail,mail_ts,Reduce_Op = reduce_Op)
# self.set_memory_local(DistIndex(index).loc,memory,memory_ts, Reduce_Op = reduce_Op)
else:
#print(index,memory,memory_ts)
if async_op == True and self.num_parts > 1:
#local_mask = (DistIndex(index).loc == dist.get_rank())
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op=async_op,submit=submit)
else:
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op = False)
if self.shared_nodes_index is not None and (mode == 'all_reduce' or mode == 'historical'):
shared_memory_ind = self.is_shared_mask[torch.min(DistIndex(index).loc,torch.tensor([self.num_nodes-1],device=torch.device('cuda')))]
mask = ((shared_memory_ind>-1)&(DistIndex(index).part==ctx.memory_group_rank))
shared_memory_ind = shared_memory_ind[mask]
shared_memory = memory[mask]
shared_memory_ts = memory_ts[mask]
shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask]
if mode == 'historical':
#print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
#print(update_index.sum(),shared_memory_ind.shape)
shared_memory_ind = shared_memory_ind[update_index]
shared_memory = shared_memory[update_index]
shared_memory_ts = shared_memory_ts[update_index]
shared_mail = shared_mail[update_index]
shared_mail_ts = shared_mail_ts[update_index]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
dist.all_gather(shared_len,broadcast_len,group = ctx.memory_nccl_group)
shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
if mode == 'all_reduce':
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if async_op == True:
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
self.last_memory_sync = shared_id_list,handle0,shared_id_list,handle1
else:
dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
mem = torch.cat(shared_list,dim = 0)
shared_index = torch.cat(shared_id_list)
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
#self.historical_cache.synchronize_shared_update(filter)
#mem[:,:-1] = mem[:,:-1] + filter.get_incretment(shared_memory_ind)
#shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
#handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
#shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
#handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.last_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if ~submit:
self.update_shared()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
shared_memory_ts = self.node_memory_ts.accessor.data[self.shared_nodes_index]
shared_mail = self.mailbox.accessor.data[self.shared_nodes_index]
shared_mail_ts = self.mailbox_ts.accessor.data[self.shared_nodes_index]
torch.distributed.all_reduce(shared_memory,group = ctx.memory_nccl_group,async_op = async_op,op=dist.ReduceOp.SUM)
torch.distributed.all_reduce(shared_memory_ts,group = ctx.memory_nccl_group,async_op = async_op,op=dist.ReduceOp.SUM)
torch.distributed.all_reduce(shared_mail,group = ctx.memory_nccl_group,async_op = async_op,op=dist.ReduceOp.SUM)
torch.distributed.all_reduce(shared_mail_ts,group = ctx.memory_nccl_group,async_op = async_op,op=dist.ReduceOp.SUM)
self.node_memory.accessor.data[self.shared_nodes_index]=shared_memory/ctx.memory_group_size
self.node_memory_ts.accessor.data[self.shared_nodes_index]=shared_memory_ts/ctx.memory_group_size
self.mailbox.accessor.data[self.shared_nodes_index]=shared_mail/ctx.memory_group_size
self.mailbox_ts.accessor.data[self.shared_nodes_index]=shared_mail_ts/ctx.memory_group_size
"""
def set_mailbox_all_to_all_empty(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
#indic = torch.searchsorted(index,self.partptr,right=False)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op = True)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None,async_op=False,submit = True):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
mem = self.pack(memory,memory_ts,mail,mail_ts)
gather_memory = torch.empty(
[gather_len_list.sum(),mem.shape[1]],
dtype = memory.dtype,device = self.device)
gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
if async_op == True:
self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
if ~submit:
self.update_p2p()
else:
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None,is_async=False,
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group=group,is_async=is_async),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group=group,is_async=is_async),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group=group,is_async=is_async),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group=group,is_async=is_async)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group,is_async=is_async),\
self.node_memory_ts.all_to_all_get(**ids,group = group,is_async=is_async),\
self.mailbox.all_to_all_get(**ids,group = group,is_async=is_async),\
self.mailbox_ts.all_to_all_get(**ids,group = group,is_async=is_async)
def gather_local_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
compute_device = torch.device('cuda')
):
local_index = DistIndex(dist_index).loc
rows = local_index.shape[0]
#print(local_index.max(),local_index.min(),self.num_nodes)
if self.node_memory.device.type == 'cpu' and self.use_pin:
mem_pin = self._get_mem_pin(0,rows)
mail_pin = self._get_mail_pin(0,rows)
torch.index_select(self.node_memory.accessor.data,0,local_index,mem_pin)
torch.index_select(self.mailbox.accessor.data,0,local_index,mail_pin)
return mem_pin.to(compute_device,non_blocking=True),self.node_memory_ts[local_index.to('cpu')].to(compute_device),\
mail_pin.to(compute_device,non_blocking=True),self.node_memory_ts[local_index.to('cpu')].to(compute_device),
return self.node_memory.accessor.data[local_index.to(self.device)].to(compute_device),\
self.node_memory_ts.accessor.data[local_index.to(self.device)].to(compute_device),\
self.mailbox.accessor.data[local_index.to(self.device)].to(compute_device),\
self.mailbox_ts.accessor.data[local_index.to(self.device)].to(compute_device)
def _get_mem_pin(self, layer: int, rows: int) -> torch.Tensor:
return _get_pin(self._mem_pins, layer, rows, self.node_memory.shape[1:])
def _get_mail_pin(self, layer: int, rows: int) -> torch.Tensor:
return _get_pin(self._mail_pins, layer, rows, self.mailbox.shape[1:])
"""
def set_memory_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_memory_local(index,source,source_ts)
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_memory_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def add_to_mailbox_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_mailbox_local(index,source,source_ts)
else:
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_mailbox_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False,
remote_src = None, remote_dst = None,Reduce_score=None):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
dst = dst.to(self.device)
index = torch.cat([src, dst]).reshape(-1)
index = dist_indx_mapper[index]
mem_src = memory[src]
mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
if remote_src is None:
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
else:
src_mail = torch.cat([emb_src if use_src_emb else mem_src, remote_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, remote_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats[:src_mail.shape[0]]], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats[src_mail.shape[0]:]], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = ts.to(self.device).to(self.mailbox_ts.dtype)
#.reshape(-1, src_mail.shape[1])
if Reduce_score is None:
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
mail_ts = max_ts
mail = mail[idx]
index = unq_index
else:
unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts = mail_ts[idx]
mail = mail[idx]
index = unq_index
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts
memory = memory[idx]
index = unq_index
return index,memory,ts
def get_memory(self,index,local = False):
if self.num_parts == 1 or local is True:
return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index]
elif self.node_memory.rrefs is None:
return self.gather_memory(dist_index = index)
else:
memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index)
mail = self.mailbox.index_select(index)
mail_ts = self.mailbox_ts.index_select(index)
def callback(fs):
memory,memory_ts,mail,mail_ts = fs.value()
memory = memory.value()
memory_ts = memory_ts.value()
mail = mail.value()
mail_ts = mail_ts.value()
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group),\
self.node_memory_ts.all_to_all_get(**ids,group = group),\
self.mailbox.all_to_all_get(**ids,group = group),\
self.mailbox_ts.all_to_all_get(**ids,group = group)
"""
...@@ -17,10 +17,10 @@ class MemoryMoniter: ...@@ -17,10 +17,10 @@ class MemoryMoniter:
self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F')) self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos')) self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
self.nid_list.append(nid) self.nid_list.append(nid)
def draw(self,degree,data,e): def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/memorynid_{}.pt'.format(data,e)) torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/memoryF_{}.pt'.format(data,e)) torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all/{}/memcos_{}.pt'.format(data,e)) torch.save(self.memory_ssim,'all/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data) # path = './memory/{}/'.format(data)
# if not os.path.exists(path): # if not os.path.exists(path):
......
...@@ -98,6 +98,7 @@ class SharedMailBox(): ...@@ -98,6 +98,7 @@ class SharedMailBox():
self.tot_comm_count = 0 self.tot_comm_count = 0
self.tot_shared_count = 0 self.tot_shared_count = 0
self.shared_nodes_index = None self.shared_nodes_index = None
self.deliver_to = memory_param['deliver_to'] if 'deliver_to' in memory_param else 'self'
if shared_nodes_index is not None: if shared_nodes_index is not None:
self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank)) self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank))
self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1 self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1
...@@ -152,7 +153,7 @@ class SharedMailBox(): ...@@ -152,7 +153,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False, memory,embedding=None,use_src_emb=False,use_dst_emb=False,
deliver_to='self',block = None,Reduce_score=None,): block = None,Reduce_score=None,):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -172,11 +173,13 @@ class SharedMailBox(): ...@@ -172,11 +173,13 @@ class SharedMailBox():
mail = torch.cat([src_mail, dst_mail], dim=0) mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts) #print(mail_ts)
if deliver_to == 'neighbor': #print(self.deliver_to)
assert block is None and Reduce_score is not None if self.deliver_to == 'neighbors':
mail = torch.cat([mail, mail[block.edges()[1].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[1].long()]], dim=0) assert block is not None and Reduce_score is None
index = torch.cat([index,dist_indx_mapper[block.edges()[0].long()]],dim=0) mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None: if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device) Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None: if Reduce_score is None:
...@@ -214,18 +217,23 @@ class SharedMailBox(): ...@@ -214,18 +217,23 @@ class SharedMailBox():
#else: #else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1) # mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem return mem
def unpack(self,mem): def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 : if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1] mail = mem[:,: -1]
mail_ts = mem[:,-1].view(-1) mail_ts = mem[:,-1].view(-1)
return mail,mail_ts return mail,mail_ts
else: elif mailbox is False:
memory = mem[:,:self.node_memory.shape[1]] memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1) memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:-1] mail = mem[:,self.node_memory.shape[1]+1:-1]
mail_ts = mem[:,-1].view(-1) mail_ts = mem[:,-1].view(-1)
return memory,memory_ts,mail,mail_ts return memory,memory_ts,mail,mail_ts
else:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None): def handle_last_async(self,reduce_Op = None):
if self.last_memory_sync is not None: if self.last_memory_sync is not None:
...@@ -247,10 +255,11 @@ class SharedMailBox(): ...@@ -247,10 +255,11 @@ class SharedMailBox():
if out is not None: if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out shared_index,shared_data,shared_ts,mail,mail_ts = out
index = self.shared_nodes_index[shared_index] index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory_ts.accessor.data[index] = shared_ts self.node_memory.accessor.data[index][mask] = shared_data[mask]
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail self.node_memory_ts.accessor.data[index][mask] = shared_ts[mask]
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts #self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
#self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self): def update_shared(self):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
if self.last_job is not None: if self.last_job is not None:
...@@ -273,6 +282,7 @@ class SharedMailBox(): ...@@ -273,6 +282,7 @@ class SharedMailBox():
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op) input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1) self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True): def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part) #print(DistIndex(index).part)
...@@ -337,6 +347,7 @@ class SharedMailBox(): ...@@ -337,6 +347,7 @@ class SharedMailBox():
#,shared_memory,shared_memory_ts, #,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem) #shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem) shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
#print(shared_memory_ts,shared_mail_ts)
unq_index,inv = torch.unique(shared_index,return_inverse = True) unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape) #print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
......
...@@ -20,19 +20,23 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -20,19 +20,23 @@ class LocalNegativeSampling(NegativeSampling):
unique: bool = False, unique: bool = False,
src_node_list: torch.Tensor = None, src_node_list: torch.Tensor = None,
dst_node_list: torch.Tensor = None, dst_node_list: torch.Tensor = None,
seed = False local_mask = None,
seed = None
): ):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique) super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
self.dst_node_list = dst_node_list.to('cpu') if dst_node_list is not None else None self.dst_node_list = dst_node_list.to('cpu') if dst_node_list is not None else None
self.rdm = torch.Generator() self.rdm = torch.Generator()
if seed is True: if seed is not None:
random.seed(seed) random.seed(seed)
seed = random.randint(0,100000) seed = random.randint(0,100000)
print('seed is',seed) print('seed is',seed)
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
self.rdm.manual_seed(seed^ctx.rank) self.rdm.manual_seed(seed^ctx.rank)
self.rdm = torch.Generator() self.rdm = torch.Generator()
self.local_mask = local_mask
if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask]
#self.rdm.manual_seed(42) #self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list)) #print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool: def is_binary(self) -> bool:
...@@ -53,6 +57,12 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -53,6 +57,12 @@ class LocalNegativeSampling(NegativeSampling):
else: else:
if self.dst_node_list is None: if self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm) return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None:
p = torch.rand(size=(num_samples,))
sr = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
sl = torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)
return torch.where(p<0.9,sl,sr)
else: else:
return self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)] s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s]
...@@ -227,29 +227,14 @@ class NeighborSampler(BaseSampler): ...@@ -227,29 +227,14 @@ class NeighborSampler(BaseSampler):
sampled_nodes: the node sampled sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled sampled_edge_index_list: the edge sampled
""" """
if(ts is None): if self.policy != 'identity':
self.part_unique = True
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None, self.part_unique)
ret = self.p_sampler.get_ret()
return ret
else:
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None) self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None)
ret = self.p_sampler.get_ret() ret = self.p_sampler.get_ret()
metadata = {} else:
if is_unique: ret = None
self.p_sampler.sample_unique( metadata = {}
nodes,ts.float(),nid_mapper, #print(nodes.shape[0],ret[0].src_index().max(),ret[0].src_index().min())
eid_mapper,'cpu' if out_device.type == 'cpu' else str(out_device.index)) return ret,metadata
metadata = {
'eid_inv':self.p_sampler.eid_inv,
'first_block_id':self.p_sampler.first_block_id,
'block_node_list':self.p_sampler.block_node_list,
'unq_id':self.p_sampler.unq_id,
'dist_nid':self.p_sampler.dist_nid,
'dist_eid':self.p_sampler.dist_eid,
}
#print(nodes.shape[0],ret[0].src_index().max(),ret[0].src_index().min())
return ret,metadata
def sample_from_edges( def sample_from_edges(
self, self,
...@@ -329,13 +314,7 @@ class NeighborSampler(BaseSampler): ...@@ -329,13 +314,7 @@ class NeighborSampler(BaseSampler):
else: else:
seed, inverse_seed = seed.unique(return_inverse=True) seed, inverse_seed = seed.unique(return_inverse=True)
""" """
metadata = {} out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
if is_unique:
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=is_unique,eid_mapper=eid_mapper,nid_mapper=nid_mapper,out_device=out_device)
first_block_id = self.p_sampler.first_block_id
else:
#print('is unique')
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device) src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device) dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet(): if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
...@@ -344,12 +323,6 @@ class NeighborSampler(BaseSampler): ...@@ -344,12 +323,6 @@ class NeighborSampler(BaseSampler):
else: else:
src_neg_index = torch.arange(2*num_pos,3*num_pos,dtype= torch.long,device=out_device) src_neg_index = torch.arange(2*num_pos,3*num_pos,dtype= torch.long,device=out_device)
dst_neg_index=torch.arange(3*num_pos,seed.shape[0],dtype= torch.long,device=out_device) dst_neg_index=torch.arange(3*num_pos,seed.shape[0],dtype= torch.long,device=out_device)
#if is_unique:
# src_pos_index = first_block_id[src_pos_index].contiguous()
# dst_pos_index = first_block_id[dst_pos_index].contiguous()
# dst_neg_index = first_block_id[dst_neg_index].contiguous()
# src_neg_index = first_block_id[src_neg_index].contiguous()
metadata['seed'] = seed metadata['seed'] = seed
metadata['seed_ts'] = seed_ts metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index metadata['src_pos_index']=src_pos_index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment