Commit 572f53f7 by zlj

add adaptive ajusement

parent 1119dd9f
......@@ -3,23 +3,23 @@
# 定义数组变量
seed=$1
addr="192.168.1.107"
partition_params=("ours" )
partition_params=("ours")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="8"
partitions="4"
node_per="4"
nnodes="2"
nnodes="1"
node_rank="0"
probability_params=("0.1")
sample_type_params=("recent")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("all_update")
memory_type=("historical")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
data_param=("WIKI" "LASTFM")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......@@ -32,9 +32,9 @@ data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="APAN_large"
model="TGN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN"
model="TGN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......@@ -165,9 +165,9 @@ for data in "${data_param[@]}"; do
done
for data in "${data_param[@]}"; do
model="TGN_large"
model="APAN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="APAN"
#continue
fi
#model="APAN"
......
......@@ -18,7 +18,7 @@ from pathlib import Path
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.module.utils import AdaParameter, parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
......@@ -232,18 +232,22 @@ def main():
print(policy)
if policy == 'recent':
policy_train = args.sample_type#'boundery_recent_decay'
if policy_train == 'recent':
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability,min_beta=1)
else:
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
else:
policy_train = policy
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),
start_historical=(args.memory_type=='historical'),
shared_ssim=args.shared_memory_ssim)
ada_param = ada_param)
else:
mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability,no_neg=no_neg)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,ada_param=ada_param,no_neg=no_neg)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = policy_train, graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,ada_param=ada_param,no_neg=no_neg)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.ts,train_mask.to(graph.edge_index.device))
......@@ -276,15 +280,15 @@ def main():
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),ada_param=ada_param)
#remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
#train_ratio_pos = 1.0/(1-args.probability+ args.probability * remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
#train_ratio_neg = 1.0/(args.probability*remote_ratio) if ((args.probability <1) & (args.probability > 0)) else 1
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
trainloader = DistributedDataLoader(graph,eval_train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=train_neg_sampler,
......@@ -299,7 +303,8 @@ def main():
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=args.probability,
reversed = (gnn_param['arch'] == 'identity')
reversed = (gnn_param['arch'] == 'identity'),
ada_param=ada_param
)
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
......@@ -350,10 +355,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda()
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,ada_param=ada_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox)
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,ada_param=ada_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
def count_parameters(model):
......@@ -365,6 +370,7 @@ def main():
def eval(mode='val'):
model.eval()
ada_param.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
......@@ -383,7 +389,7 @@ def main():
signal = torch.tensor([0],dtype = int,device = device)
batch_cnt = 0
for roots,mfgs,metadata in loader:
print(batch_cnt)
#print(batch_cnt)
batch_cnt = batch_cnt+1
"""
if ctx.memory_group == 0:
......@@ -486,6 +492,7 @@ def main():
val_list = []
loss_list = []
for e in range(train_param['epoch']):
ada_param.train()
model.module.memory_updater.empty_cache()
tt._zero()
torch.cuda.synchronize()
......@@ -517,6 +524,7 @@ def main():
start = time_count.start_gpu()
for roots,mfgs,metadata in trainloader:
end = time_count.elapsed_event(start)
ada_param.last_start_event_memory_update = ada_param.start_event()
#print('time {}'.format(end))
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1
......@@ -556,7 +564,7 @@ def main():
#print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False:
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_neg_sampler.train_ratio_pos,ones*train_neg_sampler.train_ratio_neg).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
else:
......@@ -573,10 +581,14 @@ def main():
#y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize()
mailbox.update_shared()
mailbox.update_p2p_mem()
mailbox.update_p2p_mail()
start = time_count.start_gpu()
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
ada_param.update_parameter()
#torch.cuda.empty_cache()
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment