Commit 4927a9e0 by zhlj
parents d1ea5b9d 99e5b95a
......@@ -3,10 +3,7 @@ sampling:
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
......
sampling:
- no_sample: True
- strategy: 'identity'
history: 1
memory:
- type: 'node'
dim_time: 100
......
......@@ -3,10 +3,7 @@ sampling:
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
......
bash test_all.sh 13357 > 13357.out
wait
bash test_all.sh 12347 > 12347.out
wait
bash test_all.sh 63377 > 63377.out
......
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763']
partition = 'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0
mem='all_update'#'historical'
model='TGN'
for sd in seed :
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
if data == 'WIKI' or data =='LASTFM':
model = 'TGN'
else:
model = 'TGN_large'
if p == 1:
file = 'all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = 'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
with open(file, 'r') as file:
for line in file:
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
ap_list.append(test_ap)
print('data {} seed {} ap: {}'.format(data,sd,ap_list))
# prefix = 'best test AP:'
# cnt = 0
# sum = 0
# with open(file, 'r') as file:
# for line in file:
# if line.startswith(prefix):
# ap = float(line.lstrip(prefix).split(' ')[0])
# pos = line.find('remote node number tensor')
# if(pos!=-1):
# posr = line.find(']',pos+2+len('remote node number tensor'),)
# #print(line,line[pos+2+len('remote node number tensor'):posr])
# comm = int(line[pos+2+len('remote node number tensor'):posr])
# #print()
# sum = sum+comm
# cnt = cnt+1
# #print(comm)
# ap_list.append(ap)
# comm_list.append(sum/cnt*4)
# # 绘制柱状图
# print('{} TestAP={}\n'.format(data,ap_list))
# bar_width = 0.4
# #shared comm tensor
# # 设置柱状图的位置
# bars = range(len(probability_values))
# # 绘制柱状图
# plt.bar([b for b in bars], ap_list, width=bar_width)
# # 绘制柱状图
# plt.ylim([0.9,1])
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Test AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# print(comm_list)
# plt.bar([b for b in bars], comm_list, width=bar_width)
# # 绘制柱状图
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Communication volume')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# if partition == 'ours_shared':
# partition0 = 'ours'
# else:
# partition0=partition
# for p in probability_values:
# file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
# val_ap = torch.tensor(torch.load(file))[:,0]
# epoch = torch.arange(val_ap.shape[0])
# #绘制曲线图
# plt.plot(epoch,val_ap, label='probability={}'.format(p))
# plt.xlabel('Epoch')
# plt.ylabel('Val AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# # plt.grid(True)
# plt.legend()
# plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
# plt.clf()
......@@ -4,15 +4,15 @@ import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','DGraphFin'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473',' 54763']
data_values = ['WIKI','WikiTalk'] # 存储从文件中读取的数据
seed = ['13357','12347','63377','53473','54763']
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='all_update'#'historical'
mem='local'#'historical'
model='TGN'
for sd in seed :
for data in data_values:
......
......@@ -6,19 +6,20 @@ addr="192.168.1.107"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="4"
partitions="8"
node_per="4"
nnodes="1"
node_rank="0"
probability_params=("0.1" "0" "0.05" "0.01")
sample_type_params=("boundery_recent_decay" "recent")
nnodes="2"
node_rank="1"
probability_params=("0.1" "0.05")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update")
memory_type=( "all_update" "historical" "local")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
data_param=("StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
......@@ -30,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
model="JODIE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="JODIE"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......@@ -57,8 +58,8 @@ for data in "${data_param[@]}"; do
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
# wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
......@@ -80,13 +81,13 @@ for data in "${data_param[@]}"; do
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
#else
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
# wait
# if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
# wait
# fi
fi
done
done
......
......@@ -126,7 +126,7 @@ def seed_everything(seed=42):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(args.seed)
total_next_batch = 0
total_forward = 0
total_count_score = 0
......@@ -267,9 +267,15 @@ def main():
if args.local_neg_sample:
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
......@@ -338,10 +344,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda()
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox)
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg))
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
def count_parameters(model):
......@@ -531,9 +537,12 @@ def main():
model.train()
optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item())
#mailbox.handle_last_async()
#trainloader.async_feature()
......@@ -663,9 +672,9 @@ def main():
pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all_args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(val_list,'all_{}/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_{}/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_{}/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time)
if not early_stop:
......
......@@ -2,7 +2,7 @@
#跑了4卡的TaoBao
# 定义数组变量
seed=$1
addr="192.168.1.107"
addr="192.168.1.106"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
......@@ -11,13 +11,13 @@ node_per="4"
nnodes="2"
<<<<<<< Updated upstream
node_rank="0"
probability_params=("0.1" "0.01" "0.05")
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "all_update")
memory_type=( "historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<< HEAD
data_param=("GDELT")
......@@ -48,9 +48,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
model="APAN"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="APAN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......
LOCAL RANK 2, RANK2
LOCAL RANK 1, RANK1
LOCAL RANK 3, RANK3
LOCAL RANK 0, RANK0
LOCAL RANK 2, RANK2
LOCAL RANK 3, RANK3
in
in
in
in
local rank is 1 world_size is 4 memory group is 0 memory rank is 1 memory group size is 4
local rank is 0 world_size is 4 memory group is 0 memory rank is 0 memory group size is 4
local rank is 3 world_size is 4 memory group is 0 memory rank is 3 memory group size is 4
[0, 1, 2, 3]
[0, 1, 2, 3][0, 1, 2, 3]
local rank is 2 world_size is 4 memory group is 0 memory rank is 2 memory group size is 4
[0, 1, 2, 3]
use cuda on 0
use cuda on 3
use cuda on 2
use cuda on 1
LOCAL RANK 0, RANK0
use cuda on 0
994790
get_neighbors consume: 6.18508s
Epoch 0:
train loss:12236.7578 train ap:0.986976 val ap:0.934674 val auc:0.946284
total time:630.37s prep time:545.79s
fetch time:0.00s write back time:0.00s
Epoch 1:
train loss:11833.1818 train ap:0.987815 val ap:0.960581 val auc:0.965728
total time:628.44s prep time:542.56s
fetch time:0.00s write back time:0.00s
Epoch 2:
train loss:11622.9559 train ap:0.988244 val ap:0.956752 val auc:0.963083
total time:622.89s prep time:538.77s
fetch time:0.00s write back time:0.00s
Epoch 3:
train loss:11679.1400 train ap:0.988072 val ap:0.929351 val auc:0.943797
total time:681.88s prep time:569.50s
fetch time:0.00s write back time:0.00s
Epoch 4:
train loss:11676.1710 train ap:0.988098 val ap:0.936353 val auc:0.948531
total time:849.98s prep time:741.47s
fetch time:0.00s write back time:0.00s
Epoch 5:
train loss:11745.6001 train ap:0.987897 val ap:0.950828 val auc:0.958958
total time:862.77s prep time:750.90s
fetch time:0.00s write back time:0.00s
Epoch 6:
Early stopping at epoch 6
Loading the best model at epoch 1
0.9248434901237488 0.929413378238678
0.8653780221939087 0.861071765422821
test AP:0.847958 test AUC:0.837159
test_dataset 6647176 avg_time 87.00003329753876
......@@ -228,16 +228,19 @@ def main():
num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False
if policy != 'recent':
policy_train = args.sample_type#'boundery_recent_decay'
else:
policy_train = policy
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),cache_route = cache_route,shared_ssim=args.shared_memory_ssim)
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),shared_ssim=args.shared_memory_ssim)
else:
mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
......@@ -272,8 +275,10 @@ def main():
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
train_ratio_neg = args.probability * (1-remote_ratio)
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
......@@ -402,7 +407,8 @@ def main():
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
mailbox.update_shared()
mailbox.update_p2p()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
"""
if mailbox is not None:
src = metadata['src_pos_index']
......@@ -536,7 +542,7 @@ def main():
optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones/train_ratio_pos,ones/train_ratio_neg).reshape(-1,1)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
......@@ -554,7 +560,8 @@ def main():
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize()
mailbox.update_shared()
mailbox.update_p2p()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
#torch.cuda.empty_cache()
"""
......
......@@ -18,7 +18,7 @@ class CachePushRoute():
def __init__(self,local_num, local_drop_edge, dist_index):
ctx = DistributedContext.get_default_context()
dist_drop = dist_index[local_drop_edge]
id_mask = ~(DistIndex(dist_drop) == ctx.memory_group_rank)
id_mask = not (DistIndex(dist_drop) == ctx.memory_group_rank)
remote_id = local_drop_edge[id_mask]
route = torch.zeros(local_num,dtype=torch.long)
src_mask = DistIndex(dist_drop[0,:]).part == ctx.memory_group_rank
......
......@@ -352,31 +352,60 @@ class TransformerMemoryUpdater(torch.nn.Module):
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
return rst
"""
(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
set_p2p=False,
mode=None,
filter=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False)
"""
class AsyncMemeoryUpdater(torch.nn.Module):
def all_update_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
if nxt_fetch_func is not None:
nxt_fetch_func()
def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
def all_update_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.set_memory_all_reduce(
index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op='max',async_op=False,
mode='all_reduce',
wait_submit=False,spread_mail=spread_mail,
update_cross_mm=True,
)
#self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
if nxt_fetch_func is not None:
nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
# def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.handle_last_async()
# submit_to_queue = False
# if nxt_fetch_func is not None:
# nxt_fetch_func()
# submit_to_queue = True
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
# def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
# if nxt_fetch_func is not None:
# nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='historical',set_remote=False,submit = submit_to_queue)
def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(
index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op='max',async_op=True,
mode='historical',
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None:
nxt_fetch_func()
def transformer_updater(self,b):
......@@ -473,7 +502,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None)
#print(index.shape[0])
if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail(
index0, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory,
None,False,False,block=b
......@@ -483,9 +512,12 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask = (DistIndex(index).part==torch.distributed.get_rank())
self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func)
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
#self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
......
......@@ -97,6 +97,7 @@ class NeighborSampler(BaseSampler):
node_part = None,
edge_part = None,
probability = 1,
no_neg = False,
) -> None:
r"""__init__
Args:
......@@ -122,7 +123,7 @@ class NeighborSampler(BaseSampler):
self.is_distinct = is_distinct
assert graph_name is not None
self.graph_name = graph_name
self.no_neg = no_neg
if(tnb is None):
if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort()
......@@ -314,7 +315,10 @@ class NeighborSampler(BaseSampler):
else:
seed, inverse_seed = seed.unique(return_inverse=True)
"""
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
if self.no_neg:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False)
else:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]], seed_ts[:seed.shape[0]//3*2], is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment