Commit df36000d by zlj

fix train_remote_pos and train_remote_neg delay

parent 7ae21c05
...@@ -10,7 +10,7 @@ partitions="8" ...@@ -10,7 +10,7 @@ partitions="8"
node_per="4" node_per="4"
nnodes="2" nnodes="2"
node_rank="0" node_rank="0"
probability_params=("0.1") probability_params=("1")
sample_type_params=("boundery_recent_decay") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
......
...@@ -563,22 +563,22 @@ def main(): ...@@ -563,22 +563,22 @@ def main():
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate) ada_param.update_gnn_aggregate_time(ada_param.last_start_event_gnn_aggregate)
if len(trainloader.result_queue) > 0:
batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = trainloader.result_queue[0]
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#print(time_count.elapsed_event(t2)) #print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False: if args.local_neg_sample is False:
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_neg_sampler.train_ratio_pos,ones*train_neg_sampler.train_ratio_neg).reshape(-1,1) weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*metadata['train_ratio_pos'],ones*metadata['train_ratio_neg']).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg)) loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
else: else:
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item()) total_loss += float(loss.item())
if len(trainloader.result_queue) > 0:
_,_,_,edge_feat,node_feat0 = trainloader.result_queue[0]
edge_feat[1].wait()
node_feat0[1].wait()
if ada_param is not None:
ada_param.update_fetch_time(ada_param.last_start_event_fetch)
ada_param.update_parameter()
#mailbox.handle_last_async() #mailbox.handle_last_async()
#trainloader.async_feature() #trainloader.async_feature()
#torch.cuda.synchronize() #torch.cuda.synchronize()
......
...@@ -2,6 +2,8 @@ import yaml ...@@ -2,6 +2,8 @@ import yaml
import numpy as np import numpy as np
import torch import torch
import math import math
from starrygl.distributed.context import DistributedContext
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0] sample_param = conf['sampling'][0]
...@@ -135,7 +137,7 @@ class AdaParameter: ...@@ -135,7 +137,7 @@ class AdaParameter:
self.count_gnn_aggregate = 0 self.count_gnn_aggregate = 0
def update_parameter(self): def update_parameter(self):
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha)) #print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0: #if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
# return # return
if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None: if self.end_event_fetch is None or self.end_event_memory_sync is None or self.end_event_memory_update is None or self.end_event_gnn_aggregate is None:
...@@ -160,8 +162,15 @@ class AdaParameter: ...@@ -160,8 +162,15 @@ class AdaParameter:
self.alpha = (2-math.pow((2-self.alpha),average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold))) self.alpha = (2-math.pow((2-self.alpha),average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold)))
self.beta = max(min(self.beta, self.max_beta),self.min_beta) self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha) self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time)) ctx = DistributedContext.get_default_context()
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha)) beta_comm=torch.tensor([self.beta])
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
self.beta = beta_comm[0].item()
alpha_comm=torch.tensor([self.alpha])
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time() #self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold) #log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold)) #2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
......
...@@ -332,6 +332,9 @@ class NeighborSampler(BaseSampler): ...@@ -332,6 +332,9 @@ class NeighborSampler(BaseSampler):
metadata['seed_ts'] = seed_ts metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index metadata['src_pos_index']=src_pos_index
metadata['dst_pos_index']=dst_pos_index metadata['dst_pos_index']=dst_pos_index
if 'train_ratio_pos' in neg_sampling.__dict__:
metadata['train_ratio_pos'] = neg_sampling.train_ratio_pos
metadata['train_ratio_neg'] = neg_sampling.train_ratio_neg
if neg_sampling is not None : if neg_sampling is not None :
metadata['dst_neg_index'] = dst_neg_index metadata['dst_neg_index'] = dst_neg_index
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet(): if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment