Commit a1d8044f by zhlj

fix alpha policy

parent eacb2444
......@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
data_param=("WikiTalk" "StackOverflow" "GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......@@ -31,71 +31,71 @@ data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
mkdir all_"$seed"/"$data"/"$model"
mkdir all_"$seed"/"$data"/"$model"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 --seed "$seed" > all_"$seed"/"$data"/"$model"/1.out &
wait
for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do
if [ "$sample" = "recent" ]; then
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
#if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
#fi
fi
done
else
for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
fi
done
done
fi
done
done
done
# for data in "${data_param[@]}"; do
# model="TGN_large"
# if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
# model="TGN"
# fi
# #model="APAN"
# mkdir all_"$seed"/"$data"
# mkdir all_"$seed"/"$data"/"$model"
# mkdir all_"$seed"/"$data"/"$model"/comm
# #torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition ours --memory_type local --sample_type recent --topk 0 --seed "$seed" > all_"$seed"/"$data"/"$model"/1.out &
# wait
# for partition in "${partition_params[@]}"; do
# for sample in "${sample_type_params[@]}"; do
# if [ "$sample" = "recent" ]; then
# for mem in "${memory_type[@]}"; do
# if [ "$mem" = "historical" ]; then
# for ssim in "${shared_memory_ssim[@]}"; do
# if [ "$partition" = "ours" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
# wait
# fi
# done
# elif [ "$mem" = "all_reduce" ]; then
# if [ "$partition" = "ours" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
# wait
# fi
# else
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
# wait
# #if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
# # torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
# wait
# #fi
# fi
# done
# else
# for pro in "${probability_params[@]}"; do
# for mem in "${memory_type[@]}"; do
# if [ "$mem" = "historical" ]; then
# for ssim in "${shared_memory_ssim[@]}"; do
# if [ "$partition" = "ours" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
# wait
# fi
# done
# elif [ "$mem" = "all_reduce" ]; then
# if [ "$partition" = "ours"]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
# wait
# fi
# else
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
# wait
# if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
# wait
# fi
# fi
# done
# done
# fi
# done
# done
# done
for data in "${data_param[@]}"; do
model="JODIE_large"
......
......@@ -4,6 +4,7 @@ import profile
import sys
import psutil
from os.path import abspath, join, dirname
from torch.cuda.amp import autocast
current_path = os.path.dirname(os.path.abspath(__file__))
parent_path = os.path.abspath(os.path.join(current_path, os.pardir))
sys.path.append(parent_path)
......@@ -189,6 +190,7 @@ def query():
"total_update_memory":total_update_memory,
"total_remote_update":total_remote_update,}
def main():
torch.backends.cudnn.benchmark = True
#torch.autograd.set_detect_anomaly(True)
print('LOCAL RANK {}, RANK{}'.format(os.environ["LOCAL_RANK"],os.environ["RANK"]))
use_cuda = True
......@@ -282,11 +284,7 @@ def main():
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),ada_param=ada_param)
#remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
#train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
#train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
......@@ -392,19 +390,9 @@ def main():
for roots,mfgs,metadata in loader:
#print(batch_cnt)
batch_cnt = batch_cnt+1
"""
if ctx.memory_group == 0:
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=neg_samples)
#print('check {}\n'.format(model.module.memory_updater.last_updated_nid))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
"""
if mailbox is not None:
if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')].to('cuda'),is_sorted = False) #graph.efeat[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
src = metadata['src_pos_index']
......@@ -422,44 +410,6 @@ def main():
mailbox.update_shared()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
"""
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.efeat is None:
edge_feats = None
elif(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')].to('cuda'),is_sorted = False)#graph.efeat[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')],is_sorted=False)#graph.efeat[roots.eids]
#print(mfgs[0][0].srcdata['ID'])
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
#print('{} {} {}'.format((~(dist_index_mapper==model.module.memory_updater.last_updated_nid)).nonzero(),model.module.memory_updater.last_updated_nid,dist_index_mapper))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
#print('root shape {} unique {} {}\n'.format(root_index.shape,dist_index_mapper[root_index].unique().shape,last_updated_nid.unique().shape))
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts,
model.module.embedding)
#print('index {} {}\n'.format(index.shape,dist_index_mapper[torch.cat((src,dst))].unique().shape))
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
if memory_param['historical_fix'] == True:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=model.module.memory_updater.filter,set_remote=True,mode='historical')
else:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce',submit=False)
mailbox.sychronize_shared()
"""
ap = torch.empty([1])
auc_mrr = torch.empty([1])
#if(ctx.memory_group==0):
......@@ -485,6 +435,7 @@ def main():
return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0)
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])#,weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time = 0
......@@ -513,40 +464,22 @@ def main():
sum_remote_comm = 0
sum_local_edge_comm = 0
sum_remote_edge_comm = 0
local_access = []
remote_access = []
local_comm = []
remote_comm = []
local_edge_access = []
remote_edge_access = []
local_edge_comm = []
remote_edge_comm = []
b_cnt = 0
start = time_count.start_gpu()
total = 0
for roots,mfgs,metadata in trainloader:
end = time_count.elapsed_event(start)
ada_param.last_start_event_memory_update = ada_param.start_event()
#print('time {}'.format(end))
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
total += end
print('batch {} time {} {}\n'.format(b_cnt,end,total))
b_cnt = b_cnt + 1
#local_access.append(trainloader.local_node)
#remote_access.append(trainloader.remote_node)
#local_edge_access.append(trainloader.local_edge)
#remote_edge_access.append(trainloader.remote_edge)
#local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
#remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
#if 'ID' in mfgs[0][0].edata:
# local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
# remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
# sum_local_edge_comm +=local_edge_comm[b_cnt-1]
# sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
#sum_local_comm +=local_comm[b_cnt-1]
#sum_remote_comm +=remote_comm[b_cnt-1]
t1 = time_count.start_gpu()
if mailbox is not None:
if(graph.efeat.device.type != 'cpu'):
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')].to('cuda'),is_sorted = False)#graph.efeat[roots.eids.to('cpu')].to('cuda')
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda',non_blocking=True)
else:
edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
src = metadata['src_pos_index']
......@@ -561,6 +494,7 @@ def main():
t2 = time_count.start_gpu()
optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
#with autocast():
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
#print(time_count.elapsed_event(t2))
......@@ -574,71 +508,23 @@ def main():
total_loss += float(loss.item())
if len(trainloader.result_queue) > 0:
_,_,_,edge_feat,node_feat0 = trainloader.result_queue[0]
if ada_param is not None:
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#mailbox.handle_last_async()
#trainloader.async_feature()
#torch.cuda.synchronize()
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
## train aps
#y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
#y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize()
mailbox.update_shared()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
ada_param.last_start_event_memory_update = ada_param.start_event()
start = time_count.start_gpu()
#ada_param.update_parameter()
#torch.cuda.empty_cache()
"""
if mailbox is not None:
#src = metadata['src_pos_index']
#dst = metadata['dst_pos_index']
#ts = roots.ts
#if graph.efeat is None:
# edge_feats = None
#elif(graph.efeat.device.type != 'cpu'):
# edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')]).to('cuda')
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')].to('cuda'),is_sorted = False)#graph.efeat[roots.eids.to('cpu')].to('cuda')
#else:
# edge_feats = graph.get_local_efeat(graph.eids_mapper[roots.eids.to('cpu')])
#edge_feats = graph.get_dist_efeat(graph.eids_mapper[roots.eids.to('cpu')],is_sorted=False)#graph.efeat[roots.eids]
#print(mfgs[0][0].srcdata['ID'])
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
#print('{} {} {}'.format((~(dist_index_mapper==model.module.memory_updater.last_updated_nid)).nonzero(),model.module.memory_updater.last_updated_nid,dist_index_mapper))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
#print('root shape {} unique {} {}\n'.format(root_index.shape,dist_index_mapper[root_index].unique().shape,last_updated_nid.unique().shape))
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts,
model.module.embedding)
#print('index {} {}\n'.format(index.shape,dist_index_mapper[torch.cat((src,dst))].unique().shape))
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
t7 = time.time()
if memory_param['historical'] == True:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=model.module.memory_updater.filter,set_remote=True,mode='historical')
else:
mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,set_remote=True,mode='all_reduce')
"""
torch.cuda.synchronize()
dist.barrier()
time_prep = time.time() - epoch_start_time
......
......@@ -21,7 +21,7 @@ class Filter(nn.Module):
"""
# Treat filter as parameter so that it is saved and loaded together with the model
self.count = torch.zeros((self.n_nodes),1).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension),dtype=torch.float32).to(self.device)
def get_count(self, node_idxs):
......@@ -37,7 +37,7 @@ class Filter(nn.Module):
def update(self, node_idxs, incret):
self.count[node_idxs, :] = self.count[node_idxs, :] + 1
self.incretment[node_idxs, :] = self.incretment[node_idxs, :] + incret
self.incretment[node_idxs, :] = (self.incretment[node_idxs, :] + incret).to(self.incretment.dtype)
def clear(self):
self.count.zero_()
......
......@@ -393,6 +393,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
# if nxt_fetch_func is not None:
# nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
#if self.ada_param.training is True:
......@@ -409,7 +410,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
if nxt_fetch_func is not None:
nxt_fetch_func()
......@@ -422,7 +423,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
if self.dim_time > 0:
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
#print(b.srcdata['mem_input'].dtype,b.srcdata['mem_ts'].dtype,b.srcdata['ts'].dtype,time_feat.dtype)
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem'])
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None, ada_param = None):
super(AsyncMemeoryUpdater, self).__init__()
......@@ -468,32 +471,23 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.gamma = 1
def forward(self, mfg, param = None):
for b in mfg:
#print(b.srcdata['ID'].shape[0])
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
#time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
#his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
with torch.no_grad():
upd0 = torch.zeros_like(updated_memory0)
#print(upd0.dtype)
if self.mode == 'historical':
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
#print(transition_dense.shape)
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max()
transition_dense = 2*transition_dense - 1
upd0[mask] = transition_dense#b.srcdata['his_mem'][mask] + transition_dense
#print(self.gamma)
#print('tran {} {} {}\n'.format(transition_dense.max().item(),upd0[mask].max().item(),b.srcdata['his_mem'][mask].max().item()))
upd0[mask] = transition_dense.to(upd0.dtype)#b.srcdata['his_mem'][mask] + transition_dense
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
else:
upd0[mask] = updated_memory0[mask]
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
#upd0[mask] = updated_memory0[mask]
updated_memory = updated_memory0
with torch.no_grad():
if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask]
......@@ -518,12 +512,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None,False,False,block=b
)
#print(index.shape[0])
if torch.distributed.get_world_size() == 0:
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
#if torch.distributed.get_world_size() == 0:
# self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask = (DistIndex(index).part==torch.distributed.get_rank())
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
......
......@@ -159,9 +159,11 @@ class AdaParameter:
self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold)
average_memory_sync_time = self.average_memory_sync/self.count_memory_sync
average_memory_update_time = self.average_memory_update/self.count_memory_update
self.alpha = self.alpha-math.log(average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold))
self.alpha = self.alpha - math.log(average_memory_update_time*(1+self.wait_threshold)) + math.log(average_memory_sync_time)
print(self.alpha)
self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta])
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
......@@ -169,8 +171,8 @@ class AdaParameter:
alpha_comm=torch.tensor([self.alpha])
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()/ctx.world_size
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
......
......@@ -186,26 +186,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
sample_out,metadata = sample_out
else:
metadata = None
#to_block(metadata['src_pos_index'],metadata['dst_pos_index'],metadata['dst_neg_index'],
# metadata['seed'],metadata['seed_ts'],graph.nids_mapper,graph.eids_mapper,#device.type if "cpu" else str(device.index))
#root_len = len(metadata.pop('seed'))
#eid_inv = metadata.pop('eid_inv').clone()
#print('data {} {}\n'.format(data.edges,data.ts))
#first_block_id = metadata.pop('first_block_id').clone()
#print('first_block_id {}\n'.format(first_block_id))
#block_node_list = metadata.pop('block_node_list').clone()
#print('block_node_list {}\n'.format(block_node_list))
#unq_id = metadata.pop('unq_id').clone()
#print('unq id {}'.format(unq_id))
#dist_nid = metadata.pop('dist_nid').clone().to(device)
#dist_eid = metadata.pop('dist_eid').clone().to(device)
#print('dist nid {} dist eid {}\n'.format(dist_nid,dist_eid))
#print('block node list edge {} {}'.format(
# graph.ids[DistIndex(dist_nid[block_node_list[0,#unq_id]]).loc.to('cpu')],block_node_list[1,unq_id]
eid_len = [ret.eid().shape[0] for ret in sample_out ]
# print(sample_out)
t0 = time.time()
eid = [ret.eid() for ret in sample_out]
dst = [ret.sample_nodes() for ret in sample_out]
dst_ts = [ret.sample_nodes_ts() for ret in sample_out]
......@@ -237,13 +218,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
#print(CountComm.origin_local,CountComm.origin_remote)
#for p in range(dist.get_world_size()):
# print((DistIndex(dist_nid).part == p).sum().item())
#CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
#CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
#print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid))
"""
对于同id和同时间的节点去重取得index
......@@ -258,13 +233,11 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype)
first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype)
first_block_id = first_block_id[unq_id].contiguous()
block_node_list = block_node_list[:,first_index]
block_node_list = block_node_list[:,first_index].contiguous()
#print('first block id {}\n unq id {} \n block_node_list {}\n'.format(first_block_id,unq_id,block_node_list))
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
#print('{}:{}\n'.format(k,metadata[k]))
metadata[k] = first_block_id[metadata[k]]
#print('{}:{}\n'.format(k,metadata[k]))
t2 = time.time()
def build_block():
......@@ -290,9 +263,6 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
#weight = sample_out[r].sample_weight()
#if(weight.shape[0] > 0):
# b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
b.edata['__ID'] = e_idx
col = row
col_len += eid_len[r]
......@@ -348,9 +318,6 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
if identity is False:
assert len(sample_out) == 1
ret = sample_out[0]
eid_len = ret.eid().shape[0]
t0 = time.time()
dst_ts = ret.sample_nodes_ts().to(device)
dst = nid_mapper[ret.sample_nodes()].to(device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device)
......
......@@ -33,11 +33,11 @@ class time_count:
def start_gpu():
# Uncomment for better breakdown timings
#torch.cuda.synchronize()
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
#return start_event,end_event
return 0,0
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
return start_event,end_event
#return 0,0
@staticmethod
def start():
# return time.perf_counter(),0
......
......@@ -259,25 +259,12 @@ class DistributedDataLoader:
pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0]
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum().item()
self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum().item()
self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum().item()
#self.remote_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part != dist.get_rank()).sum()
#self.local_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part == dist.get_rank()).sum()
#torch.cuda.synchronize(stream)
#start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True)
#start.record()
stream.synchronize()
#end.record()
#end.synchronize()
#print(start.elapsed_time(end))
self.result_queue.popleft()
#start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True)
#start.record()
nind,ndata = get_node_all_to_all_route(self.graph,self.mailbox,dist_nid,out_device=self.device)
eind,edata = get_edge_all_to_all_route(self.graph,dist_eid,out_device=self.device)
......@@ -287,15 +274,14 @@ class DistributedDataLoader:
else:
node_feat = None
if eind is not None:
if eind or self.reversed is False:
edge_feat = DistributedTensor.all_to_all_get_data(edata,send_ptr=eind['send_ptr'],recv_ptr=eind['recv_ptr'],is_async=True)
else:
edge_feat = None
if self.ada_param is not None:
self.ada_param.last_start_event_fetch = self.ada_param.start_event()
t3 = time.time()
self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat))
self.submit()
if self.ada_param is not None:
self.ada_param.last_start_event_fetch = self.ada_param.start_event()
@torch.no_grad()
def __next__(self):
......@@ -319,9 +305,6 @@ class DistributedDataLoader:
)
root,mfgs,metadata = batch_data
t_sample = tt.elapsed_event(t0)
tt.time_sample_and_build+=t_sample
t1 = tt.start_gpu()
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
......@@ -358,10 +341,7 @@ class DistributedDataLoader:
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]]
self.recv_idxs += 1
else:
......@@ -392,22 +372,7 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0]
node_feat = None
mem = self.mailbox.unpack(node_feat0,mailbox = True)
#print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait()
#node_feat = node_feat[0]
##for i in range(len(mem)):
## mem[i][1].wait()
#mem[0][1].wait()
#mem[1][1].wait()
#mem[2][1].wait()
#mem[3][1].wait()
t1 = time.time()
#mem = (mem[0][0],mem[1][0],mem[2][0],mem[3][0])
#node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox, dist_nid,is_local,out_device=self.device)
t1 = time.time()
#if self.ada_param is not None:
# self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
else:
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
stream.synchronize()
......@@ -429,172 +394,6 @@ class DistributedDataLoader:
raise StopIteration
global executor
if(len(self.result_queue)==0):
#if(self.recv_idxs+1<=self.expected_idx):
self.submit()
"""
graph_sample(
graph=self.graph,
sampler=self.sampler,
sample_fn=self.sampler_fn,
data=data,neg_sampling=self.neg_sampler,
mailbox=self.mailbox,
device=torch.cuda.current_device(),
local_embedding=self.local_embedding,
async_op=True)
"""
#print('dataloader {} '.format(batch_data[1][0][0].srcdata['h'].shape))
return batch_data
import os
class PreBatchDataLoader(DistributedDataLoader):
def __init__(
self,
graph,
dataname = None,
dataset = None,
batch_size = None,
drop_last = False,
device: torch.device = torch.device('cuda'),
mode = 'train',
mailbox = None,
is_pipeline = False,
edge_classification = False,
use_local_feature = False,
train_neg_samples = 1,
neg_sets = 32,
pre_batch_path = '/mnt/data/part_data/starrytgl/minibatches/',
**kwargs
):
super().__init__(graph=graph,dataset=dataset,batch_size=batch_size,drop_last=drop_last,device=device,mailbox=mailbox,use_local_feature=use_local_feature)
self.edge_classification = edge_classification
if edge_classification:
train_neg_samples = 0
neg_sets = 0
ctx = DistributedContext.get_default_context()
if mode == 'train':
self.path = '{}/{}/{}/{}_{}/'.format(pre_batch_path,dataname,ctx.memory_group_rank,train_neg_samples, neg_sets)
self.neg_sets = neg_sets
self.mode = mode
#self.tot_length = len([fn for fn in os.listdir(self.path) if fn.startswith('{}_pos'.format(mode))])
self.idx = 0
self.is_pipeline = is_pipeline
self.rng = np.random.default_rng()
self.train_neg_samples = train_neg_samples
self.rng = np.random.default_rng()
self.stream = torch.cuda.Stream()
def __iter__(self):
self.idx = 0
ctx = DistributedContext.get_default_context()
super().__iter__()
self.init_prefetch(ctx.memory_group)
return self
def __next__(self):
if self.thread_idx == self.expected_idx:
raise StopIteration
else:
self.prefetch_next()
ret = self.get_fetched()
return ret
def get(self, idx, num_neg=1, offset=0):
#used for partial vlaidation
idx = idx * self.minibatch_parallelism + offset
with open('{}{}_pos_{}.pkl'.format(self.path, self.mode, idx), 'rb') as f:
pos_mfg = pickle.load(f)
neg_mfgs = list()
for _ in range(num_neg):
if self.mode == 'train':
chosen_set = '_' + str(self.rng.integers(self.neg_sets))
else:
chosen_set = ''
if not self.edge_classification:
with open('{}{}_neg_{}{}.pkl'.format(self.path, self.mode, idx, chosen_set), 'rb') as f:
neg_mfgs.append(pickle.load(f))
return pos_mfg, neg_mfgs
def _load_and_slice_mfgs(self, roots,pos_idx, neg_idxs ):
# only usable in training set
t_s = time.time()
with open('{}{}_pos_{}.pkl'.format(self.path, self.mode, pos_idx), 'rb') as f:
pos_mfg = pickle.load(f)
if not self.edge_classification:
with open('{}{}_neg_{}_{}.pkl'.format(self.path, self.mode, pos_idx, neg_idxs), 'rb') as f:
neg_mfg = pickle.load(f)
with torch.cuda.stream(self.stream):
mfg = combine_mfgs(self.graph,pos_mfg, neg_mfg,to_device = self.device)
self.prefetched_mfgs[0] = (roots,*mfg)
#prepare_input(mfg, self.node_feats, self.edge_feats)
#prepare_input_tgl(mfg[0],self.graph,self.mailbox,use_local=((self.mode=='train') & (self.use_local_feature==True)),out_device=self.device)
#self.prefetched_mfgs[0] = (roots,*mfg)
else:
mfg = pos_mfg
mfg.combined = False
#prepare_input_tgl(mfg[0],self.graph,self.mailbox,use_local=((self.mode=='train') & (self.use_local_feature==True)),out_device=self.device)
#self.prefetched_mfgs[0] = (roots,*pos_mfg)
return
def init_prefetch(self, idx, num_neg=1, offset=0, prefetch_interval=1, rank=None, memory_read_buffer=None, mail_read_buffer=None, read_1idx_buffer=None, read_status=None):
# initilize and prefetch the first minibatches with all zero node memory and mails
self.prefetch_interval = prefetch_interval
self.prefetch_offset = offset
self.prefetch_num_neg = num_neg
self.next_prefetch_idx = idx + prefetch_interval
self.fetched_mfgs = [None]
self.thread_idx = 0
self.prefetched_mfgs = [None] * num_neg if not self.edge_classification else [None]
self.prefetch_threads = None
idx = 0 if idx < 0 else idx
if not self.edge_classification:
neg_idxs = self.rng.integers(self.neg_sets)
else:
neg_idxs = None
for _ in range(idx+1):
data = super()._next_data()
self.prefetch_thread = threading.Thread(target=self._load_and_slice_mfgs, args=(data,idx, neg_idxs))
self.prefetch_thread.start()
return
def prefetch_next(self):
# put current prefetched to fetched and start next prefetch
t_s = time.time()
self.prefetch_thread.join()
mfg = self.prefetched_mfgs[0]
prepare_input_tgl(mfg[1],self.graph,self.mailbox,use_local=((self.mode=='train') & (self.use_local_feature==True)),out_device=self.device)
self.fetched_mfgs[0] = mfg
# print('\twait for previous thread time {}ms'.format((time.time() - t_s) * 1000))
if self.next_prefetch_idx != -1:
if self.next_prefetch_idx >= self.expected_idx:
self.next_prefetch_idx = self.expected_idx - 1
self.prefetched_mfgs = [None] * self.prefetch_num_neg if not self.edge_classification else [None]
self.prefetch_threads = None
pos_idx = self.next_prefetch_idx
if not self.edge_classification:
neg_idxs = self.rng.integers(self.neg_sets)
else:
neg_idxs = None
root = super()._next_data()
self.prefetch_thread = threading.Thread(target=self._load_and_slice_mfgs, args=(root,pos_idx, neg_idxs))
self.prefetch_thread.start()
if self.next_prefetch_idx != self.expected_idx - 1:
self.next_prefetch_idx += self.prefetch_interval
else:
self.next_prefetch_idx = -1
def get_fetched(self):
ret = self.fetched_mfgs[0]
self.thread_idx += 1
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment