Commit df36000d by zlj

fix train_remote_pos and train_remote_neg delay

parent 7ae21c05
......@@ -10,7 +10,7 @@ partitions="8"
node_per="4"
nnodes="2"
node_rank="0"
probability_params=("0.1")
probability_params=("1")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
......
......@@ -563,22 +563,22 @@ def main():
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
if len(trainloader.result_queue) > 0:
batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = trainloader.result_queue[0]
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False:
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_neg_sampler.train_ratio_pos,ones*train_neg_sampler.train_ratio_neg).reshape(-1,1)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*metadata['train_ratio_pos'],ones*metadata['train_ratio_neg']).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
else:
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item())
if len(trainloader.result_queue) > 0:
_,_,_,edge_feat,node_feat0 = trainloader.result_queue[0]
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#mailbox.handle_last_async()
#trainloader.async_feature()
#torch.cuda.synchronize()
......
......@@ -2,6 +2,8 @@ import yaml
import numpy as np
import torch
import math
from starrygl.distributed.context import DistributedContext
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
......@@ -135,7 +137,7 @@ class AdaParameter:
self.count_gnn_aggregate = 0
def update_parameter(self):
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
# return
if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None:
......@@ -160,8 +162,15 @@ class AdaParameter:
self.alpha = (2-math.pow((2-self.alpha),average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold)))
self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta])
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
self.beta = beta_comm[0].item()
alpha_comm=torch.tensor([self.alpha])
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
......
......@@ -332,6 +332,9 @@ class NeighborSampler(BaseSampler):
metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index
metadata['dst_pos_index']=dst_pos_index
if 'train_ratio_pos' in neg_sampling.__dict__:
metadata['train_ratio_pos'] = neg_sampling.train_ratio_pos
metadata['train_ratio_neg'] = neg_sampling.train_ratio_neg
if neg_sampling is not None :
metadata['dst_neg_index'] = dst_neg_index
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment