Commit a70d1b3b by zlj

add sample module

parent abce8f53
def foo():
a = 1
def fa():
print(a)
a+=1
print(a)
def fb(a):
def apply():
print(a)
return apply
fc = lambda: print(a)
return fa, fb(a), fc
fa, fb, fc = foo()
fa()
fb()
fc()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
#dist.barrier()
#for i in range(100):
# print(i)
dist.barrier()
idx = ((graph.eids_mapper >> 48).int() & 0xFFFF)
print((idx==0).nonzero().shape,(idx==1).nonzero().shape)
t1 = time.time()
"""
fut = []
for i in range(1000):
#print(i)
out = graph.edge_attr.index_select(graph.eids_mapper[(idx== 0)|(idx ==1)].to('cuda'))
fut.append(out)
#out.wait()
#out.value()
if i>0 and i%100==0:
f = torch.futures.collect_all(fut)
f.wait()
f.value()
fut = []
"""
partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(3) ],device = 'cuda')
for i in range(1000):
if i%100==0:
idx = graph.eids_mapper.to('cuda')
idx,inv = idx.unique(return_inverse=True)
ind = torch.searchsorted(idx,partptr,right=False)
len = ind[1:]-ind[:-1]
gatherlen = torch.empty([2],dtype = torch.long,device = 'cuda')
dist.all_to_all_single(gatherlen,len)
query_idx = torch.empty([gatherlen.sum()],dtype = torch.long,device = 'cuda')
input_s = list(len)
output_s = list(gatherlen)
dist.all_to_all_single(query_idx,idx,output_s,input_s)
input_f = graph.edge_attr.accessor.data[DistIndex(query_idx).loc]
f = torch.empty([idx.shape[0],graph.edge_attr.accessor.data.shape[1]],dtype=torch.float,device='cuda')
dist.all_to_all_single(f,input_f,input_s,output_s)
torch.cuda.synchronize()
t2 = time.time()-t1
print(t2)
#dist.barrier()
ctx.shutdown()
if __name__ == "__main__":
main()
...@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', ...@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
# #
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', # partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict ) # edge_weight_dict=edge_weight_dict )
......
ERROR:root:unable to import libstarrygl.so, some features may not be available.
the number of nodes in graph is 1980, the number of edges in graph is 1293103
directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_1' not empty and cleared
running partition algorithm: metis_for_tgnn
saving partition data: 1/1
running partition algorithm: metis_for_tgnn
saving partition data: 1/2
saving partition data: 2/2
creating directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_4'
running partition algorithm: metis_for_tgnn
saving partition data: 1/4
saving partition data: 2/4
saving partition data: 3/4
saving partition data: 4/4
...@@ -15,6 +15,7 @@ class TensorAccessor: ...@@ -15,6 +15,7 @@ class TensorAccessor:
self._data = data self._data = data
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data) self._rref = rpc.RRef(data)
self._rref.confirmed_by_owner
@property @property
def data(self): def data(self):
...@@ -28,6 +29,20 @@ class TensorAccessor: ...@@ -28,6 +29,20 @@ class TensorAccessor:
def ctx(self): def ctx(self):
return self._ctx return self._ctx
def all_gather_index(self,index,input_split) -> Tensor:
out_split = torch.empty_like(input_split)
torch.distributed.all_to_all_single(out_split,input_split)
input_split = list(input_split)
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
out_split = list(out_split)
torch.distributed.all_to_all_single(output,index,out_split,input_split)
return output,out_split,input_split
def all_gather_data(self,index,input_split,out_split):
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
return output
def all_gather_rrefs(self) -> List[rpc.RRef]: def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref) return self.ctx.all_gather_remote_objects(self.rref)
...@@ -101,7 +116,6 @@ class DistributedTensor: ...@@ -101,7 +116,6 @@ class DistributedTensor:
def __init__(self, data: Tensor) -> None: def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data) self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs() self.rrefs = self.accessor.all_gather_rrefs()
# self.num_parts = len(self.rrefs) # self.num_parts = len(self.rrefs)
local_sizes = [] local_sizes = []
...@@ -110,6 +124,7 @@ class DistributedTensor: ...@@ -110,6 +124,7 @@ class DistributedTensor:
local_sizes.append(n) local_sizes.append(n)
self.num_nodes = DistInt(local_sizes) self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs)) self.num_parts = DistInt([1] * len(self.rrefs))
self.distptr = torch.tensor([((part_ids & 0xFFFF) << 48) for part_ids in range(self.num_parts()+1)],device = 'cuda')#data.device)
@property @property
...@@ -130,6 +145,17 @@ class DistributedTensor: ...@@ -130,6 +145,17 @@ class DistributedTensor:
def ctx(self): def ctx(self):
return self.accessor.ctx return self.accessor.ctx
def gather_select_index(self,dist_index: Union[Tensor,DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
data = dist_index.data
posptr = torch.searchsorted(data,self.distptr,right = False)
input_split = posptr[1:] - posptr[:-1]
return self.accessor.all_gather_index(DistIndex(data).loc,input_split)
def scatter_data(self,local_index,input_split,out_split):
return self.accessor.all_gather_data(local_index,input_split=input_split,out_split=out_split)
def index_select(self, dist_index: Union[Tensor, DistIndex]): def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor): if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index) dist_index = DistIndex(dist_index)
...@@ -139,6 +165,10 @@ class DistributedTensor: ...@@ -139,6 +165,10 @@ class DistributedTensor:
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i]) f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f) futs.append(f)
...@@ -150,6 +180,7 @@ class DistributedTensor: ...@@ -150,6 +180,7 @@ class DistributedTensor:
result = torch.empty( result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device, part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
) )
result[part_idx == i] = t result[part_idx == i] = t
return result return result
return torch.futures.collect_all(futs).then(callback) return torch.futures.collect_all(futs).then(callback)
......
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.graph_core import DataSet from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData from starrygl.sample.graph_core import GraphData
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
...@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid): ...@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape) #print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.device('cuda')): def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda')):
if len(sample_out) > 1: if len(sample_out) > 1:
sample_out,metadata = sample_out sample_out,metadata = sample_out
...@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d ...@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
dist_eid,eid_inv = dist_eid.unique(return_inverse=True) dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device) src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None src_ts = None
edge_feat = graph._get_edge_attr(dist_eid)
if metadata is None: if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device) root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]] root_len = [root_node.shape[0]]
...@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d ...@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
nid_tensor = torch.cat([root_node,src_node],dim = 0) nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device) dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True) dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
local_index, input_split,output_split = graph.edge_attr.gather_select_index(dist_eid)
edge_feat = graph.edge_attr.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
edge_feat = graph._get_edge_attr(dist_eid)
local_index = None
if isinstance(graph.x,DistributedTensor):
local_index, input_split,output_split = graph.x.gather_select_index(dist_nid)
node_feat = graph.x.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
node_feat = graph._get_node_attr(dist_nid) node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None: if mailbox is not None:
mem = mailbox._get_memory(dist_nid) if torch.distributed.get_world_size() > 1:
if node_feat is None:
local_index, input_split,output_split = mailbox.node_memory.gather_select_index(dist_nid)
mem = mailbox.gather_memory(local_index,input_split,output_split)
else:
mem = mailbox.get_memory(dist_nid)
else: else:
mem = None mem = None
...@@ -120,25 +136,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d ...@@ -120,25 +136,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
return data,mfgs,metadata return data,mfgs,metadata
data,mfgs,metadata = build_block() data,mfgs,metadata = build_block()
if dist.get_world_size() > 1: #if dist.get_world_size() > 1:
if(node_feat is None): # if(node_feat is None):
node_feat = torch.futures.Future() # node_feat = torch.futures.Future()
node_feat.set_result(None) # node_feat.set_result(None)
if(edge_feat is None): # if(edge_feat is None):
edge_feat = torch.futures.Future() # edge_feat = torch.futures.Future()
edge_feat.set_result(None) # edge_feat.set_result(None)
if(mem is None): # if(mem is None):
mem = torch.futures.Future() # mem = torch.futures.Future()
mem.set_result(None) # mem.set_result(None)
def callback(fs,mfgs,dist_nid,dist_eid): # def callback(fs,mfgs,dist_nid,dist_eid):
node_feat,edge_feat,mem_embedding = fs.value() # node_feat,edge_feat,mem_embedding = fs.value()
node_feat = node_feat.value() # node_feat = node_feat.value()
edge_feat = edge_feat.value() # edge_feat = edge_feat.value()
mem_embedding = mem_embedding.value() # mem_embedding = mem_embedding.value()
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid) # return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid) # cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata # return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
else: #else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid) mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata #return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata return data,mfgs,metadata
......
...@@ -66,7 +66,8 @@ class DistributedDataLoader: ...@@ -66,7 +66,8 @@ class DistributedDataLoader:
if train is True: if train is True:
self._get_expected_idx(self.dataset.len) self._get_expected_idx(self.dataset.len)
else: else:
self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size)) self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def __iter__(self): def __iter__(self):
if self.chunk_size is None: if self.chunk_size is None:
...@@ -102,18 +103,19 @@ class DistributedDataLoader: ...@@ -102,18 +103,19 @@ class DistributedDataLoader:
self.neg_sampler.set_next_pos(self.current_pos) self.neg_sampler.set_next_pos(self.current_pos)
return self return self
def _get_expected_idx(self,data_size): def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size)) self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device) num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
dist.all_reduce(num_epochs, op=dist.ReduceOp.MIN) print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item()) self.expected_idx = int(num_epochs.item())
def _next_data(self): def _next_data(self):
if self.current_pos >= self.dataset.len: if self.current_pos >= self.dataset.len:
return None return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len: if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last: if self.drop_last:
...@@ -132,7 +134,7 @@ class DistributedDataLoader: ...@@ -132,7 +134,7 @@ class DistributedDataLoader:
return next_data return next_data
def __next__(self): def __next__(self):
if(dist.get_world_size() == 1): if(dist.get_world_size() > 0):
if self.recv_idxs < self.expected_idx: if self.recv_idxs < self.expected_idx:
data = self._next_data() data = self._next_data()
batch_data = graph_sample(self.graph, batch_data = graph_sample(self.graph,
...@@ -165,6 +167,12 @@ class DistributedDataLoader: ...@@ -165,6 +167,12 @@ class DistributedDataLoader:
next_data,self.neg_sampler, next_data,self.neg_sampler,
self.mailbox, self.mailbox,
self.device) self.device)
batch_data[1].wait()
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
self.recv_idxs += 1
self.num_pending -= 1
return batch_data[0],batch_data[1].value(),batch_data[2]
self.result_queue.append(batch_data) self.result_queue.append(batch_data)
self.submitted = self.submitted + 1 self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1 self.num_pending = self.num_pending + 1
......
...@@ -23,16 +23,16 @@ class GraphData(): ...@@ -23,16 +23,16 @@ class GraphData():
world_size = dist.get_world_size() world_size = dist.get_world_size()
if hasattr(pdata,'x') and pdata.x is not None: if hasattr(pdata,'x') and pdata.x is not None:
if world_size > 1: if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device)) self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else: else:
self.x = pdata.x.to(device).to(torch.float) self.x = pdata.x.to(device).to(torch.float)
else: else:
self.x = None self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None: if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
if world_size > 1: if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to(self.device)) self.edge_attr = DistributedTensor(pdata.edge_attr.to('cpu').to(torch.float))
else: else:
self.edge_attr = pdata.edge_attr.to('cuda').to(torch.float) self.edge_attr = pdata.edge_attr.to('cpu').to(torch.float)
else: else:
self.edge_attr = None self.edge_attr = None
...@@ -43,14 +43,15 @@ class GraphData(): ...@@ -43,14 +43,15 @@ class GraphData():
return self.x[ids] return self.x[ids]
else: else:
return self.x.index_select(ids) return self.x.index_select(ids)
def _get_edge_attr(self,ids): def _get_edge_attr(self,ids,):
if self.edge_attr is None: if self.edge_attr is None:
return None return None
elif dist.get_world_size() == 1: elif dist.get_world_size() == 1:
return self.edge_attr[ids.to('cpu')].to('cuda') return self.edge_attr[ids]
else: else:
return self.edge_attr.index_select(ids) return self.edge_attr.index_select(ids)
class DataSet: class DataSet:
def __init__(self,nodes = None, def __init__(self,nodes = None,
edges = None, edges = None,
...@@ -69,6 +70,16 @@ class DataSet: ...@@ -69,6 +70,16 @@ class DataSet:
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
def _get_empty(self):
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod #@staticmethod
def get_next(self,indx): def get_next(self,indx):
......
...@@ -214,7 +214,7 @@ class SharedMailBox(): ...@@ -214,7 +214,7 @@ class SharedMailBox():
return index,memory,ts return index,memory,ts
def _get_memory(self,index): def get_memory(self,index):
if self.num_parts == 1: if self.num_parts == 1:
return self.node_memory.accessor.data[index],\ return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\ self.node_memory_ts.accessor.data[index],\
...@@ -234,6 +234,9 @@ class SharedMailBox(): ...@@ -234,6 +234,9 @@ class SharedMailBox():
#print(memory.shape[0]) #print(memory.shape[0])
return memory,memory_ts,mail,mail_ts return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback) return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(self,index,input_split,out_split):
return self.node_memory.scatter_data(index,input_split,out_split),\
self.node_memory_ts.scatter_data(index,input_split,out_split),\
self.mailbox.scatter_data(index,input_split,out_split),\
self.mailbox_ts.scatter_data(index,input_split,out_split)
...@@ -5,3 +5,7 @@ class WorkStreamEvent: ...@@ -5,3 +5,7 @@ class WorkStreamEvent:
self.write_memory_stream = torch.cuda.Stream() self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream() self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream() self.write_mail_stream = torch.cuda.Stream()
self.event = None
event = WorkStreamEvent()
def get_event():
return event.event
...@@ -82,8 +82,9 @@ def main(): ...@@ -82,8 +82,9 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
print('use cuda on',device_id) print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn") pdata = partition_load("./dataset/here/GDELT", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata) graph = GraphData(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full') sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0) mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train") sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train")
...@@ -102,17 +103,17 @@ def main(): ...@@ -102,17 +103,17 @@ def main():
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 2000, batch_size = 1000,
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
chunk_size = None, chunk_size = None,
train=True, train=True,
queue_size = 100, queue_size = 1000,
mailbox = mailbox) mailbox = mailbox)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler, testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 2000, batch_size = 1000,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -122,7 +123,7 @@ def main(): ...@@ -122,7 +123,7 @@ def main():
valloader = DistributedDataLoader(graph,val_data,sampler = sampler, valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 2000, batch_size = 1000,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -158,9 +159,9 @@ def main(): ...@@ -158,9 +159,9 @@ def main():
with torch.no_grad(): with torch.no_grad():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader: for roots,mfgs,metadata in loader:
signal[0] = 0
dist.all_reduce(signal,async_op=False)
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos)) total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg)) total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
...@@ -190,32 +191,28 @@ def main(): ...@@ -190,32 +191,28 @@ def main():
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) )
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
if mode == 'val':
val_losses.append(float(total_loss))
while(signal[0].item() != dist.get_world_size()):
signal[0] = 1
dist.all_reduce(signal,async_op=False)
if(signal[0].item() == dist.get_world_size()):
break
if mailbox is not None:
mailbox.set_mailbox_all_to_all(torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.memory_size),
torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.mailbox.accessor.data.size(2)),
torch.tensor([],device = device).reshape(-1),
reduce_Op = 'max')
ap = float(torch.tensor(aps).mean())
if neg_samples > 1: #ap = float(torch.tensor(aps).mean())
auc_mrr = float(torch.cat(aucs_mrrs).mean()) #if neg_samples > 1:
else: # auc_mrr = float(torch.cat(aucs_mrrs).mean())
auc_mrr = float(torch.tensor(aucs_mrrs).mean()) #else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
torch.cuda.synchronize()
epoch_start_time = time.time() epoch_start_time = time.time()
train_aps = list() train_aps = list()
print('Epoch {:d}:'.format(e)) print('Epoch {:d}:'.format(e))
...@@ -229,14 +226,18 @@ def main(): ...@@ -229,14 +226,18 @@ def main():
model.module.memory_updater.last_updated_ts = None model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
t_prep_s = time.time() t_prep_s = time.time()
optimizer.zero_grad()
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) total_loss += float(loss)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time() t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
...@@ -262,7 +263,9 @@ def main(): ...@@ -262,7 +263,9 @@ def main():
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) )
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize() torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time avg_time += time.time() - epoch_start_time
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment