Commit 869842ae by zhlj

fix special setting for GDELT

parent 2fd762b0
......@@ -196,7 +196,10 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True,memory_group_num=1,cache_use_rpc=True)
torch.set_num_threads(10)
device_id = torch.cuda.current_device()
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
if (args.dataname =='GDELT' & dist.get_world_size() <=4 ):
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cpu'),partition=args.partition)#torch.device('cpu'))
else:
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
if(args.dataname=='GDELT'):
train_param['epoch'] = 10
#torch.autograd.set_detect_anomaly(True)
......
......@@ -257,10 +257,10 @@ class DistributedDataLoader:
pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0]
#self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum()
#self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum()
#self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum()
#self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum()
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum().item()
self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum().item()
self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum().item()
#self.remote_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part != dist.get_rank()).sum()
#self.local_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part == dist.get_rank()).sum()
#torch.cuda.synchronize(stream)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment