Commit 44d3b857 by Wenjie Huang

Merge remote-tracking branch 'origin/develop-v2'

parents 9ca0dca0 6cb3e218
install.sh install.sh
a.out
third_party
install.sh
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
......
...@@ -155,6 +155,8 @@ endif() ...@@ -155,6 +155,8 @@ endif()
# add libsampler.so # add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler") set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp") file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS}) add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 5
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <sampler.h> #include <sampler.h>
#include <output.h> #include <output.h>
#include <neighbors.h> #include <neighbors.h>
#include <temporal_utils.h>
/*------------Python Bind--------------------------------------------------------------*/ /*------------Python Bind--------------------------------------------------------------*/
...@@ -16,7 +17,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -16,7 +17,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::return_value_policy::reference) py::return_value_policy::reference)
.def("divide_nodes_to_part", .def("divide_nodes_to_part",
&divide_nodes_to_part, &divide_nodes_to_part,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock") py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<vector<NodeIDType> &, vector<NodeIDType> &, .def(py::init<vector<NodeIDType> &, vector<NodeIDType> &,
...@@ -79,4 +87,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -79,4 +87,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes) .def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset) .def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; }); .def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
} }
\ No newline at end of file
...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
// uniform_int_distribution<> u(0, tnb.deg[node]-1); // uniform_int_distribution<> u(0, tnb.deg[node]-1);
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){ // while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){
for(int i=0;i<fanout;i++){ for(int i=0;i<fanout;i++){
//循环选择fanout个邻居 //ѭ��ѡ��fanout���ھ�
NodeIDType indice; NodeIDType indice;
if(policy == "weighted"){//考虑边权重信 if(policy == "weighted"){//���DZ�Ȩ����Ϣ
const vector<WeightType>& ew = tnb.edge_weight[node]; const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e); indice = sample_multinomial(ew, e);
} }
else if(policy == "uniform"){//均匀采样 else if(policy == "uniform"){//���Ȳ���
// indice = u(e); // indice = u(e);
indice = rand_r(&loc_seed) % (nei.size()); indice = rand_r(&loc_seed) % (nei.size());
} }
...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
auto chosen_e_iter = edge.begin() + indice; auto chosen_e_iter = edge.begin() + indice;
if(part_unique){ if(part_unique){
auto rst = temp_s.insert(*chosen_n_iter); auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //不重复 if(rst.second){ //���ظ�
eid_threads[tid].emplace_back(*chosen_e_iter); eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter); node_s_threads[tid].insert(*chosen_n_iter);
if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++; if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++;
...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
} }
} }
else{ else{
//可选邻居边大于扇出的话需要随机选择fanout个邻居 //��ѡ�ھӱߴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i); tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1); uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl; //cout<<end_index<<endl;
......
#include <pybind11/pybind11.h> #pragma once
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <parallel_hashmap/phmap.h> #include <parallel_hashmap/phmap.h>
#include <cstring> #include <cstring>
......
...@@ -44,14 +44,14 @@ def load_feat(d, rand_de=0, rand_dn=0): ...@@ -44,14 +44,14 @@ def load_feat(d, rand_de=0, rand_dn=0):
data_name = args.data_name data_name = args.data_name
g = np.load('../tgl_main/DATA/'+data_name+'/ext_full.npz') g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('../tgl_main/DATA/'+data_name+'/edges.csv') df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
if os.path.exists('../tgl_main/DATA/'+data_name+'/node_features.pt'): if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('../tgl_main/DATA/'+data_name+'/node_features.pt') n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
else: else:
n_feat = None n_feat = None
if os.path.exists('../tgl_main/DATA/'+data_name+'/edge_features.pt'): if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('../tgl_main/DATA/'+data_name+'/edge_features.pt') e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
else: else:
e_feat = None e_feat = None
......
#include <cuda_runtime.h>
#include <cstdio>
int main()
{
int count = 0;
if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;
if (count == 0) return -1;
for (int device = 0; device < count; ++device)
{
cudaDeviceProp prop;
if (cudaSuccess == cudaGetDeviceProperties(&prop, device))
std::printf("%d.%d ", prop.major, prop.minor);
}
return 0;
}
#include <cuda.h>
#include <cstdio>
int main() {
printf("%d.%d", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);
return 0;
}
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
mkdir -p build && cd build mkdir -p build && cd build
cmake .. \ cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \ -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.8/site-packages" \
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \ -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \ -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
&& make -j32 \ && make -j32 \
&& rm -rf ../starrygl/lib \ && rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \ && mkdir ../starrygl/lib \
......
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -23,6 +31,9 @@ __all__ = [ ...@@ -23,6 +31,9 @@ __all__ = [
Strings = Sequence[str] Strings = Sequence[str]
OptStrings = Optional[Strings] OptStrings = Optional[Strings]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData: class GraphData:
def __init__(self, def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
...@@ -178,6 +189,13 @@ class GraphData: ...@@ -178,6 +189,13 @@ class GraphData:
algorithm: str = "metis", algorithm: str = "metis",
) -> 'GraphData': ) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}" p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__()) return torch.load(p.__str__())
def save_partition(self, def save_partition(self,
...@@ -307,7 +325,6 @@ class GraphData: ...@@ -307,7 +325,6 @@ class GraphData:
logging.info(f"saving partition data: {i+1}/{num_parts}") logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__()) torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData: class MetaData:
def __init__(self) -> None: def __init__(self) -> None:
......
...@@ -133,7 +133,6 @@ def all_to_all_s( ...@@ -133,7 +133,6 @@ def all_to_all_s(
output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])] output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])]
input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])] input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])]
return dist.all_to_all_single( return dist.all_to_all_single(
output=output_tensor, output=output_tensor,
input=input_tensor, input=input_tensor,
......
...@@ -58,7 +58,6 @@ class DistributedContext: ...@@ -58,7 +58,6 @@ class DistributedContext:
master_port = int(os.environ["MASTER_PORT"]) master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}" ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}" rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext( ctx = DistributedContext(
backend=backend, backend=backend,
ccl_init_method=ccl_init_url, ccl_init_method=ccl_init_url,
...@@ -69,6 +68,7 @@ class DistributedContext: ...@@ -69,6 +68,7 @@ class DistributedContext:
use_gpu=use_gpu, use_gpu=use_gpu,
rpc_gpu=rpc_gpu, rpc_gpu=rpc_gpu,
) )
_set_default_dist_context(ctx) _set_default_dist_context(ctx)
return ctx return ctx
......
...@@ -16,7 +16,11 @@ class TensorAccessor: ...@@ -16,7 +16,11 @@ class TensorAccessor:
self._data = data self._data = data
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data) if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
else:
self._rref = None
self.stream = torch.cuda.Stream()
@property @property
def data(self): def data(self):
...@@ -29,6 +33,12 @@ class TensorAccessor: ...@@ -29,6 +33,12 @@ class TensorAccessor:
@property @property
def ctx(self): def ctx(self):
return self._ctx return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]: def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref) return self.ctx.all_gather_remote_objects(self.rref)
...@@ -134,18 +144,29 @@ class DistIndex: ...@@ -134,18 +144,29 @@ class DistIndex:
class DistributedTensor: class DistributedTensor:
def __init__(self, data: Tensor) -> None: def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data) self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs() if self.accessor.rref is not None:
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs: local_sizes = []
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait() for rref in self.rrefs:
local_sizes.append(n) n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
self._num_nodes: int = sum(local_sizes) local_sizes.append(n)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes) self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
self._part_id: int = self.accessor.ctx.rank self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property @property
def dtype(self): def dtype(self):
return self.accessor.data.dtype return self.accessor.data.dtype
...@@ -180,10 +201,10 @@ class DistributedTensor: ...@@ -180,10 +201,10 @@ class DistributedTensor:
def ctx(self): def ctx(self):
return self.accessor.ctx return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex]) -> Dict[str, Union[List[int], Tensor]]: def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor): if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index) dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts()) send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1] send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes) recv_sizes = torch.empty_like(send_sizes)
...@@ -195,8 +216,8 @@ class DistributedTensor: ...@@ -195,8 +216,8 @@ class DistributedTensor:
send_ptr = send_ptr.tolist() send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist() recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device) recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, send_ptr, recv_ptr) all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return { return {
"send_ptr": send_ptr, "send_ptr": send_ptr,
...@@ -209,6 +230,7 @@ class DistributedTensor: ...@@ -209,6 +230,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None, send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None, recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None, recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor: ) -> Tensor:
if dist_index is not None: if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index) dist_dict = self.all_to_all_ind2ptr(dist_index)
...@@ -217,8 +239,8 @@ class DistributedTensor: ...@@ -217,8 +239,8 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"] recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind] data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device) recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr) all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv return recv
def all_to_all_set(self, def all_to_all_set(self,
...@@ -227,6 +249,7 @@ class DistributedTensor: ...@@ -227,6 +249,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None, send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None, recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None, recv_ind: Optional[List[int]] = None,
group = None
): ):
if dist_index is not None: if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index) dist_dict = self.all_to_all_ind2ptr(dist_index)
...@@ -235,7 +258,7 @@ class DistributedTensor: ...@@ -235,7 +258,7 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"] recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device) recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr) all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv) self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]): def index_select(self, dist_index: Union[Tensor, DistIndex]):
......
...@@ -32,8 +32,8 @@ class EdgePredictor(torch.nn.Module): ...@@ -32,8 +32,8 @@ class EdgePredictor(torch.nn.Module):
def forward(self, h, neg_samples=1): def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2) num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge]) h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[:num_edge]) # h_pos_dst = self.dst_fc(h[num_edge:num_edge*2]) #
h_neg_src = self.src_fc(h[2 * num_edge:]) h_neg_src = self.src_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst) h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1)) h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
......
...@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module): ...@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param self.train_param = train_param
if memory_param['type'] == 'node': if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru': if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn': elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
......
...@@ -3,10 +3,13 @@ import torch ...@@ -3,10 +3,13 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
""" """
入参不变,出参变为: 入参不变,出参变为:
sample_from_nodes sample_from_nodes
...@@ -42,8 +45,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid): ...@@ -42,8 +45,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape) #print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda')): def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1: if len(sample_out) > 1:
sample_out,metadata = sample_out sample_out,metadata = sample_out
else: else:
...@@ -75,27 +77,39 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = ...@@ -75,27 +77,39 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
nid_tensor = torch.cat([root_node,src_node],dim = 0) nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device) dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True) dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
local_index, input_split,output_split = graph.edge_attr.gather_select_index(dist_eid) fetchCache = FetchFeatureCache.getFetchCache()
edge_feat = graph.edge_attr.scatter_data(local_index,input_split=input_split,out_split=output_split) if fetchCache is None:
else: if isinstance(graph.edge_attr,DistributedTensor):
edge_feat = graph._get_edge_attr(dist_eid) ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
local_index = None edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
if isinstance(graph.x,DistributedTensor):
local_index, input_split,output_split = graph.x.gather_select_index(dist_nid)
node_feat = graph.x.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
local_index, input_split,output_split = mailbox.node_memory.gather_select_index(dist_nid)
mem = mailbox.gather_memory(local_index,input_split,output_split)
else: else:
mem = mailbox.get_memory(dist_nid) edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else: else:
mem = None raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block(): def build_block():
mfgs = list() mfgs = list()
col = torch.arange(0,root_len,device = device) col = torch.arange(0,root_len,device = device)
...@@ -110,23 +124,12 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = ...@@ -110,23 +124,12 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
device = device) device = device)
idx = nid_inv[0:row_len + elen] idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen] e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx#dist_nid[idx] b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0: if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device) b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None: if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]] b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx#dist_eid[e_idx] b.edata['ID'] = e_idx
#if edge_feat is not None:
# b.edata['f'] = edge_feat[e_idx]
#if r == len(eid_len)-1:
# if node_feat is not None:
# b.srcdata['h'] = node_feat[idx]
# if mem_embedding is not None:
# node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
# b.srcdata['mem'] = node_memory[idx]
# b.srcdata['mem_ts'] = node_memory_ts[idx]
# b.srcdata['mem_input'] = mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
# b.srcdata['mail_ts'] = mailbox_ts[idx]
col = row col = row
col_len += eid_len[r] col_len += eid_len[r]
row_len += eid_len[r] row_len += eid_len[r]
...@@ -134,39 +137,28 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = ...@@ -134,39 +137,28 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
mfgs = list(map(list, zip(*[iter(mfgs)]))) mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse() mfgs.reverse()
return data,mfgs,metadata return data,mfgs,metadata
data,mfgs,metadata = build_block() data,mfgs,metadata = build_block()
#if dist.get_world_size() > 1:
# if(node_feat is None):
# node_feat = torch.futures.Future()
# node_feat.set_result(None)
# if(edge_feat is None):
# edge_feat = torch.futures.Future()
# edge_feat.set_result(None)
# if(mem is None):
# mem = torch.futures.Future()
# mem.set_result(None)
# def callback(fs,mfgs,dist_nid,dist_eid):
# node_feat,edge_feat,mem_embedding = fs.value()
# node_feat = node_feat.value()
# edge_feat = edge_feat.value()
# mem_embedding = mem_embedding.value()
# return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
# cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
# return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
#else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid) mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata #return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler, def graph_sample(graph, sampler:BaseSampler,
sample_fn, data, sample_fn, data,
neg_sampling = None, neg_sampling = None,
mailbox = None, mailbox = None,
device = torch.device('cuda')): device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling) out = sample_fn(sampler,data,neg_sampling)
return to_block(graph,data,out,mailbox,device) if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs): def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1)) out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
......
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class LRUCache(Cache):
"""
Least-recently-used (LRU) cache
"""
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(LRUCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'lru'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_feature: Sequence[torch.Tensor]):
if len(uncached_index) > self.capacity:
num_to_cache = self.capacity
else:
num_to_cache = len(uncached_index)
node_id_to_cache = uncached_index[:num_to_cache].to(torch.int32)
self.now_cache_count -= 1
self.cache_count[cached_index] = 0
# get the k node id with the least water level
removing_cache_index = torch.topk(
self.cache_count, k=num_to_cache, largest=False).indices.to(torch.int32)
removing_node_id = self.cache_index_to_id[removing_cache_index]
# update cache attributes
for buffer,data in zip(self.buffers,uncached_feature):
buffer[removing_cache_index] = data[:num_to_cache].reshape(-1,*buffer.shape[1:])
self.cache_count[removing_cache_index] = 0
self.cache_validate[removing_node_id] = False
self.cache_validate[node_id_to_cache] = True
self.cache_map[removing_node_id] = -1
self.cache_map[node_id_to_cache] = removing_cache_index
self.cache_index_to_id[removing_cache_index] = node_id_to_cache
from typing import Callable, List, Optional, Sequence, Union
import numpy as np
import torch
from starrygl.distributed.utils import DistributedTensor
class Cache:
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
print(len(cache_data),cache_data)
assert torch.cuda.is_available() == True
self.use_local = use_local
self.is_update_cache = is_update_cache
self.device = torch.device('cuda')
self.use_remote = torch.distributed.get_world_size()>1
assert not (self.use_local is False and self.use_remote is False),\
"the data is on the cuda and no need remote cache"
self.cache_ratio = cache_ratio
self.num_cache = num_cache
self.capacity = int(self.num_cache * cache_ratio)
self.update_stream = torch.cuda.Stream()
self.buffers = []
self.pinned_buffers = []
for data in cache_data:
self.buffers.append(
torch.zeros(self.capacity,*data.shape[1:],
dtype = data.dtype,device = torch.device('cuda'))
)
self.cache_validate = torch.zeros(
num_cache, dtype=torch.bool, device=self.device)
# maps node id -> index
self.cache_map = torch.zeros(
num_cache, dtype=torch.int32, device=self.device) - 1
# maps index -> node id
self.cache_index_to_id = torch.zeros(
num_cache,dtype=torch.int32, device=self.device) -1
self.hit_sum = 0
self.hit_ = 0
def init_cache(self,ind:torch.Tensor,data:Sequence[torch.Tensor]):
pos = torch.arange(ind.shape[0],device = 'cuda',dtype = ind.dtype)
self.cache_map[ind] = pos.to(torch.int32).to('cuda')
self.cache_index_to_id[pos] = ind.to(torch.int32).to('cuda')
for data,buffer in zip(data,self.buffers):
buffer[:ind.shape[0],] = data
self.cache_validate[ind] = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_data: Sequence[torch.Tensor]):
raise NotImplementedError
def fetch_data(self,ind:Optional[torch.Tensor] = None,
uncached_source_fn: Callable = None, source_index:torch.Tensor = None):
self.hit_sum += ind.shape[0]
assert isinstance(ind, torch.Tensor)
cache_mask = self.cache_validate[ind]
uncached_mask = ~cache_mask
self.hit_ += torch.sum(cache_mask)
cached_data = []
cached_index = self.cache_map[ind[cache_mask]]
if uncached_mask.sum() > 0:
uncached_id = ind[uncached_mask]
source_index = source_index[uncached_mask]
uncached_feature = uncached_source_fn(source_index)
if isinstance(uncached_feature,torch.Tensor):
uncached_feature = [uncached_feature]
else:
uncached_id = None
uncached_feature = [None for _ in range(len(self.buffers))]
for data,uncached_data in zip(self.buffers,uncached_feature):
nfeature = torch.zeros(
len(ind), *data.shape[1:], dtype=data.dtype,device=self.device)
nfeature[cache_mask,:] = data[cached_index]
if uncached_id is not None:
nfeature[uncached_mask] = uncached_data.reshape(-1,*data.shape[1:])
cached_data.append(nfeature)
if self.is_update_cache and uncached_mask.sum() > 0:
self.update_cache(cached_index=cached_index,
uncached_index=uncached_id,
uncached_feature=uncached_feature)
return nfeature
def invalidate(self,ind):
self.cache_validate[ind] = False
from typing import Optional
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.cache import LRU_cache
from starrygl.sample.cache.cache import Cache
from starrygl.sample.cache.static_cache import StaticCache
from starrygl.sample.cache.utils import pre_sample
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.memory.shared_mailbox import SharedMailBox
import torch
_FetchCache = None
class FetchFeatureCache:
@staticmethod
def create_fetch_cache(num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'):
global _FetchCache
_FetchCache = FetchFeatureCache(num_nodes, num_edges,
edge_cache_ratio, node_cache_ratio,
graph,mailbox,policy)
@staticmethod
def getFetchCache():
global _FetchCache
return _FetchCache
def __init__(self, num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'
):
if policy == 'lru':
init_fn = LRU_cache.LRUCache
elif policy == 'static':
init_fn = StaticCache
self.ctx = DistributedContext.get_default_context()
if graph.x is not None:
self.node_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[graph.x],use_local=graph.uvm_node)
else:
self.node_cache = None
if graph.edge_attr is not None:
self.edge_cache:Cache = init_fn(edge_cache_ratio,num_edges,
[graph.edge_attr],use_local = graph.uvm_edge)
else:
self.edge_cache = None
if mailbox is not None:
self.mailbox_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[mailbox.node_memory,
mailbox.node_memory_ts.accessor.data.reshape(-1,1),
mailbox.mailbox,
mailbox.mailbox_ts],
use_local = mailbox.uvm)
else:
self.mailbox_cache = None
self.graph = graph
self.mailbox = mailbox
global FetchCache
FetchCache = self
def fetch_feature(self, nid: Optional[torch.Tensor] = None, dist_nid = None,
eid: Optional[torch.Tensor] = None, dist_eid = None
):
nfeat = None
mem = None
efeat = None
if self.node_cache is not None and nid is not None:
nfeat = torch.zeros(nid.shape[0],
self.node_cache.buffers[0].shape[1],
dtype = self.node_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.node_cache.use_local is False:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
nfeat[local_mask] = self.graph.x.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
nfeat[remote_mask] = self.node_cache.fetch_data(remote_id,\
self.graph._get_node_attr,source_id)[0]
else:
nfeat = self.node_cache.fetch_data(nid,
self.graph._get_node_attr,dist_nid)[0]
if self.mailbox_cache is not None and nid is not None:
memory = torch.zeros(nid.shape[0],
self.mailbox_cache.buffers[0].shape[1],
dtype = self.mailbox_cache.buffers[0].dtype,
device = torch.device('cuda')
)
memory_ts = torch.zeros(nid.shape[0],
dtype = self.mailbox_cache.buffers[1].dtype,
device = torch.device('cuda')
)
mailbox = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[2].shape[1:],
dtype = self.mailbox_cache.buffers[2].dtype,
device = torch.device('cuda')
)
mailbox_ts = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[3].shape[1:],
dtype = self.mailbox_cache.buffers[3].dtype,
device = torch.device('cuda')
)
if self.mailbox_cache.use_local is False:
if self.node_cache is None:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
remote_mask = ~local_mask
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
mem = self.mailbox.gather_memory(local_id)
memory[local_mask],memory_ts[local_mask],mailbox[local_mask],mailbox_ts[local_mask]= mem
if remote_mask.sum() > 0:
mem = self.mailbox_cache.fetch_data(remote_id,\
self.mailbox.gather_memory,source_id)
memory[remote_mask] = mem[0]
memory_ts[remote_mask] = mem[1].reshape(-1)
mailbox[remote_mask] = mem[2]
mailbox_ts[remote_mask] = mem[3]
mem = memory,memory_ts,mailbox,mailbox_ts
else:
mem = self.mailbox_cache.fetch_data(nid,mailbox.gather_memory,dist_nid)
if self.edge_cache is not None and eid is not None:
efeat = torch.zeros(eid.shape[0],
self.edge_cache.buffers[0].shape[1],
dtype = self.edge_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.edge_cache.use_local is False:
local_mask = (DistIndex(dist_eid).part == torch.distributed.get_rank())
local_id = dist_eid[local_mask]
efeat[local_mask] = self.graph.edge_attr.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = eid[remote_mask]
source_id = dist_eid[remote_mask]
efeat[remote_mask] = self.edge_cache.fetch_data(remote_id,\
self.graph._get_edge_attr,source_id)[0]
else:
efeat = self.node_cache.fetch_data(eid,
self.graph._get_edge_attr,dist_eid)[0]
return nfeat,efeat,mem
def init_cache_with_presample(self,dataloader, num_epoch:int = 10):
node_size = self.node_cache.capacity if self.node_cache is not None else 0
edge_size = self.edge_cache.capacity if self.edge_cache is not None else 0
node_counts,edge_counts = pre_sample(dataloader=dataloader,
num_epoch=num_epoch,
node_size = node_size,
edge_size = edge_size)
if node_size != 0:
if self.node_cache.use_local is False:
dist_mask = DistIndex(self.graph.nids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
node_counts = node_counts[dist_mask]
_,nid = node_counts.topk(node_size)
if self.node_cache.use_local is False:
nid = dist_mask.nonzero()[nid]
dist_nid = self.graph.nids_mapper[nid].unique()
node_feature = self.graph._get_node_attr(dist_nid.to(self.graph.x.device))
_nid = nid.reshape(-1)
self.node_cache.init_cache(_nid,node_feature)
print('finish node init')
if edge_size != 0:
if self.edge_cache.use_local is False:
dist_mask = DistIndex(self.graph.eids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
edge_counts = edge_counts[dist_mask]
_,eid = edge_counts.topk(edge_size)
if self.edge_cache.use_local is False:
eid_ = dist_mask.nonzero()[eid]
else:
eid_ = eid
dist_eid = self.graph.eids_mapper[eid_].unique()
edge_feature = self.graph._get_edge_attr(dist_eid.to(self.graph.edge_attr.device))
eid_ = eid_.reshape(-1)
self.edge_cache.init_cache(eid_,edge_feature)
print('finish edge init')
\ No newline at end of file
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class StaticCache(Cache):
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(StaticCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'static'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = False
import torch
#dataloader不要加import
def pre_sample(dataloader, num_epoch:int,node_size:int,edge_size:int):
nodes_counts = torch.zeros(dataloader.graph.num_nodes,dtype = torch.long)
edges_counts = torch.zeros(dataloader.graph.num_edges,dtype = torch.long)
print(nodes_counts.shape,edges_counts.shape)
sampler = dataloader.sampler
neg_sampling = dataloader.neg_sampler
sample_fn = dataloader.sampler_fn
graph = dataloader.graph
for _ in range(num_epoch):
dataloader.__iter__()
while dataloader.recv_idxs < dataloader.expected_idx:
dataloader.recv_idxs += 1
data = dataloader._next_data()
out = sample_fn(sampler,data,neg_sampling)
if(len(out)>0):
sample_out,metadata = out
else:
sample_out = out
eid = [ret.eid() for ret in sample_out]
eid_tensor = torch.cat(eid,dim = 0)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
dst_node = graph.sample_graph['edge_index'][1,eid_tensor*2].to(graph.nids_mapper.device)
eid_tensor = torch.unique(eid_tensor)
nid_tensor = torch.unique(torch.cat((src_node,dst_node)))
edges_counts[eid_tensor] += 1
nodes_counts[nid_tensor] += 1
return nodes_counts,edges_counts
...@@ -42,6 +42,7 @@ class DistributedDataLoader: ...@@ -42,6 +42,7 @@ class DistributedDataLoader:
train = False, train = False,
queue_size = 10, queue_size = 10,
mailbox = None, mailbox = None,
is_pipeline = False,
**kwargs **kwargs
): ):
assert sampler is not None assert sampler is not None
...@@ -63,11 +64,13 @@ class DistributedDataLoader: ...@@ -63,11 +64,13 @@ class DistributedDataLoader:
self.dataset = dataset self.dataset = dataset
self.mailbox = mailbox self.mailbox = mailbox
self.device = device self.device = device
self.is_pipeline = is_pipeline
if train is True: if train is True:
self._get_expected_idx(self.dataset.len) self._get_expected_idx(self.dataset.len)
else: else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX) self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size)) #self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self): def __iter__(self):
if self.chunk_size is None: if self.chunk_size is None:
...@@ -134,7 +137,7 @@ class DistributedDataLoader: ...@@ -134,7 +137,7 @@ class DistributedDataLoader:
return next_data return next_data
def __next__(self): def __next__(self):
if(dist.get_world_size() > 0): if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx: if self.recv_idxs < self.expected_idx:
data = self._next_data() data = self._next_data()
batch_data = graph_sample(self.graph, batch_data = graph_sample(self.graph,
...@@ -148,47 +151,36 @@ class DistributedDataLoader: ...@@ -148,47 +151,36 @@ class DistributedDataLoader:
return batch_data return batch_data
else : else :
raise StopIteration raise StopIteration
else : else:
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted) if self.recv_idxs == 0:
for _ in range(num_reqs): data = self._next_data()
if len(self.result_queue)>0: batch_data = graph_sample(self.graph,
result = self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
else:
next_data = self._next_data()
assert next_data is not None
batch_data = graph_sample(self.graph,
self.sampler, self.sampler,
self.sampler_fn, self.sampler_fn,
next_data,self.neg_sampler, data,self.neg_sampler,
self.mailbox, self.mailbox,
self.device) self.device)
batch_data[1].wait() self.recv_idxs += 1
self.submitted = self.submitted + 1 else:
self.num_pending = self.num_pending + 1 if(self.recv_idxs < self.expected_idx):
self.recv_idxs += 1 assert len(self.result_queue) > 0
self.num_pending -= 1 result= self.result_queue[0]
return batch_data[0],batch_data[1].value(),batch_data[2]
self.result_queue.append(batch_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
while(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft() self.result_queue.popleft()
return result[0],result[1].value(),result[2] batch_data = result.result()
self.recv_idxs += 1
else:
assert self.num_pending == 0 raise StopIteration
raise StopIteration if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
......
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp import os.path as osp
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch_geometric.data import Data from torch_geometric.data import Data
class GraphData():
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False): from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share
class DistributedGraphStore:
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device self.device = device
self.ids = pdata.ids.to(device) self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device) self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'): if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float) self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else: else:
self.edge_ts = None self.edge_ts = None
self.sample_graph = pdata.sample_graph self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).data.to('cpu') self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).data.to('cpu') self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
if all_on_gpu: torch.cuda.empty_cache()
self.nids_mapper = self.nids_mapper.to(device)
self.eids_mapper = self.eids_mapper.to(device)
self.num_nodes = self.nids_mapper.data.shape[0] self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None: if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1: if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float)) self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else: else:
self.x = pdata.x.to(device).to(torch.float) self.x = x
else: else:
self.x = None self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None: if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1: if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to('cpu').to(torch.float)) self.edge_attr = DistributedTensor(edge_attr)
else: else:
self.edge_attr = pdata.edge_attr.to('cpu').to(torch.float) self.edge_attr = edge_attr
else: else:
self.edge_attr = None self.edge_attr = None
def _get_node_attr(self,ids): def _get_node_attr(self,ids,asyncOp = False):
if self.x is None: if self.x is None:
return None return None
elif dist.get_world_size() == 1: elif dist.get_world_size() == 1:
return self.x[ids] return self.x[ids]
else: else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids) return self.x.index_select(ids)
def _get_edge_attr(self,ids,):
def _get_edge_attr(self,ids,asyncOp = False):
if self.edge_attr is None: if self.edge_attr is None:
return None return None
elif dist.get_world_size() == 1: elif dist.get_world_size() == 1:
return self.edge_attr[ids] return self.edge_attr[ids]
else: else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids) return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
return mapper[ind.to(mapper.device)]
class DataSet: class DataSet:
def __init__(self,nodes = None, def __init__(self,nodes = None,
...@@ -106,7 +149,7 @@ class DataSet: ...@@ -106,7 +149,7 @@ class DataSet:
setattr(d,k,v[indx]) setattr(d,k,v[indx])
return d return d
class TemporalGraphData(): class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device): def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device) super(TemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size): def _set_temporal_batch_cache(self,size,pin_size):
...@@ -117,7 +160,7 @@ class TemporalGraphData(): ...@@ -117,7 +160,7 @@ class TemporalGraphData():
class TemporalNeighborSampleGraph(GraphData): class TemporalNeighborSampleGraph(DistributedGraphStore):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None): def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index'] self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1] self.num_edges = self.edge_index.shape[1]
......
...@@ -22,4 +22,6 @@ def build_mapper(nids): ...@@ -22,4 +22,6 @@ def build_mapper(nids):
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw) ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp) return DistIndex(ind_mp,part_mp)
\ No newline at end of file def get_validate_graph(self,graph):
pass
\ No newline at end of file
from typing import Union
from typing import List from typing import List
from typing import Optional
import torch import torch
from torch.distributed import rpc from torch.distributed import rpc
import torch_scatter import torch_scatter
...@@ -6,12 +8,15 @@ from starrygl.distributed.context import DistributedContext ...@@ -6,12 +8,15 @@ from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist import torch.distributed as dist
from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share
class SharedMailBox(): class SharedMailBox():
def __init__(self, def __init__(self,
num_nodes, num_nodes,
memory_param, memory_param,
dim_edge_feat, dim_edge_feat,
device = torch.device('cuda')): device = torch.device('cuda'),
uvm = False):
self.device = device self.device = device
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.num_parts = dist.get_world_size() self.num_parts = dist.get_world_size()
...@@ -19,19 +24,41 @@ class SharedMailBox(): ...@@ -19,19 +24,41 @@ class SharedMailBox():
raise NotImplementedError raise NotImplementedError
self.memory_param = memory_param self.memory_param = memory_param
self.memory_size = memory_param['dim_out'] self.memory_size = memory_param['dim_out']
assert not (device.type =='cpu' and uvm is True),\
'set uvm must set device on cuda'
memory_device = device
if device.type == 'cuda' and uvm is True:
memory_device = torch.device('cpu')
node_memory = torch.zeros(( node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']), self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =self.device) dtype=torch.float32,device =memory_device)
node_memory_ts = torch.zeros(self.num_nodes, node_memory_ts = torch.zeros(self.num_nodes,
dtype=torch.float32, dtype=torch.float32,
device = self.device) device = self.device)
mailbox = torch.zeros(self.num_nodes, mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'], memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat, 2 * memory_param['dim_out'] + dim_edge_feat,
device = self.device, dtype=torch.float32) device = memory_device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes, mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']), memory_param['mailbox_size']),
dtype=torch.float32,device = self.device) dtype=torch.float32,device = self.device)
self.uvm = uvm
if uvm is True:
ctx = DistributedContext.get_default_context()
node_memory = uvm_empty(*node_memory.shape,
dtype=node_memory.dtype,
device=ctx.device)
uvm_share(node_memory,device = ctx.device)
uvm_advise(node_memory,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(node_memory)
mailbox = uvm_empty(*mailbox.shape,
dtype=mailbox.dtype,
device=ctx.device)
uvm_share(mailbox,device = ctx.device)
uvm_advise(mailbox,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(mailbox)
self.node_memory = DistributedTensor(node_memory) self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts) self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox) self.mailbox = DistributedTensor(mailbox)
...@@ -40,8 +67,9 @@ class SharedMailBox(): ...@@ -40,8 +67,9 @@ class SharedMailBox():
dtype=torch.long, dtype=torch.long,
device = self.device) device = self.device)
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self.rref = rpc.RRef(self) if self._ctx._use_rpc == True:
self.rrefs = self._ctx.all_gather_remote_objects(self.rref) self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device) self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
...@@ -118,7 +146,7 @@ class SharedMailBox(): ...@@ -118,7 +146,7 @@ class SharedMailBox():
def set_mailbox_all_to_all(self,index,memory, def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts, memory_ts,mail,mail_ts,
reduce_Op = None): reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = [] #futs: List[torch.futures.Future] = []
if self.num_parts == 1: if self.num_parts == 1:
dist_index = DistIndex(index) dist_index = DistIndex(index)
...@@ -132,7 +160,7 @@ class SharedMailBox(): ...@@ -132,7 +160,7 @@ class SharedMailBox():
device = self.device) device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False) indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1] scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list) torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist() input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist() output_split = gather_len_list.tolist()
gather_id_list = torch.empty( gather_id_list = torch.empty(
...@@ -143,7 +171,7 @@ class SharedMailBox(): ...@@ -143,7 +171,7 @@ class SharedMailBox():
output_split = gather_len_list.tolist() output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split, gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split) input_split_sizes=input_split,group = group)
index = gather_id_list index = gather_id_list
gather_memory = torch.empty( gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]], [gather_len_list.sum(),memory.shape[1]],
...@@ -160,22 +188,81 @@ class SharedMailBox(): ...@@ -160,22 +188,81 @@ class SharedMailBox():
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
gather_memory,memory, gather_memory,memory,
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split) input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts, gather_memory_ts,memory_ts,
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split) input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
gather_mail,mail, gather_mail,mail,
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split) input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts, gather_mail_ts,mail_ts,
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split) input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op) self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op) self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
...@@ -220,6 +307,8 @@ class SharedMailBox(): ...@@ -220,6 +307,8 @@ class SharedMailBox():
self.node_memory_ts.accessor.data[index],\ self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\ self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index] self.mailbox_ts.accessor.data[index]
elif self.node_memory.rrefs is None:
return self.gather_memory(dist_index = index)
else: else:
memory = self.node_memory.index_select(index) memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index) memory_ts = self.node_memory_ts.index_select(index)
...@@ -234,9 +323,25 @@ class SharedMailBox(): ...@@ -234,9 +323,25 @@ class SharedMailBox():
#print(memory.shape[0]) #print(memory.shape[0])
return memory,memory_ts,mail,mail_ts return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback) return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(self,index,input_split,out_split):
return self.node_memory.scatter_data(index,input_split,out_split),\ def gather_memory(
self.node_memory_ts.scatter_data(index,input_split,out_split),\ self,
self.mailbox.scatter_data(index,input_split,out_split),\ dist_index: Union[torch.Tensor, DistIndex, None] = None,
self.mailbox_ts.scatter_data(index,input_split,out_split) send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group),\
self.node_memory_ts.all_to_all_get(**ids,group = group),\
self.mailbox.all_to_all_get(**ids,group = group),\
self.mailbox_ts.all_to_all_get(**ids,group = group)
...@@ -2,15 +2,16 @@ import parser ...@@ -2,15 +2,16 @@ import parser
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
from .torch_utils import get_norm_temporal
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
import torch import torch
import torch.utils.data import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data: def partition_load(root: str, algo: str = "metis") -> Data:
...@@ -22,6 +23,7 @@ def partition_load(root: str, algo: str = "metis") -> Data: ...@@ -22,6 +23,7 @@ def partition_load(root: str, algo: str = "metis") -> Data:
def partition_save(root: str, data: Data, num_parts: int, def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis", algo: str = "metis",
node_weight = None,
edge_weight_dict=None): edge_weight_dict=None):
root = osp.abspath(root) root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root): if osp.exists(root) and not osp.isdir(root):
...@@ -46,6 +48,7 @@ def partition_save(root: str, data: Data, num_parts: int, ...@@ -46,6 +48,7 @@ def partition_save(root: str, data: Data, num_parts: int,
if algo == 'metis_for_tgnn': if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn( for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True, data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)): edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}") print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}") fn = osp.join(path, f"{i:03d}")
...@@ -153,30 +156,41 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int): ...@@ -153,30 +156,41 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int):
def metis_for_tgnn(edge_index_dict: dict, def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int, num_nodes: int,
num_parts: int, num_parts: int,
node_weight = None,
edge_weight_dict=None): edge_weight_dict=None):
if num_parts <= 1: if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes) return _nopart(edge_index_dict, num_nodes)
G = nx.Graph() edge_list = []
G.add_nodes_from(torch.arange(0, num_nodes).tolist()) weight_list = []
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1), for i,key in enumerate(edge_index_dict):
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key] v = edge_index_dict[key]
edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) * edge_list.append(v)
edge_weight_dict[key]).unsqueeze(0)), dim=0) weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
# w = edges.T edge_index = torch.cat(edge_list,dim = 1)
G.add_weighted_edges_from((edges.T).tolist()) edge_weight = torch.cat(weight_list,dim = 0)
G.graph['edge_weight_attr'] = 'weight' node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
""" """
...@@ -187,6 +201,7 @@ weight: 各种工作负载边划分权重 ...@@ -187,6 +201,7 @@ weight: 各种工作负载边划分权重
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str, def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False, verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None): edge_weight_dict: dict = None):
if algo == "metis_for_tgnn": if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn part_fn = metis_for_tgnn
...@@ -200,6 +215,7 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str, ...@@ -200,6 +215,7 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
if verbose: if verbose:
print(f"running partition algorithm: {algo}") print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts, node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict) edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]] edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long) eids = torch.arange(num_edges, dtype=torch.long)
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='torch_utils',
ext_modules=[
CppExtension(
name='torch_utils',
sources=['./torch_utils.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["../sample_core"],
),
],
cmdclass={
'build_ext': BuildExtension
})#
#setup(
# name='cpu_cache_manager',
# ext_modules=[
# CppExtension(
# name='cpu_cache_manager',
# sources=['cpu_cache_manager.cpp'],
# extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
# include_dirs=["./"],
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# })#
#
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor, thread
from multiprocessing import Process
from multiprocessing.pool import ThreadPool
from typing import Deque
import torch import torch
class WorkStreamEvent: import torch.distributed as dist
def __init__(self): stage = ['train_stream','write_memory','write_mail','lookup']
self.train_stream = torch.cuda.Stream()
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream() class PipelineManager:
self.write_mail_stream = torch.cuda.Stream() def __init__(self,num_tasks = 10):
self.event = None self.stream_set = {}
event = WorkStreamEvent() self.dist_set = {}
def get_event(): self.args_queue = {}
return event.event self.thread_pool = ThreadPoolExecutor(num_tasks)
for k in stage:
self.stream_set[k] = torch.cuda.Stream()
self.dist_set[k] = dist.new_group()
self.args_queue[k] = Deque()
def submit(self,state,func,kwargs):
future = self.thread_pool.submit(self.run, state,func,kwargs)
return future
def run(self,state,func,kwargs):
with torch.cuda.stream(self.stream_set[state]):
return func(**kwargs,group = self.dist_set[state])
manger = None
def getPipelineManger():
global manger
if manger == None:
manger = PipelineManager()
return manger
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor
def my_function():
# 执行一些操作...
print('hi')
return 'result'
# 创建线程池
executor = ThreadPoolExecutor()
# 提交任务并获取Future对象
future = executor.submit(my_function)
# 等待任务执行完成并获取结果
future.wait()
result = future.result()
# 打印结果
print(result) # 'result'
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 0 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 1 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 2 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 3 > taobao4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 0 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 1 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 2 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 3 > GDELT4.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 0 > ML25M2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 1 > ML25M2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 0 > TaoBao2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 1 > TaoBao2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 1 --rank 0 > ML25M1.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 1 --rank 0 > TaoBao1.out &
nohup python train_tgnn.py --dataname GDELT --world_size 1 --rank 0 > GDELT1.out &
\ No newline at end of file
...@@ -7,7 +7,8 @@ from starrygl.distributed.utils import DistIndex ...@@ -7,7 +7,8 @@ from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
...@@ -25,32 +26,18 @@ import os ...@@ -25,32 +26,18 @@ import os
from starrygl.sample.data_loader import DistributedDataLoader from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE from starrygl.sample.batch_data import SAMPLE_TYPE
""" from starrygl.sample.stream_manager import getPipelineManger
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example", description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument('--rank', default=0, type=str, metavar='W', parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset') help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W', parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples') help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples')
args = parser.parse_args() args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
...@@ -60,11 +47,12 @@ import dgl ...@@ -60,11 +47,12 @@ import dgl
import numpy as np import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank) #os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank) #os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size) #os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0) #os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1' torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337' os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42): def seed_everything(seed=42):
random.seed(seed) random.seed(seed)
...@@ -76,18 +64,19 @@ def seed_everything(seed=42): ...@@ -76,18 +64,19 @@ def seed_everything(seed=42):
seed_everything(1234) seed_everything(1234)
def main(): def main():
print('main')
use_cuda = True use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml') sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12) torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
print('use cuda on',device_id) print('use cuda on',device_id)
pdata = partition_load("./dataset/here/GDELT", algo="metis_for_tgnn") pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = GraphData(pdata = pdata) graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full') sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0) mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train") sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1) train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device)) train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1) val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
...@@ -96,8 +85,16 @@ def main(): ...@@ -96,8 +85,16 @@ def main():
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device)) test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_data.shape[1],val_data.shape[1],test_data.shape[1]) print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1)) train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1)) test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1)) val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1)) #train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet') neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
...@@ -109,7 +106,8 @@ def main(): ...@@ -109,7 +106,8 @@ def main():
chunk_size = None, chunk_size = None,
train=True, train=True,
queue_size = 1000, queue_size = 1000,
mailbox = mailbox) mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler, testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
...@@ -130,8 +128,12 @@ def main(): ...@@ -130,8 +128,12 @@ def main():
train=False, train=False,
queue_size = 100, queue_size = 100,
mailbox = mailbox) mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1] gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1] gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
...@@ -159,40 +161,42 @@ def main(): ...@@ -159,40 +161,42 @@ def main():
with torch.no_grad(): with torch.no_grad():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader: for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos)) total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg)) total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy())) aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred)) aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None: if mailbox is not None:
src = metadata['src_pos_index'] src = metadata['src_pos_index']
dst = metadata['dst_pos_index'] dst = metadata['dst_pos_index']
ts = roots.ts ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')): if graph.edge_attr is None:
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None edge_feats = None
else: elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
dist_index_mapper = mfgs[0][0].srcdata['ID'] else:
root_index = torch.cat((src,dst)) edge_feats = graph.edge_attr[roots.eids]
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index] dist_index_mapper = mfgs[0][0].srcdata['ID']
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index] root_index = torch.cat((src,dst))
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index] last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid, last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_memory, last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
last_updated_ts) index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, last_updated_ts)
src,dst,ts,edge_feats, #
model.module.memory_updater.last_updated_memory, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
) src,dst,ts,edge_feats,
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean()) #ap = float(torch.tensor(aps).mean())
#if neg_samples > 1: #if neg_samples > 1:
...@@ -207,7 +211,6 @@ def main(): ...@@ -207,7 +211,6 @@ def main():
ap = float(torch.tensor(apc).mean()) ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean()) auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
...@@ -242,15 +245,20 @@ def main(): ...@@ -242,15 +245,20 @@ def main():
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy())) train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None: if mailbox is not None:
src = metadata['src_pos_index'] src = metadata['src_pos_index']
dst = metadata['dst_pos_index'] dst = metadata['dst_pos_index']
ts = roots.ts ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')): if graph.edge_attr is None:
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else: else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID'] dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst)) root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index] last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
...@@ -263,8 +271,11 @@ def main(): ...@@ -263,8 +271,11 @@ def main():
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) )
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize() torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time time_prep = time.time() - epoch_start_time
...@@ -272,10 +283,14 @@ def main(): ...@@ -272,10 +283,14 @@ def main():
train_ap = float(torch.tensor(train_aps).mean()) train_ap = float(torch.tensor(train_aps).mean())
ap = 0 ap = 0
auc = 0 auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val') ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
model.eval() model.eval()
if mailbox is not None: if mailbox is not None:
mailbox.reset() mailbox.reset()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment