Commit 54424543 by zlj

change the torch_utilsdependece path

parent 8768fe6f
set(CMAKE_SOURCE_DIR "/mnt/data/")
cmake_minimum_required(VERSION 3.15)
project(starrygl_ops VERSION 0.1)
......@@ -45,7 +46,7 @@ if(WITH_METIS)
add_definitions(-DWITH_METIS)
set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
message(${METIS_DIR})
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
......@@ -104,10 +105,10 @@ endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
......
This diff is collapsed. Click to expand it.
......@@ -23,7 +23,7 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 10
- epoch: 5
#batch_size: 100
# reorder: 16
lr: 0.0001
......
......@@ -2,6 +2,7 @@
#include <sampler.h>
#include <output.h>
#include <neighbors.h>
#include <temporal_utils.h>
/*------------Python Bind--------------------------------------------------------------*/
......@@ -16,7 +17,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::return_value_policy::reference)
.def("divide_nodes_to_part",
&divide_nodes_to_part,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<vector<NodeIDType> &, vector<NodeIDType> &,
......@@ -79,4 +87,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
}
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#pragma once
#include <torch/extension.h>
#include <parallel_hashmap/phmap.h>
#include <cstring>
......
......@@ -44,14 +44,14 @@ def load_feat(d, rand_de=0, rand_dn=0):
data_name = args.data_name
g = np.load('../tgl_main/DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('../tgl_main/DATA/'+data_name+'/edges.csv')
if os.path.exists('../tgl_main/DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('../tgl_main/DATA/'+data_name+'/node_features.pt')
g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('../tgl_main/DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('../tgl_main/DATA/'+data_name+'/edge_features.pt')
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
else:
e_feat = None
......
#include <cuda_runtime.h>
#include <cstdio>
int main()
{
int count = 0;
if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;
if (count == 0) return -1;
for (int device = 0; device < count; ++device)
{
cudaDeviceProp prop;
if (cudaSuccess == cudaGetDeviceProperties(&prop, device))
std::printf("%d.%d ", prop.major, prop.minor);
}
return 0;
}
#include <cuda.h>
#include <cstdio>
int main() {
printf("%d.%d", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);
return 0;
}
......@@ -3,9 +3,9 @@
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
......
......@@ -133,7 +133,6 @@ def all_to_all_s(
output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])]
input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])]
return dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
......
......@@ -17,6 +17,7 @@ class TensorAccessor:
self._data = data
self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data)
self.stream = torch.cuda.Stream()
@property
def data(self):
......@@ -29,6 +30,12 @@ class TensorAccessor:
@property
def ctx(self):
return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
......@@ -180,7 +187,7 @@ class DistributedTensor:
def ctx(self):
return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex]) -> Dict[str, Union[List[int], Tensor]]:
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
......@@ -196,7 +203,7 @@ class DistributedTensor:
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device)
all_to_all_s(recv_ind, dist_index.loc, send_ptr, recv_ptr)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
"send_ptr": send_ptr,
......@@ -209,6 +216,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
......@@ -218,7 +226,7 @@ class DistributedTensor:
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, send_ptr, recv_ptr)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
def all_to_all_set(self,
......@@ -227,6 +235,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
......@@ -235,7 +244,7 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr)
all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
......
......@@ -7,6 +7,8 @@ from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
......@@ -42,7 +44,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda')):
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
......@@ -76,26 +78,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
local_index, input_split,output_split = graph.edge_attr.gather_select_index(dist_eid)
edge_feat = graph.edge_attr.scatter_data(local_index,input_split=input_split,out_split=output_split)
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
local_index = None
ind_dict = None
if isinstance(graph.x,DistributedTensor):
local_index, input_split,output_split = graph.x.gather_select_index(dist_nid)
node_feat = graph.x.scatter_data(local_index,input_split=input_split,out_split=output_split)
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
local_index, input_split,output_split = mailbox.node_memory.gather_select_index(dist_nid)
mem = mailbox.gather_memory(local_index,input_split,output_split)
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
......@@ -110,23 +111,12 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx#dist_nid[idx]
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx#dist_eid[e_idx]
#if edge_feat is not None:
# b.edata['f'] = edge_feat[e_idx]
#if r == len(eid_len)-1:
# if node_feat is not None:
# b.srcdata['h'] = node_feat[idx]
# if mem_embedding is not None:
# node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
# b.srcdata['mem'] = node_memory[idx]
# b.srcdata['mem_ts'] = node_memory_ts[idx]
# b.srcdata['mem_input'] = mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
# b.srcdata['mail_ts'] = mailbox_ts[idx]
b.edata['ID'] = e_idx
col = row
col_len += eid_len[r]
row_len += eid_len[r]
......@@ -134,27 +124,7 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
#if dist.get_world_size() > 1:
# if(node_feat is None):
# node_feat = torch.futures.Future()
# node_feat.set_result(None)
# if(edge_feat is None):
# edge_feat = torch.futures.Future()
# edge_feat.set_result(None)
# if(mem is None):
# mem = torch.futures.Future()
# mem.set_result(None)
# def callback(fs,mfgs,dist_nid,dist_eid):
# node_feat,edge_feat,mem_embedding = fs.value()
# node_feat = node_feat.value()
# edge_feat = edge_feat.value()
# mem_embedding = mem_embedding.value()
# return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
# cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
# return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
#else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata
......@@ -164,9 +134,15 @@ def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda')):
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,graph,data,out,mailbox,device)
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
......
#include <omp.h>
#include <torch/extension.h>
#include <time.h>
#include <random>
#include <phmap.h>
#include <boost/thread/mutex.hpp>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#define MTX boost::mutex
using namespace std;
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
typedef int64_t EdgeIDType;
typedef float WeightType;
typedef float TimeStampType;
#define EXTRAARGS , phmap::priv::hash_default_hash<K>, \
phmap::priv::hash_default_eq<K>, \
std::allocator<K>, 4, MTX
template <class K, class V>
using HashM = phmap::parallel_flat_hash_map<K, V EXTRAARGS>;
void init_cache_buffer(long long size){
}
at::Storage set_cache_buffer(at::Stroge ind){
}
//按策略更新替换cache上的数据,返回对应的索引
at::Storage update_buffer_index(string policy, at::Storge ind, u){
}
at::Storage get_buffer_index(at::Storage ind){
}
PYBIND11_MODULE(sample_cores, m)
{
m
.def("get_neighbors",
&get_neighbors,
py::return_value_policy::reference)
.
}
\ No newline at end of file
......@@ -148,7 +148,7 @@ class DistributedDataLoader:
return batch_data
else :
raise StopIteration
else :
if self.queue_size > 0 :
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
if len(self.result_queue)>0:
......
......@@ -4,7 +4,7 @@ import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class GraphData():
class GraphStore():
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False):
self.device = device
self.ids = pdata.ids.to(device)
......@@ -14,8 +14,8 @@ class GraphData():
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).data.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).data.to('cpu')
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
if all_on_gpu:
self.nids_mapper = self.nids_mapper.to(device)
self.eids_mapper = self.eids_mapper.to(device)
......@@ -30,9 +30,9 @@ class GraphData():
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to('cpu').to(torch.float))
self.edge_attr = DistributedTensor(pdata.edge_attr.to('cuda').to(torch.float))
else:
self.edge_attr = pdata.edge_attr.to('cpu').to(torch.float)
self.edge_attr = pdata.edge_attr.to('cuda').to(torch.float)
else:
self.edge_attr = None
......@@ -106,7 +106,7 @@ class DataSet:
setattr(d,k,v[indx])
return d
class TemporalGraphData():
class TemporalGraphData(GraphStore):
def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
......@@ -117,7 +117,7 @@ class TemporalGraphData():
class TemporalNeighborSampleGraph(GraphData):
class TemporalNeighborSampleGraph(GraphStore):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
......
from typing import Union
from typing import List
from typing import Optional
import torch
from torch.distributed import rpc
import torch_scatter
......@@ -118,7 +120,7 @@ class SharedMailBox():
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None):
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
......@@ -132,7 +134,7 @@ class SharedMailBox():
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list)
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
......@@ -143,7 +145,7 @@ class SharedMailBox():
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
......@@ -160,22 +162,81 @@ class SharedMailBox():
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
......@@ -234,9 +295,16 @@ class SharedMailBox():
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(self,index,input_split,out_split):
return self.node_memory.scatter_data(index,input_split,out_split),\
self.node_memory_ts.scatter_data(index,input_split,out_split),\
self.mailbox.scatter_data(index,input_split,out_split),\
self.mailbox_ts.scatter_data(index,input_split,out_split)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
......@@ -2,7 +2,8 @@ import parser
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
from .torch_utils import get_norm_temporal
from starrygl.lib.libstarrygl_ops import metis_partition
from starrygl.lib.libstarrygl_ops_sampler import get_norm_temporal
import os.path as osp
import os
import shutil
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='torch_utils',
ext_modules=[
CppExtension(
name='torch_utils',
sources=['./torch_utils.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["../sample_core"],
),
],
cmdclass={
'build_ext': BuildExtension
})#
#setup(
# name='cpu_cache_manager',
# ext_modules=[
# CppExtension(
# name='cpu_cache_manager',
# sources=['cpu_cache_manager.cpp'],
# extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
# include_dirs=["./"],
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# })#
#
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor, thread
from multiprocessing import Process
from multiprocessing.pool import ThreadPool
from typing import Deque
import torch
class WorkStreamEvent:
def __init__(self):
self.train_stream = torch.cuda.Stream()
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream()
self.event = None
event = WorkStreamEvent()
def get_event():
return event.event
import torch.distributed as dist
stage = ['train_stream','write_memory','write_mail','lookup']
class PipelineManager:
def __init__(self,num_tasks = 10):
self.stream_set = {}
self.dist_set = {}
self.args_queue = {}
self.thread_pool = ThreadPoolExecutor(num_tasks)
for k in stage:
self.stream_set[k] = torch.cuda.Stream()
self.dist_set[k] = dist.new_group()
self.args_queue[k] = Deque()
def submit(self,state,func,*args):
future = self.thread_pool.submit(target=self.run, args=(state,func,args))
return future
def run(self,state,func,*args):
while torch.cuda.stream(self.stream_set[state]):
func(args, group = self.dist_set[state])
manger = None
def getPipelineManger():
global manger
if manger == None:
manger = PipelineManager()
return manger
\ No newline at end of file
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 0 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 1 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 2 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 3 > taobao4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 0 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 1 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 2 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 3 > GDELT4.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 0 > ML25M2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 1 > ML25M2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 0 > TaoBao2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 1 > TaoBao2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 1 --rank 0 > ML25M1.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 1 --rank 0 > TaoBao1.out &
nohup python train_tgnn.py --dataname GDELT --world_size 1 --rank 0 > GDELT1.out &
\ No newline at end of file
......@@ -25,24 +25,7 @@ import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
......@@ -51,6 +34,8 @@ parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
......@@ -82,7 +67,7 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/GDELT", algo="metis_for_tgnn")
pdata = partition_load("./dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
......@@ -174,10 +159,12 @@ def main():
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
......@@ -247,10 +234,12 @@ def main():
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment