Commit 99e5b95a by xxx

fix bugs and add APAN model

parent b305c21a
......@@ -3,10 +3,7 @@ sampling:
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
......
sampling:
- no_sample: True
- strategy: 'identity'
history: 1
memory:
- type: 'node'
dim_time: 100
......
......@@ -3,10 +3,7 @@ sampling:
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
......
bash test_all.sh 13357 > 13357.out
wait
bash test_all.sh 12347 > 12347.out
wait
bash test_all.sh 63377 > 63377.out
......
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763']
partition = 'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0
mem='all_update'#'historical'
model='TGN'
for sd in seed :
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
if data == 'WIKI' or data =='LASTFM':
model = 'TGN'
else:
model = 'TGN_large'
if p == 1:
file = 'all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = 'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
with open(file, 'r') as file:
for line in file:
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
ap_list.append(test_ap)
print('data {} seed {} ap: {}'.format(data,sd,ap_list))
# prefix = 'best test AP:'
# cnt = 0
# sum = 0
# with open(file, 'r') as file:
# for line in file:
# if line.startswith(prefix):
# ap = float(line.lstrip(prefix).split(' ')[0])
# pos = line.find('remote node number tensor')
# if(pos!=-1):
# posr = line.find(']',pos+2+len('remote node number tensor'),)
# #print(line,line[pos+2+len('remote node number tensor'):posr])
# comm = int(line[pos+2+len('remote node number tensor'):posr])
# #print()
# sum = sum+comm
# cnt = cnt+1
# #print(comm)
# ap_list.append(ap)
# comm_list.append(sum/cnt*4)
# # 绘制柱状图
# print('{} TestAP={}\n'.format(data,ap_list))
# bar_width = 0.4
# #shared comm tensor
# # 设置柱状图的位置
# bars = range(len(probability_values))
# # 绘制柱状图
# plt.bar([b for b in bars], ap_list, width=bar_width)
# # 绘制柱状图
# plt.ylim([0.9,1])
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Test AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# print(comm_list)
# plt.bar([b for b in bars], comm_list, width=bar_width)
# # 绘制柱状图
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Communication volume')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# if partition == 'ours_shared':
# partition0 = 'ours'
# else:
# partition0=partition
# for p in probability_values:
# file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
# val_ap = torch.tensor(torch.load(file))[:,0]
# epoch = torch.arange(val_ap.shape[0])
# #绘制曲线图
# plt.plot(epoch,val_ap, label='probability={}'.format(p))
# plt.xlabel('Epoch')
# plt.ylabel('Val AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# # plt.grid(True)
# plt.legend()
# plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
# plt.clf()
......@@ -4,15 +4,15 @@ import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763']
data_values = ['WIKI','WikiTalk'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473','54763']
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='all_update'#'historical'
mem='local'#'historical'
model='TGN'
for sd in seed :
for data in data_values:
......
......@@ -6,19 +6,20 @@ addr="192.168.1.107"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="4"
partitions="8"
node_per="4"
nnodes="1"
node_rank="0"
probability_params=("0.1" "0" "0.05" "0.01")
sample_type_params=("boundery_recent_decay" "recent")
nnodes="2"
node_rank="1"
probability_params=("0.1" "0.05")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update")
memory_type=( "all_update" "historical" "local")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
data_param=("StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
......@@ -30,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
model="JODIE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="JODIE"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......@@ -57,8 +58,8 @@ for data in "${data_param[@]}"; do
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
# wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
......@@ -80,13 +81,13 @@ for data in "${data_param[@]}"; do
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
#else
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
# wait
# if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
# wait
# fi
fi
done
done
......
......@@ -126,7 +126,7 @@ def seed_everything(seed=42):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(args.seed)
total_next_batch = 0
total_forward = 0
total_count_score = 0
......@@ -267,9 +267,15 @@ def main():
if args.local_neg_sample:
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
......@@ -338,10 +344,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda()
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox)
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg))
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
def count_parameters(model):
......@@ -531,9 +537,12 @@ def main():
model.train()
optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item())
#mailbox.handle_last_async()
#trainloader.async_feature()
......@@ -663,9 +672,9 @@ def main():
pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all_args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(val_list,'all_{}/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_{}/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_{}/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time)
if not early_stop:
......
......@@ -2,7 +2,7 @@
#跑了4卡的TaoBao
# 定义数组变量
seed=$1
addr="192.168.1.107"
addr="192.168.1.106"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
......@@ -10,13 +10,13 @@ partitions="8"
node_per="4"
nnodes="2"
node_rank="0"
probability_params=("0.1" "0.01" "0.05")
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update")
memory_type=( "historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<< HEAD
data_param=("GDELT")
......@@ -34,9 +34,9 @@ data_param=("LASTFM")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
model="APAN"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="APAN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......
LOCAL RANK 2, RANK2
LOCAL RANK 1, RANK1
LOCAL RANK 3, RANK3
LOCAL RANK 0, RANK0
LOCAL RANK 2, RANK2
LOCAL RANK 3, RANK3
in
in
in
in
local rank is 1 world_size is 4 memory group is 0 memory rank is 1 memory group size is 4
local rank is 0 world_size is 4 memory group is 0 memory rank is 0 memory group size is 4
local rank is 3 world_size is 4 memory group is 0 memory rank is 3 memory group size is 4
[0, 1, 2, 3]
[0, 1, 2, 3][0, 1, 2, 3]
local rank is 2 world_size is 4 memory group is 0 memory rank is 2 memory group size is 4
[0, 1, 2, 3]
use cuda on 0
use cuda on 3
use cuda on 2
use cuda on 1
LOCAL RANK 0, RANK0
use cuda on 0
994790
get_neighbors consume: 6.18508s
Epoch 0:
train loss:12236.7578 train ap:0.986976 val ap:0.934674 val auc:0.946284
total time:630.37s prep time:545.79s
fetch time:0.00s write back time:0.00s
Epoch 1:
train loss:11833.1818 train ap:0.987815 val ap:0.960581 val auc:0.965728
total time:628.44s prep time:542.56s
fetch time:0.00s write back time:0.00s
Epoch 2:
train loss:11622.9559 train ap:0.988244 val ap:0.956752 val auc:0.963083
total time:622.89s prep time:538.77s
fetch time:0.00s write back time:0.00s
Epoch 3:
train loss:11679.1400 train ap:0.988072 val ap:0.929351 val auc:0.943797
total time:681.88s prep time:569.50s
fetch time:0.00s write back time:0.00s
Epoch 4:
train loss:11676.1710 train ap:0.988098 val ap:0.936353 val auc:0.948531
total time:849.98s prep time:741.47s
fetch time:0.00s write back time:0.00s
Epoch 5:
train loss:11745.6001 train ap:0.987897 val ap:0.950828 val auc:0.958958
total time:862.77s prep time:750.90s
fetch time:0.00s write back time:0.00s
Epoch 6:
Early stopping at epoch 6
Loading the best model at epoch 1
0.9248434901237488 0.929413378238678
0.8653780221939087 0.861071765422821
test AP:0.847958 test AUC:0.837159
test_dataset 6647176 avg_time 87.00003329753876
......@@ -228,16 +228,19 @@ def main():
num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
policy_train = args.sample_type#'boundery_recent_decay'
no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False
if policy != 'recent':
policy_train = args.sample_type#'boundery_recent_decay'
else:
policy_train = policy
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),cache_route = cache_route,shared_ssim=args.shared_memory_ssim)
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),shared_ssim=args.shared_memory_ssim)
else:
mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
......@@ -272,8 +275,10 @@ def main():
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
train_ratio_neg = args.probability * (1-remote_ratio)
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
......@@ -402,7 +407,8 @@ def main():
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
mailbox.update_shared()
mailbox.update_p2p()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
"""
if mailbox is not None:
src = metadata['src_pos_index']
......@@ -536,7 +542,7 @@ def main():
optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones/train_ratio_pos,ones/train_ratio_neg).reshape(-1,1)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
......@@ -554,7 +560,8 @@ def main():
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize()
mailbox.update_shared()
mailbox.update_p2p()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
#torch.cuda.empty_cache()
"""
......
......@@ -18,7 +18,7 @@ class CachePushRoute():
def __init__(self,local_num, local_drop_edge, dist_index):
ctx = DistributedContext.get_default_context()
dist_drop = dist_index[local_drop_edge]
id_mask = ~(DistIndex(dist_drop) == ctx.memory_group_rank)
id_mask = not (DistIndex(dist_drop) == ctx.memory_group_rank)
remote_id = local_drop_edge[id_mask]
route = torch.zeros(local_num,dtype=torch.long)
src_mask = DistIndex(dist_drop[0,:]).part == ctx.memory_group_rank
......
......@@ -352,31 +352,60 @@ class TransformerMemoryUpdater(torch.nn.Module):
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
return rst
"""
(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
set_p2p=False,
mode=None,
filter=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False)
"""
class AsyncMemeoryUpdater(torch.nn.Module):
def all_update_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
if nxt_fetch_func is not None:
nxt_fetch_func()
def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
def all_update_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.set_memory_all_reduce(
index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op='max',async_op=False,
mode='all_reduce',
wait_submit=False,spread_mail=spread_mail,
update_cross_mm=True,
)
#self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
if nxt_fetch_func is not None:
nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
# def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.handle_last_async()
# submit_to_queue = False
# if nxt_fetch_func is not None:
# nxt_fetch_func()
# submit_to_queue = True
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
# def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
# if nxt_fetch_func is not None:
# nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='historical',set_remote=False,submit = submit_to_queue)
def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(
index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op='max',async_op=True,
mode='historical',
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None:
nxt_fetch_func()
def transformer_updater(self,b):
......@@ -473,7 +502,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None)
#print(index.shape[0])
if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail(
index0, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory,
None,False,False,block=b
......@@ -483,9 +512,12 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask = (DistIndex(index).part==torch.distributed.get_rank())
self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func)
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
#self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
......
......@@ -58,7 +58,7 @@ class SharedMailBox():
ts_dtye = torch.float32,
uvm = False,
use_pin = False,
cache_route = None,
start_historical = False,
shared_ssim = 2):
ctx = distributed.context._get_default_dist_context()
self.device = device
......@@ -106,13 +106,16 @@ class SharedMailBox():
device=torch.device('cuda:{}'.format(ctx.local_rank)))
print(self.shared_nodes_index)
if cache_route is not None:
if start_historical is not None:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim)
self._mem_pin = {}
self._mail_pin = {}
self.use_pin = use_pin
self.last_memory_sync = None
self.last_job = None
self.last_mail_sync = None
self.next_wait_memory_job=None
self.next_wait_mail_job=None
self.next_wait_gather_memory_job=None
self.mon = MemoryMoniter()
......@@ -124,6 +127,7 @@ class SharedMailBox():
self.next_mail_pos.zero_()
self.historical_cache.empty()
self.last_memory_sync = None
self.last_mail_sync = None
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
......@@ -212,11 +216,8 @@ class SharedMailBox():
mem = torch.cat((mail,mail_ts.view(-1,1)),dim = 1)
else:
mem = torch.cat((memory,memory_ts.view(-1,1)),dim = 1)
#if index is None:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem
def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1]
......@@ -234,8 +235,10 @@ class SharedMailBox():
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None):
"""
sychronization last async all-to-all M&M communication task
"""
def handle_last_memory(self,reduce_Op=None,):
if self.last_memory_sync is not None:
gather_id_list,handle0,gather_memory,handle1 = self.last_memory_sync
self.last_memory_sync = None
......@@ -249,56 +252,167 @@ class SharedMailBox():
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_mail(self,reduce_Op=None,):
if self.last_mail_sync is not None:
gather_id_list,handle0,gather_memory,handle1 = self.last_mail_sync
self.last_mail_sync = None
handle0.wait()
handle1.wait()
if isinstance(gather_memory,list):
gather_memory = torch.cat(gather_memory,dim = 0)
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_async(self,reduce_Op = None):
self.handle_last_memory(self,reduce_Op)
self.handle_last_mail(self,reduce_Op)
"""
sychronization last async all-gather memory communication task
"""
def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update()
if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out
shared_index,shared_data,shared_ts = out
index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data
self.node_memory_ts.accessor.data[index] = shared_ts
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory.accessor.data[index][mask] = shared_data[mask]
self.node_memory_ts.accessor.data[index][mask] = shared_ts[mask]
def update_shared(self):
ctx = DistributedContext.get_default_context()
if self.last_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.last_job
self.last_job = None
if self.next_wait_gather_memory_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.self.next_wait_gather_memory_job
self.next_wait_gather_memory_job = None
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
def update_p2p(self):
if self.last_job is None:
def update_p2p_mem(self):
if self.next_wait_gather_memory_job is None:
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.last_job
self.last_job = None
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_gather_memory_job
self.next_wait_gather_memory_job = None
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
#print(input_split,output_split)
handle1 = torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
def update_p2p_mail(self):
if self.next_wait_gather_mail_job is None:
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_gather_mail_job
self.next_wait_gather_mail_job = None
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
handle1 = torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_mail_sync = (gather_id_list,handle0,gather_memory,handle1)
"""
submit: take the task request into queue wait for start
"""
def build_all_to_all_route(self,index,mm,mm_ts,is_no_redundant):
ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part)
if set_remote is False:
gather_len_list = torch.empty([self.num_parts],dtype=int,device=self.device)
ind = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part,self.num_parts)
scatter_len_list = ind[1:] - ind[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = ctx.memory_nccl_group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
if is_no_redundant:
gather_ts = torch.empty(
[gather_len_list.sum()],
dtype = mm_ts.dtype,device = self.device
)
gather_id = torch.empty(
[gather_len_list.sum()],
dtype = index.dtype,device = self.device
)
torch.distributed.all_to_all_single(gather_id,index,
output_split_sizes=output_split,
input_split_sizes = input_split,
group = ctx.memory_nccl_group,
)
torch.distributed.all_to_all_single(gather_ts,mm_ts,
output_split_sizes=output_split,
input_split_sizes = input_split,
group = ctx.memory_nccl_group,
)
unq_id,inv = gather_id.unique(return_inverse=True)
max_ts,pos = torch_scatter.scatter_max(gather_ts,inv,dim = 0)
is_used = torch.zeros(gather_ts.shape,device=gather_ts.device,dtype=torch.bool)
is_used[pos] = True
send_mm_to_dst = torch.zeros([scatter_len_list.sum().item()],device = is_used.device,dtype=torch.bool)
torch.distributed.all_to_all_single(send_mm_to_dst,is_used,
output_split = input_split,
input_split = output_split,
group = ctx.memory_nccl_group,
)
index = index[send_mm_to_dst]
mm = mm[send_mm_to_dst]
mm_ts = mm_ts[send_mm_to_dst]
gather_len_list = torch.empty([self.num_parts],dtype=int,device=self.device)
ind = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part,self.num_parts)
scatter_len_list = ind[1:] - ind[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = ctx.memory_nccl_group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id = gather_id[is_used]
else:
gather_id = torch.empty([gather_len_list.sum()],dtype = index.dtype(),device = self.device)
gather_memory = torch.empty(
[gather_len_list.sum(),mm.shape[1]],
dtype = mm.dtype,device=self.device
)
return index,gather_id,mm,gather_memory,input_split,output_split
def build_all_to_all_async_task(self,index,mm,mm_ts,is_no_redundant=False,is_async=True):
ctx = DistributedContext.get_default_context()
p2p_async_info = self.build_all_to_all_route(self,index,mm,mm_ts,is_no_redundant,is_async=True)
if mm.shape[1] != self.mailbox.shape[-1] + 1:
self.next_wait_memory_job = (*p2p_async_info,is_async,ctx.memory_nccl_group)
else:
self.next_wait_mail_job = (*p2p_async_info,is_async,ctx.memory_nccl_group)
def set_memory_all_reduce(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
mode=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False):
if self.num_parts == 1:
return
if not spread_mail and not update_cross_mm:
pass
# self.set_mailbox_local(DistIndex(index).loc,mail,mail_ts,Reduce_Op = reduce_Op)
# self.set_memory_local(DistIndex(index).loc,memory,memory_ts, Reduce_Op = reduce_Op)
else:
#print(index,memory,memory_ts)
if async_op == True and self.num_parts > 1:
#local_mask = (DistIndex(index).loc == dist.get_rank())
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op=async_op,submit=submit)
if spread_mail:
mm = torch.cat((mail,mail_ts.reshape(-1,1)),dim=1)
mm_ts = mail_ts
self.build_all_to_all_async_task(mail_index,mm,mm_ts,is_async=True,is_no_redundant=True)
if update_cross_mm:
mm = torch.cat((memory,memory_ts.reshape(-1,1)),dim=1)
mm_ts = memory_ts
self.build_all_to_all_async_task(mail_index,mm,mm_ts,is_async=True,is_no_redundant=True)
else:
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op = False)
if update_cross_mm:
mm = self.pack(memory,memory_ts,mail,mail_ts,index)
mm_ts = memory_ts
self.build_all_to_all_async_task(mail_index,mm,mm_ts,is_async=True,is_no_redundant=True)
if async_op is False:
self.update_p2p_mail()
self.update_p2p_mem()
self.handle_last_async()
ctx = DistributedContext.get_default_context()
if self.shared_nodes_index is not None and (mode == 'all_reduce' or mode == 'historical'):
shared_memory_ind = self.is_shared_mask[torch.min(DistIndex(index).loc,torch.tensor([self.num_nodes-1],device=torch.device('cuda')))]
shared_memory_ind = self.is_shared_mask[torch.min(DistIndex(index).loc,torch.tensor([self.num_nodes-1],device=index.device))]
mask = ((shared_memory_ind>-1)&(DistIndex(index).part==ctx.memory_group_rank))
shared_memory_ind = shared_memory_ind[mask]
shared_memory = memory[mask]
......@@ -306,20 +420,16 @@ class SharedMailBox():
shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask]
if mode == 'historical':
#print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
#print(update_index.sum(),shared_memory_ind.shape)
shared_memory_ind = shared_memory_ind[update_index]
shared_memory = shared_memory[update_index]
shared_memory_ts = shared_memory_ts[update_index]
shared_mail = shared_mail[update_index]
shared_mail_ts = shared_mail_ts[update_index]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
else:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
......@@ -327,53 +437,34 @@ class SharedMailBox():
shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
if mode == 'all_reduce':
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if async_op == True:
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
self.last_memory_sync = shared_id_list,handle0,shared_id_list,handle1
else:
dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
mem = torch.cat(shared_list,dim = 0)
shared_index = torch.cat(shared_id_list)
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
#if async_op == True:
# handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
# handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
# self.last_memory_sync = shared_id_list,handle0,shared_id_list,handle1
#else:
dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
mem = torch.cat(shared_list,dim = 0)
shared_index = torch.cat(shared_id_list)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
#self.historical_cache.synchronize_shared_update(filter)
#mem[:,:-1] = mem[:,:-1] + filter.get_incretment(shared_memory_ind)
#shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
#handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
#shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
#handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.last_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if ~submit:
self.update_shared()
self.next_wait_gather_memory_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if not wait_submit:
self.update_shared()
self.update_p2p_mail()
self.update_p2p_mem()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
......@@ -394,107 +485,54 @@ class SharedMailBox():
def set_mailbox_all_to_all_empty(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
#indic = torch.searchsorted(index,self.partptr,right=False)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op = True)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
pass
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None,async_op=False,submit = True):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
mem = self.pack(memory,memory_ts,mail,mail_ts)
gather_memory = torch.empty(
[gather_len_list.sum(),mem.shape[1]],
dtype = memory.dtype,device = self.device)
gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
if async_op == True:
self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
if ~submit:
self.update_p2p()
else:
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
pass
# #futs: List[torch.futures.Future] = []
# if self.num_parts == 1:
# dist_index = DistIndex(index)
# index = dist_index.loc
# self.set_mailbox_local(index,mail,mail_ts)
# self.set_memory_local(index,memory,memory_ts)
# else:
# self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
# gather_len_list = torch.empty([self.num_parts],
# dtype = int,
# device = self.device)
# indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
# scatter_len_list = indic[1:] - indic[0:-1]
# torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
# input_split = scatter_len_list.tolist()
# output_split = gather_len_list.tolist()
# mem = self.pack(memory,memory_ts,mail,mail_ts)
# gather_memory = torch.empty(
# [gather_len_list.sum(),mem.shape[1]],
# dtype = memory.dtype,device = self.device)
# gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
# input_split = scatter_len_list.tolist()
# output_split = gather_len_list.tolist()
# if async_op == True:
# self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
# if not submit:
# self.update_p2p()
# else:
# torch.distributed.all_to_all_single(
# gather_id_list,index,output_split_sizes=output_split,
# input_split_sizes=input_split,group = group,async_op=async_op)
# torch.distributed.all_to_all_single(
# gather_memory,mem,
# output_split_sizes=output_split,
# input_split_sizes=input_split,group = group)
# if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
# gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
# self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
# else:
# gather_memory,gather_memory_ts = self.unpack(gather_memory)
# self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def gather_memory(
self,
......
......@@ -58,7 +58,7 @@ class SharedMailBox():
ts_dtye = torch.float32,
uvm = False,
use_pin = False,
cache_route = None,
start_historical = False,
shared_ssim = 2):
ctx = distributed.context._get_default_dist_context()
self.device = device
......@@ -105,14 +105,16 @@ class SharedMailBox():
self.is_shared_mask[shared_nodes_index] = torch.arange(self.shared_nodes_index.shape[0],dtype=torch.int,
device=torch.device('cuda:{}'.format(ctx.local_rank)))
print(self.shared_nodes_index)
if cache_route is not None:
if start_historical is not None:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim)
self._mem_pin = {}
self._mail_pin = {}
self.use_pin = use_pin
self.last_memory_sync = None
self.last_job = None
self.last_mail_sync = None
self.next_wait_memory_job=None
self.next_wait_mail_job=None
self.next_wait_gather_memory_job=None
self.mon = MemoryMoniter()
......@@ -124,6 +126,7 @@ class SharedMailBox():
self.next_mail_pos.zero_()
self.historical_cache.empty()
self.last_memory_sync = None
self.last_mail_sync = None
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
......@@ -212,11 +215,8 @@ class SharedMailBox():
mem = torch.cat((mail,mail_ts.view(-1,1)),dim = 1)
else:
mem = torch.cat((memory,memory_ts.view(-1,1)),dim = 1)
#if index is None:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem
def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1]
......@@ -234,8 +234,10 @@ class SharedMailBox():
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None):
"""
sychronization last async all-to-all M&M communication task
"""
def handle_last_memory(self,reduce_Op=None,):
if self.last_memory_sync is not None:
gather_id_list,handle0,gather_memory,handle1 = self.last_memory_sync
self.last_memory_sync = None
......@@ -249,7 +251,23 @@ class SharedMailBox():
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_mail(self,reduce_Op=None,):
if self.last_mail_sync is not None:
gather_id_list,handle0,gather_memory,handle1 = self.last_mail_sync
self.last_mail_sync = None
handle0.wait()
handle1.wait()
if isinstance(gather_memory,list):
gather_memory = torch.cat(gather_memory,dim = 0)
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_async(self,reduce_Op = None):
self.handle_last_memory(reduce_Op)
self.handle_last_mail(reduce_Op)
"""
sychronization last async all-gather memory communication task
"""
def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update()
if out is not None:
......@@ -258,71 +276,167 @@ class SharedMailBox():
mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory.accessor.data[index][mask] = shared_data[mask]
self.node_memory_ts.accessor.data[index][mask] = shared_ts[mask]
#self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
#self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self):
ctx = DistributedContext.get_default_context()
if self.last_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.last_job
self.last_job = None
if self.next_wait_gather_memory_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.next_wait_gather_memory_job
self.next_wait_gather_memory_job = None
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
def update_p2p(self):
if self.last_job is None:
def update_p2p_mem(self):
if self.next_wait_memory_job is None:
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.last_job
self.last_job = None
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_memory_job
self.next_wait_memory_job = None
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
#print(input_split,output_split)
handle1 = torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def update_p2p_mail(self):
if self.next_wait_mail_job is None:
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_mail_job
self.next_wait_mail_job = None
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
handle1 = torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_mail_sync = (gather_id_list,handle0,gather_memory,handle1)
"""
submit: take the task request into queue wait for start
"""
def build_all_to_all_route(self,index,mm,mm_ts,is_no_redundant):
ctx = DistributedContext.get_default_context()
gather_len_list = torch.empty([self.num_parts],dtype=int,device=self.device)
ind = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part,self.num_parts)
scatter_len_list = ind[1:] - ind[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = ctx.memory_nccl_group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
if is_no_redundant:
gather_ts = torch.empty(
[gather_len_list.sum()],
dtype = mm_ts.dtype,device = self.device
)
gather_id = torch.empty(
[gather_len_list.sum()],
dtype = index.dtype,device = self.device
)
torch.distributed.all_to_all_single(gather_id,index,
output_split_sizes=output_split,
input_split_sizes = input_split,
group = ctx.memory_nccl_group,
)
torch.distributed.all_to_all_single(gather_ts,mm_ts,
output_split_sizes=output_split,
input_split_sizes = input_split,
group = ctx.memory_nccl_group,
)
unq_id,inv = gather_id.unique(return_inverse=True)
max_ts,pos = torch_scatter.scatter_max(gather_ts,inv,dim = 0)
is_used = torch.zeros(gather_ts.shape,device=gather_ts.device,dtype=torch.int8)
is_used[pos] = 1
send_mm_to_dst = torch.zeros([scatter_len_list.sum().item()],device = is_used.device,dtype=torch.int8)
torch.distributed.all_to_all_single(send_mm_to_dst,is_used,
output_split_sizes = input_split,
input_split_sizes = output_split,
group = ctx.memory_nccl_group,
)
index = index[send_mm_to_dst>0]
mm = mm[send_mm_to_dst>0]
mm_ts = mm_ts[send_mm_to_dst>0]
gather_len_list = torch.empty([self.num_parts],dtype=int,device=self.device)
ind = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part,self.num_parts)
scatter_len_list = ind[1:] - ind[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = ctx.memory_nccl_group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id = gather_id[is_used>0]
else:
gather_id = torch.empty([gather_len_list.sum()],dtype = index.dtype,device = self.device)
gather_memory = torch.empty(
[gather_len_list.sum(),mm.shape[1]],
dtype = mm.dtype,device=self.device
)
return index,gather_id,mm,gather_memory,input_split,output_split
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
def build_all_to_all_async_task(self,index,mm,mm_ts,is_no_redundant=False,is_async=True):
ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part)
if set_remote is False:
p2p_async_info = self.build_all_to_all_route(index,mm,mm_ts,is_no_redundant)
if mm.shape[1] != self.mailbox.shape[-1] + 1:
self.next_wait_memory_job = (*p2p_async_info,ctx.memory_nccl_group,is_async)
else:
self.next_wait_mail_job = (*p2p_async_info,ctx.memory_nccl_group,is_async)
def set_memory_all_reduce(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
mode=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False,
is_no_redundant = True):
if self.num_parts == 1:
return
if not spread_mail and not update_cross_mm:
pass
# self.set_mailbox_local(DistIndex(index).loc,mail,mail_ts,Reduce_Op = reduce_Op)
# self.set_memory_local(DistIndex(index).loc,memory,memory_ts, Reduce_Op = reduce_Op)
else:
#print(index,memory,memory_ts)
if async_op == True and self.num_parts > 1:
#local_mask = (DistIndex(index).loc == dist.get_rank())
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op=async_op,submit=submit)
if spread_mail:
mm = torch.cat((mail,mail_ts.reshape(-1,1)),dim=1)
mm_ts = mail_ts
self.build_all_to_all_async_task(mail_index,mm,mm_ts,is_async=True,is_no_redundant=False)
if update_cross_mm:
mm = torch.cat((memory,memory_ts.reshape(-1,1)),dim=1)
mm_ts = memory_ts
self.build_all_to_all_async_task(index,mm,mm_ts,is_async=True,is_no_redundant=False)
else:
self.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op='max',async_op = False)
if update_cross_mm:
mm = self.pack(memory,memory_ts,mail,mail_ts,index)
mm_ts = memory_ts
self.build_all_to_all_async_task(index,mm,mm_ts,is_async=True,is_no_redundant=False)
if async_op is False:
self.update_p2p_mail()
self.update_p2p_mem()
self.handle_last_async()
ctx = DistributedContext.get_default_context()
if self.shared_nodes_index is not None and (mode == 'all_reduce' or mode == 'historical'):
shared_memory_ind = self.is_shared_mask[torch.min(DistIndex(index).loc,torch.tensor([self.num_nodes-1],device=torch.device('cuda')))]
shared_memory_ind = self.is_shared_mask[torch.min(DistIndex(index).loc,torch.tensor([self.num_nodes-1],device=index.device))]
mask = ((shared_memory_ind>-1)&(DistIndex(index).part==ctx.memory_group_rank))
shared_memory_ind = shared_memory_ind[mask]
shared_memory = memory[mask]
shared_memory_ts = memory_ts[mask]
if spread_mail:
shared_mail_indx = self.is_shared_mask[torch.min(DistIndex(mail_index).loc,torch.tensor([self.num_nodes-1],device=index.device))]
mask = ((shared_mail_indx>-1)&(DistIndex(shared_mail_indx).part==ctx.memory_group_rank))
shared_mail_indx = shared_mail_indx[mask]
else:
shared_mail_indx = shared_memory_ind
shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask]
if mode == 'historical':
#print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
#print(update_index.sum(),shared_memory_ind.shape)
shared_memory_ind = shared_memory_ind[update_index]
shared_memory = shared_memory[update_index]
shared_memory_ts = shared_memory_ts[update_index]
shared_mail = shared_mail[update_index]
shared_mail_ts = shared_mail_ts[update_index]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
else:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
if not spread_mail:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
else:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0]
......@@ -331,54 +445,61 @@ class SharedMailBox():
shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
if mode == 'all_reduce':
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if async_op == True:
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
self.last_memory_sync = shared_id_list,handle0,shared_id_list,handle1
dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
mem = torch.cat(shared_list,dim = 0)
shared_index = torch.cat(shared_id_list)
if spread_mail:
shared_memory,shared_memory_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_index = unq_index
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_mail_indx.shape[0]
shared_len = [torch.empty([1],device = mail.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
dist.all_gather(shared_len,broadcast_len,group = ctx.memory_nccl_group)
mail = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_mail_indx,mode=mode)
shared_mail_list = [torch.empty([l.item(),mail.shape[1]],device = mail.device,dtype=mail.dtype) for l in shared_len]
shared_mail_id_list = [torch.empty([l.item()],device = shared_mail_indx.device,dtype=shared_mail_indx.dtype) for l in shared_len]
#print(mail.shape)
dist.all_gather(shared_mail_list,mail,group=ctx.memory_nccl_group)
dist.all_gather(shared_mail_id_list,shared_mail_indx,group=ctx.memory_nccl_group)
shared_mail_indx = torch.cat(shared_mail_id_list,dim=0)
mail = torch.cat(shared_mail_list,dim=0)
shared_mail,shared_mail_ts = self.unpack(mail)
unq_index,inv = torch.unique(shared_mail_indx,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_mail_ts,inv,0)
shared_mail= shared_mail[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail_indx = unq_index
else:
dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
mem = torch.cat(shared_list,dim = 0)
shared_index = torch.cat(shared_id_list)
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
#print(shared_memory_ts,shared_mail_ts)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
shared_mail_indx = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
self.set_mailbox_local(self.shared_nodes_index[shared_mail_indx],shared_mail,shared_mail_ts)
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
#self.historical_cache.synchronize_shared_update(filter)
#mem[:,:-1] = mem[:,:-1] + filter.get_incretment(shared_memory_ind)
#shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
#handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
#shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
#handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.last_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if ~submit:
self.update_shared()
self.next_wait_gather_memory_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if not wait_submit:
self.update_shared()
self.update_p2p_mail()
self.update_p2p_mem()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
......@@ -399,107 +520,54 @@ class SharedMailBox():
def set_mailbox_all_to_all_empty(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
#indic = torch.searchsorted(index,self.partptr,right=False)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op = True)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
pass
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None,async_op=False,submit = True):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
mem = self.pack(memory,memory_ts,mail,mail_ts)
gather_memory = torch.empty(
[gather_len_list.sum(),mem.shape[1]],
dtype = memory.dtype,device = self.device)
gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
if async_op == True:
self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
if ~submit:
self.update_p2p()
else:
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
torch.distributed.all_to_all_single(
gather_memory,mem,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
pass
# #futs: List[torch.futures.Future] = []
# if self.num_parts == 1:
# dist_index = DistIndex(index)
# index = dist_index.loc
# self.set_mailbox_local(index,mail,mail_ts)
# self.set_memory_local(index,memory,memory_ts)
# else:
# self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
# gather_len_list = torch.empty([self.num_parts],
# dtype = int,
# device = self.device)
# indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
# scatter_len_list = indic[1:] - indic[0:-1]
# torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
# input_split = scatter_len_list.tolist()
# output_split = gather_len_list.tolist()
# mem = self.pack(memory,memory_ts,mail,mail_ts)
# gather_memory = torch.empty(
# [gather_len_list.sum(),mem.shape[1]],
# dtype = memory.dtype,device = self.device)
# gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
# input_split = scatter_len_list.tolist()
# output_split = gather_len_list.tolist()
# if async_op == True:
# self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
# if not submit:
# self.update_p2p()
# else:
# torch.distributed.all_to_all_single(
# gather_id_list,index,output_split_sizes=output_split,
# input_split_sizes=input_split,group = group,async_op=async_op)
# torch.distributed.all_to_all_single(
# gather_memory,mem,
# output_split_sizes=output_split,
# input_split_sizes=input_split,group = group)
# if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
# gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
# self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
# else:
# gather_memory,gather_memory_ts = self.unpack(gather_memory)
# self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def gather_memory(
self,
......
......@@ -97,6 +97,7 @@ class NeighborSampler(BaseSampler):
node_part = None,
edge_part = None,
probability = 1,
no_neg = False,
) -> None:
r"""__init__
Args:
......@@ -122,7 +123,7 @@ class NeighborSampler(BaseSampler):
self.is_distinct = is_distinct
assert graph_name is not None
self.graph_name = graph_name
self.no_neg = no_neg
if(tnb is None):
if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort()
......@@ -314,7 +315,10 @@ class NeighborSampler(BaseSampler):
else:
seed, inverse_seed = seed.unique(return_inverse=True)
"""
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
if self.no_neg:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False)
else:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]], seed_ts[:seed.shape[0]//3*2], is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment