Commit 572f53f7 by zlj

add adaptive ajusement

parent 1119dd9f
...@@ -3,23 +3,23 @@ ...@@ -3,23 +3,23 @@
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.107" addr="192.168.1.107"
partition_params=("ours" ) partition_params=("ours")
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="8" partitions="4"
node_per="4" node_per="4"
nnodes="2" nnodes="1"
node_rank="0" node_rank="0"
probability_params=("0.1") probability_params=("0.1")
sample_type_params=("recent") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("all_update") memory_type=("historical")
#"historical") #"historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT") data_param=("WIKI" "LASTFM")
#"GDELT") #"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
...@@ -32,9 +32,9 @@ data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT") ...@@ -32,9 +32,9 @@ data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="APAN_large" model="TGN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN" model="TGN"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
...@@ -165,9 +165,9 @@ for data in "${data_param[@]}"; do ...@@ -165,9 +165,9 @@ for data in "${data_param[@]}"; do
done done
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="TGN_large" model="APAN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="APAN"
#continue #continue
fi fi
#model="APAN" #model="APAN"
......
...@@ -18,7 +18,7 @@ from pathlib import Path ...@@ -18,7 +18,7 @@ from pathlib import Path
from pathlib import Path from pathlib import Path
from starrygl.module.utils import parse_config from starrygl.module.utils import AdaParameter, parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor from starrygl.module.utils import parse_config, EarlyStopMonitor
...@@ -232,18 +232,22 @@ def main(): ...@@ -232,18 +232,22 @@ def main():
print(policy) print(policy)
if policy == 'recent': if policy == 'recent':
policy_train = args.sample_type#'boundery_recent_decay' policy_train = args.sample_type#'boundery_recent_decay'
if policy_train == 'recent':
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability,min_beta=1)
else:
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
else: else:
policy_train = policy policy_train = policy
if memory_param['type'] != 'none': if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0, mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)), shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),
start_historical=(args.memory_type=='historical'), start_historical=(args.memory_type=='historical'),
shared_ssim=args.shared_memory_ssim) ada_param = ada_param)
else: else:
mailbox = None mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg) sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,ada_param=ada_param,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg) eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,ada_param=ada_param,no_neg=no_neg)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1) train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device)) train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
...@@ -276,15 +280,15 @@ def main(): ...@@ -276,15 +280,15 @@ def main():
else: else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),ada_param=ada_param)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0] #remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio #train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio) #train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1 #train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1 #train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=train_neg_sampler, neg_sampler=train_neg_sampler,
...@@ -299,7 +303,8 @@ def main(): ...@@ -299,7 +303,8 @@ def main():
use_local_feature = False, use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)), device = torch.device('cuda:{}'.format(local_rank)),
probability=args.probability, probability=args.probability,
reversed = (gnn_param['arch'] == 'identity') reversed = (gnn_param['arch'] == 'identity'),
ada_param=ada_param
) )
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler, eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
...@@ -350,10 +355,10 @@ def main(): ...@@ -350,10 +355,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge)) print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,ada_param=ada_param).cuda()
device = torch.device('cuda') device = torch.device('cuda')
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,ada_param=ada_param)
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=True)
def count_parameters(model): def count_parameters(model):
...@@ -365,6 +370,7 @@ def main(): ...@@ -365,6 +370,7 @@ def main():
def eval(mode='val'): def eval(mode='val'):
model.eval() model.eval()
ada_param.eval()
aps = list() aps = list()
aucs_mrrs = list() aucs_mrrs = list()
if mode == 'val': if mode == 'val':
...@@ -383,7 +389,7 @@ def main(): ...@@ -383,7 +389,7 @@ def main():
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
batch_cnt = 0 batch_cnt = 0
for roots,mfgs,metadata in loader: for roots,mfgs,metadata in loader:
print(batch_cnt) #print(batch_cnt)
batch_cnt = batch_cnt+1 batch_cnt = batch_cnt+1
""" """
if ctx.memory_group == 0: if ctx.memory_group == 0:
...@@ -486,6 +492,7 @@ def main(): ...@@ -486,6 +492,7 @@ def main():
val_list = [] val_list = []
loss_list = [] loss_list = []
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
ada_param.train()
model.module.memory_updater.empty_cache() model.module.memory_updater.empty_cache()
tt._zero() tt._zero()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -517,6 +524,7 @@ def main(): ...@@ -517,6 +524,7 @@ def main():
start = time_count.start_gpu() start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
end = time_count.elapsed_event(start) end = time_count.elapsed_event(start)
ada_param.last_start_event_memory_update = ada_param.start_event()
#print('time {}'.format(end)) #print('time {}'.format(end))
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
...@@ -556,7 +564,7 @@ def main(): ...@@ -556,7 +564,7 @@ def main():
#print(time_count.elapsed_event(t2)) #print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False: if args.local_neg_sample is False:
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1) weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_neg_sampler.train_ratio_pos,ones*train_neg_sampler.train_ratio_neg).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg)) loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
else: else:
...@@ -573,10 +581,14 @@ def main(): ...@@ -573,10 +581,14 @@ def main():
#y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) #y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy())) #train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize() #torch.cuda.synchronize()
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p_mem() mailbox.update_p2p_mem()
mailbox.update_p2p_mail() mailbox.update_p2p_mail()
start = time_count.start_gpu() start = time_count.start_gpu()
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
ada_param.update_parameter()
#torch.cuda.empty_cache() #torch.cuda.empty_cache()
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment