Commit 44d3b857 by Wenjie Huang

Merge remote-tracking branch 'origin/develop-v2'

parents 9ca0dca0 6cb3e218
install.sh
a.out
third_party
install.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......
......@@ -155,6 +155,8 @@ endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 5
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
......@@ -2,6 +2,7 @@
#include <sampler.h>
#include <output.h>
#include <neighbors.h>
#include <temporal_utils.h>
/*------------Python Bind--------------------------------------------------------------*/
......@@ -16,7 +17,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::return_value_policy::reference)
.def("divide_nodes_to_part",
&divide_nodes_to_part,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<vector<NodeIDType> &, vector<NodeIDType> &,
......@@ -79,4 +87,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
}
\ No newline at end of file
......@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
// uniform_int_distribution<> u(0, tnb.deg[node]-1);
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){
for(int i=0;i<fanout;i++){
//循环选择fanout个邻居
//ѭ��ѡ��fanout���ھ�
NodeIDType indice;
if(policy == "weighted"){//考虑边权重信
if(policy == "weighted"){//���DZ�Ȩ����Ϣ
const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e);
}
else if(policy == "uniform"){//均匀采样
else if(policy == "uniform"){//���Ȳ���
// indice = u(e);
indice = rand_r(&loc_seed) % (nei.size());
}
......@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
auto chosen_e_iter = edge.begin() + indice;
if(part_unique){
auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //不重复
if(rst.second){ //���ظ�
eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter);
if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++;
......@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
}
}
else{
//可选邻居边大于扇出的话需要随机选择fanout个邻居
//��ѡ�ھӱߴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl;
......
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#pragma once
#include <torch/extension.h>
#include <parallel_hashmap/phmap.h>
#include <cstring>
......
......@@ -44,14 +44,14 @@ def load_feat(d, rand_de=0, rand_dn=0):
data_name = args.data_name
g = np.load('../tgl_main/DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('../tgl_main/DATA/'+data_name+'/edges.csv')
if os.path.exists('../tgl_main/DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('../tgl_main/DATA/'+data_name+'/node_features.pt')
g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('../tgl_main/DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('../tgl_main/DATA/'+data_name+'/edge_features.pt')
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
else:
e_feat = None
......
#include <cuda_runtime.h>
#include <cstdio>
int main()
{
int count = 0;
if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;
if (count == 0) return -1;
for (int device = 0; device < count; ++device)
{
cudaDeviceProp prop;
if (cudaSuccess == cudaGetDeviceProperties(&prop, device))
std::printf("%d.%d ", prop.major, prop.minor);
}
return 0;
}
#include <cuda.h>
#include <cstdio>
int main() {
printf("%d.%d", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);
return 0;
}
......@@ -3,9 +3,9 @@
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.8/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
......
import torch
from torch import Tensor
......@@ -23,6 +31,9 @@ __all__ = [
Strings = Sequence[str]
OptStrings = Optional[Strings]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
......@@ -178,6 +189,13 @@ class GraphData:
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self,
......@@ -307,7 +325,6 @@ class GraphData:
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData:
def __init__(self) -> None:
......
......@@ -133,7 +133,6 @@ def all_to_all_s(
output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])]
input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])]
return dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
......
......@@ -58,7 +58,6 @@ class DistributedContext:
master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext(
backend=backend,
ccl_init_method=ccl_init_url,
......@@ -69,6 +68,7 @@ class DistributedContext:
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
return ctx
......
......@@ -16,7 +16,11 @@ class TensorAccessor:
self._data = data
self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data)
if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
else:
self._rref = None
self.stream = torch.cuda.Stream()
@property
def data(self):
......@@ -29,6 +33,12 @@ class TensorAccessor:
@property
def ctx(self):
return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
......@@ -134,18 +144,29 @@ class DistIndex:
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
if self.accessor.rref is not None:
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property
def dtype(self):
return self.accessor.data.dtype
......@@ -180,10 +201,10 @@ class DistributedTensor:
def ctx(self):
return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex]) -> Dict[str, Union[List[int], Tensor]]:
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
......@@ -195,8 +216,8 @@ class DistributedTensor:
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device)
all_to_all_s(recv_ind, dist_index.loc, send_ptr, recv_ptr)
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
"send_ptr": send_ptr,
......@@ -209,6 +230,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
......@@ -217,8 +239,8 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, send_ptr, recv_ptr)
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
def all_to_all_set(self,
......@@ -227,6 +249,7 @@ class DistributedTensor:
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
......@@ -235,7 +258,7 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr)
all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
......
......@@ -32,8 +32,8 @@ class EdgePredictor(torch.nn.Module):
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[:num_edge]) #
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2]) #
h_neg_src = self.src_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
......
......@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
......
......@@ -3,10 +3,13 @@ import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
......@@ -42,8 +45,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda')):
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
......@@ -75,27 +77,39 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
local_index, input_split,output_split = graph.edge_attr.gather_select_index(dist_eid)
edge_feat = graph.edge_attr.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
edge_feat = graph._get_edge_attr(dist_eid)
local_index = None
if isinstance(graph.x,DistributedTensor):
local_index, input_split,output_split = graph.x.gather_select_index(dist_nid)
node_feat = graph.x.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
local_index, input_split,output_split = mailbox.node_memory.gather_select_index(dist_nid)
mem = mailbox.gather_memory(local_index,input_split,output_split)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else:
mem = None
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
......@@ -110,23 +124,12 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx#dist_nid[idx]
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx#dist_eid[e_idx]
#if edge_feat is not None:
# b.edata['f'] = edge_feat[e_idx]
#if r == len(eid_len)-1:
# if node_feat is not None:
# b.srcdata['h'] = node_feat[idx]
# if mem_embedding is not None:
# node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
# b.srcdata['mem'] = node_memory[idx]
# b.srcdata['mem_ts'] = node_memory_ts[idx]
# b.srcdata['mem_input'] = mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
# b.srcdata['mail_ts'] = mailbox_ts[idx]
b.edata['ID'] = e_idx
col = row
col_len += eid_len[r]
row_len += eid_len[r]
......@@ -134,39 +137,28 @@ def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device =
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
#if dist.get_world_size() > 1:
# if(node_feat is None):
# node_feat = torch.futures.Future()
# node_feat.set_result(None)
# if(edge_feat is None):
# edge_feat = torch.futures.Future()
# edge_feat.set_result(None)
# if(mem is None):
# mem = torch.futures.Future()
# mem.set_result(None)
# def callback(fs,mfgs,dist_nid,dist_eid):
# node_feat,edge_feat,mem_embedding = fs.value()
# node_feat = node_feat.value()
# edge_feat = edge_feat.value()
# mem_embedding = mem_embedding.value()
# return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
# cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
# return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
#else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda')):
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
return to_block(graph,data,out,mailbox,device)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
......
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class LRUCache(Cache):
"""
Least-recently-used (LRU) cache
"""
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(LRUCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'lru'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_feature: Sequence[torch.Tensor]):
if len(uncached_index) > self.capacity:
num_to_cache = self.capacity
else:
num_to_cache = len(uncached_index)
node_id_to_cache = uncached_index[:num_to_cache].to(torch.int32)
self.now_cache_count -= 1
self.cache_count[cached_index] = 0
# get the k node id with the least water level
removing_cache_index = torch.topk(
self.cache_count, k=num_to_cache, largest=False).indices.to(torch.int32)
removing_node_id = self.cache_index_to_id[removing_cache_index]
# update cache attributes
for buffer,data in zip(self.buffers,uncached_feature):
buffer[removing_cache_index] = data[:num_to_cache].reshape(-1,*buffer.shape[1:])
self.cache_count[removing_cache_index] = 0
self.cache_validate[removing_node_id] = False
self.cache_validate[node_id_to_cache] = True
self.cache_map[removing_node_id] = -1
self.cache_map[node_id_to_cache] = removing_cache_index
self.cache_index_to_id[removing_cache_index] = node_id_to_cache
from typing import Callable, List, Optional, Sequence, Union
import numpy as np
import torch
from starrygl.distributed.utils import DistributedTensor
class Cache:
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
print(len(cache_data),cache_data)
assert torch.cuda.is_available() == True
self.use_local = use_local
self.is_update_cache = is_update_cache
self.device = torch.device('cuda')
self.use_remote = torch.distributed.get_world_size()>1
assert not (self.use_local is False and self.use_remote is False),\
"the data is on the cuda and no need remote cache"
self.cache_ratio = cache_ratio
self.num_cache = num_cache
self.capacity = int(self.num_cache * cache_ratio)
self.update_stream = torch.cuda.Stream()
self.buffers = []
self.pinned_buffers = []
for data in cache_data:
self.buffers.append(
torch.zeros(self.capacity,*data.shape[1:],
dtype = data.dtype,device = torch.device('cuda'))
)
self.cache_validate = torch.zeros(
num_cache, dtype=torch.bool, device=self.device)
# maps node id -> index
self.cache_map = torch.zeros(
num_cache, dtype=torch.int32, device=self.device) - 1
# maps index -> node id
self.cache_index_to_id = torch.zeros(
num_cache,dtype=torch.int32, device=self.device) -1
self.hit_sum = 0
self.hit_ = 0
def init_cache(self,ind:torch.Tensor,data:Sequence[torch.Tensor]):
pos = torch.arange(ind.shape[0],device = 'cuda',dtype = ind.dtype)
self.cache_map[ind] = pos.to(torch.int32).to('cuda')
self.cache_index_to_id[pos] = ind.to(torch.int32).to('cuda')
for data,buffer in zip(data,self.buffers):
buffer[:ind.shape[0],] = data
self.cache_validate[ind] = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_data: Sequence[torch.Tensor]):
raise NotImplementedError
def fetch_data(self,ind:Optional[torch.Tensor] = None,
uncached_source_fn: Callable = None, source_index:torch.Tensor = None):
self.hit_sum += ind.shape[0]
assert isinstance(ind, torch.Tensor)
cache_mask = self.cache_validate[ind]
uncached_mask = ~cache_mask
self.hit_ += torch.sum(cache_mask)
cached_data = []
cached_index = self.cache_map[ind[cache_mask]]
if uncached_mask.sum() > 0:
uncached_id = ind[uncached_mask]
source_index = source_index[uncached_mask]
uncached_feature = uncached_source_fn(source_index)
if isinstance(uncached_feature,torch.Tensor):
uncached_feature = [uncached_feature]
else:
uncached_id = None
uncached_feature = [None for _ in range(len(self.buffers))]
for data,uncached_data in zip(self.buffers,uncached_feature):
nfeature = torch.zeros(
len(ind), *data.shape[1:], dtype=data.dtype,device=self.device)
nfeature[cache_mask,:] = data[cached_index]
if uncached_id is not None:
nfeature[uncached_mask] = uncached_data.reshape(-1,*data.shape[1:])
cached_data.append(nfeature)
if self.is_update_cache and uncached_mask.sum() > 0:
self.update_cache(cached_index=cached_index,
uncached_index=uncached_id,
uncached_feature=uncached_feature)
return nfeature
def invalidate(self,ind):
self.cache_validate[ind] = False
from typing import Optional
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.cache import LRU_cache
from starrygl.sample.cache.cache import Cache
from starrygl.sample.cache.static_cache import StaticCache
from starrygl.sample.cache.utils import pre_sample
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.memory.shared_mailbox import SharedMailBox
import torch
_FetchCache = None
class FetchFeatureCache:
@staticmethod
def create_fetch_cache(num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'):
global _FetchCache
_FetchCache = FetchFeatureCache(num_nodes, num_edges,
edge_cache_ratio, node_cache_ratio,
graph,mailbox,policy)
@staticmethod
def getFetchCache():
global _FetchCache
return _FetchCache
def __init__(self, num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'
):
if policy == 'lru':
init_fn = LRU_cache.LRUCache
elif policy == 'static':
init_fn = StaticCache
self.ctx = DistributedContext.get_default_context()
if graph.x is not None:
self.node_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[graph.x],use_local=graph.uvm_node)
else:
self.node_cache = None
if graph.edge_attr is not None:
self.edge_cache:Cache = init_fn(edge_cache_ratio,num_edges,
[graph.edge_attr],use_local = graph.uvm_edge)
else:
self.edge_cache = None
if mailbox is not None:
self.mailbox_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[mailbox.node_memory,
mailbox.node_memory_ts.accessor.data.reshape(-1,1),
mailbox.mailbox,
mailbox.mailbox_ts],
use_local = mailbox.uvm)
else:
self.mailbox_cache = None
self.graph = graph
self.mailbox = mailbox
global FetchCache
FetchCache = self
def fetch_feature(self, nid: Optional[torch.Tensor] = None, dist_nid = None,
eid: Optional[torch.Tensor] = None, dist_eid = None
):
nfeat = None
mem = None
efeat = None
if self.node_cache is not None and nid is not None:
nfeat = torch.zeros(nid.shape[0],
self.node_cache.buffers[0].shape[1],
dtype = self.node_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.node_cache.use_local is False:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
nfeat[local_mask] = self.graph.x.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
nfeat[remote_mask] = self.node_cache.fetch_data(remote_id,\
self.graph._get_node_attr,source_id)[0]
else:
nfeat = self.node_cache.fetch_data(nid,
self.graph._get_node_attr,dist_nid)[0]
if self.mailbox_cache is not None and nid is not None:
memory = torch.zeros(nid.shape[0],
self.mailbox_cache.buffers[0].shape[1],
dtype = self.mailbox_cache.buffers[0].dtype,
device = torch.device('cuda')
)
memory_ts = torch.zeros(nid.shape[0],
dtype = self.mailbox_cache.buffers[1].dtype,
device = torch.device('cuda')
)
mailbox = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[2].shape[1:],
dtype = self.mailbox_cache.buffers[2].dtype,
device = torch.device('cuda')
)
mailbox_ts = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[3].shape[1:],
dtype = self.mailbox_cache.buffers[3].dtype,
device = torch.device('cuda')
)
if self.mailbox_cache.use_local is False:
if self.node_cache is None:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
remote_mask = ~local_mask
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
mem = self.mailbox.gather_memory(local_id)
memory[local_mask],memory_ts[local_mask],mailbox[local_mask],mailbox_ts[local_mask]= mem
if remote_mask.sum() > 0:
mem = self.mailbox_cache.fetch_data(remote_id,\
self.mailbox.gather_memory,source_id)
memory[remote_mask] = mem[0]
memory_ts[remote_mask] = mem[1].reshape(-1)
mailbox[remote_mask] = mem[2]
mailbox_ts[remote_mask] = mem[3]
mem = memory,memory_ts,mailbox,mailbox_ts
else:
mem = self.mailbox_cache.fetch_data(nid,mailbox.gather_memory,dist_nid)
if self.edge_cache is not None and eid is not None:
efeat = torch.zeros(eid.shape[0],
self.edge_cache.buffers[0].shape[1],
dtype = self.edge_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.edge_cache.use_local is False:
local_mask = (DistIndex(dist_eid).part == torch.distributed.get_rank())
local_id = dist_eid[local_mask]
efeat[local_mask] = self.graph.edge_attr.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = eid[remote_mask]
source_id = dist_eid[remote_mask]
efeat[remote_mask] = self.edge_cache.fetch_data(remote_id,\
self.graph._get_edge_attr,source_id)[0]
else:
efeat = self.node_cache.fetch_data(eid,
self.graph._get_edge_attr,dist_eid)[0]
return nfeat,efeat,mem
def init_cache_with_presample(self,dataloader, num_epoch:int = 10):
node_size = self.node_cache.capacity if self.node_cache is not None else 0
edge_size = self.edge_cache.capacity if self.edge_cache is not None else 0
node_counts,edge_counts = pre_sample(dataloader=dataloader,
num_epoch=num_epoch,
node_size = node_size,
edge_size = edge_size)
if node_size != 0:
if self.node_cache.use_local is False:
dist_mask = DistIndex(self.graph.nids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
node_counts = node_counts[dist_mask]
_,nid = node_counts.topk(node_size)
if self.node_cache.use_local is False:
nid = dist_mask.nonzero()[nid]
dist_nid = self.graph.nids_mapper[nid].unique()
node_feature = self.graph._get_node_attr(dist_nid.to(self.graph.x.device))
_nid = nid.reshape(-1)
self.node_cache.init_cache(_nid,node_feature)
print('finish node init')
if edge_size != 0:
if self.edge_cache.use_local is False:
dist_mask = DistIndex(self.graph.eids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
edge_counts = edge_counts[dist_mask]
_,eid = edge_counts.topk(edge_size)
if self.edge_cache.use_local is False:
eid_ = dist_mask.nonzero()[eid]
else:
eid_ = eid
dist_eid = self.graph.eids_mapper[eid_].unique()
edge_feature = self.graph._get_edge_attr(dist_eid.to(self.graph.edge_attr.device))
eid_ = eid_.reshape(-1)
self.edge_cache.init_cache(eid_,edge_feature)
print('finish edge init')
\ No newline at end of file
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class StaticCache(Cache):
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(StaticCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'static'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = False
import torch
#dataloader不要加import
def pre_sample(dataloader, num_epoch:int,node_size:int,edge_size:int):
nodes_counts = torch.zeros(dataloader.graph.num_nodes,dtype = torch.long)
edges_counts = torch.zeros(dataloader.graph.num_edges,dtype = torch.long)
print(nodes_counts.shape,edges_counts.shape)
sampler = dataloader.sampler
neg_sampling = dataloader.neg_sampler
sample_fn = dataloader.sampler_fn
graph = dataloader.graph
for _ in range(num_epoch):
dataloader.__iter__()
while dataloader.recv_idxs < dataloader.expected_idx:
dataloader.recv_idxs += 1
data = dataloader._next_data()
out = sample_fn(sampler,data,neg_sampling)
if(len(out)>0):
sample_out,metadata = out
else:
sample_out = out
eid = [ret.eid() for ret in sample_out]
eid_tensor = torch.cat(eid,dim = 0)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
dst_node = graph.sample_graph['edge_index'][1,eid_tensor*2].to(graph.nids_mapper.device)
eid_tensor = torch.unique(eid_tensor)
nid_tensor = torch.unique(torch.cat((src_node,dst_node)))
edges_counts[eid_tensor] += 1
nodes_counts[nid_tensor] += 1
return nodes_counts,edges_counts
......@@ -42,6 +42,7 @@ class DistributedDataLoader:
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
......@@ -63,11 +64,13 @@ class DistributedDataLoader:
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
......@@ -134,7 +137,7 @@ class DistributedDataLoader:
return next_data
def __next__(self):
if(dist.get_world_size() > 0):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
......@@ -148,47 +151,36 @@ class DistributedDataLoader:
return batch_data
else :
raise StopIteration
else :
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
if len(self.result_queue)>0:
result = self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
else:
next_data = self._next_data()
assert next_data is not None
batch_data = graph_sample(self.graph,
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
next_data,self.neg_sampler,
data,self.neg_sampler,
self.mailbox,
self.device)
batch_data[1].wait()
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
self.recv_idxs += 1
self.num_pending -= 1
return batch_data[0],batch_data[1].value(),batch_data[2]
self.result_queue.append(batch_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
while(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
assert self.num_pending == 0
raise StopIteration
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
......
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class GraphData():
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False):
from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share
class DistributedGraphStore:
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).data.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).data.to('cpu')
if all_on_gpu:
self.nids_mapper = self.nids_mapper.to(device)
self.eids_mapper = self.eids_mapper.to(device)
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = pdata.x.to(device).to(torch.float)
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to('cpu').to(torch.float))
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = pdata.edge_attr.to('cpu').to(torch.float)
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids):
def _get_node_attr(self,ids,asyncOp = False):
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,):
def _get_edge_attr(self,ids,asyncOp = False):
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
return mapper[ind.to(mapper.device)]
class DataSet:
def __init__(self,nodes = None,
......@@ -106,7 +149,7 @@ class DataSet:
setattr(d,k,v[indx])
return d
class TemporalGraphData():
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
......@@ -117,7 +160,7 @@ class TemporalGraphData():
class TemporalNeighborSampleGraph(GraphData):
class TemporalNeighborSampleGraph(DistributedGraphStore):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
......
......@@ -22,4 +22,6 @@ def build_mapper(nids):
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp)
\ No newline at end of file
def get_validate_graph(self,graph):
pass
\ No newline at end of file
from typing import Union
from typing import List
from typing import Optional
import torch
from torch.distributed import rpc
import torch_scatter
......@@ -6,12 +8,15 @@ from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist
from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share
class SharedMailBox():
def __init__(self,
num_nodes,
memory_param,
dim_edge_feat,
device = torch.device('cuda')):
device = torch.device('cuda'),
uvm = False):
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
......@@ -19,19 +24,41 @@ class SharedMailBox():
raise NotImplementedError
self.memory_param = memory_param
self.memory_size = memory_param['dim_out']
assert not (device.type =='cpu' and uvm is True),\
'set uvm must set device on cuda'
memory_device = device
if device.type == 'cuda' and uvm is True:
memory_device = torch.device('cpu')
node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =self.device)
dtype=torch.float32,device =memory_device)
node_memory_ts = torch.zeros(self.num_nodes,
dtype=torch.float32,
device = self.device)
mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat,
device = self.device, dtype=torch.float32)
device = memory_device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']),
dtype=torch.float32,device = self.device)
dtype=torch.float32,device = self.device)
self.uvm = uvm
if uvm is True:
ctx = DistributedContext.get_default_context()
node_memory = uvm_empty(*node_memory.shape,
dtype=node_memory.dtype,
device=ctx.device)
uvm_share(node_memory,device = ctx.device)
uvm_advise(node_memory,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(node_memory)
mailbox = uvm_empty(*mailbox.shape,
dtype=mailbox.dtype,
device=ctx.device)
uvm_share(mailbox,device = ctx.device)
uvm_advise(mailbox,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(mailbox)
self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox)
......@@ -40,8 +67,9 @@ class SharedMailBox():
dtype=torch.long,
device = self.device)
self._ctx = DistributedContext.get_default_context()
self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
if self._ctx._use_rpc == True:
self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
......@@ -118,7 +146,7 @@ class SharedMailBox():
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None):
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
......@@ -132,7 +160,7 @@ class SharedMailBox():
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list)
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
......@@ -143,7 +171,7 @@ class SharedMailBox():
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
......@@ -160,22 +188,81 @@ class SharedMailBox():
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
......@@ -220,6 +307,8 @@ class SharedMailBox():
self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index]
elif self.node_memory.rrefs is None:
return self.gather_memory(dist_index = index)
else:
memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index)
......@@ -234,9 +323,25 @@ class SharedMailBox():
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(self,index,input_split,out_split):
return self.node_memory.scatter_data(index,input_split,out_split),\
self.node_memory_ts.scatter_data(index,input_split,out_split),\
self.mailbox.scatter_data(index,input_split,out_split),\
self.mailbox_ts.scatter_data(index,input_split,out_split)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group),\
self.node_memory_ts.all_to_all_get(**ids,group = group),\
self.mailbox.all_to_all_get(**ids,group = group),\
self.mailbox_ts.all_to_all_get(**ids,group = group)
......@@ -2,15 +2,16 @@ import parser
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
from .torch_utils import get_norm_temporal
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data:
......@@ -22,6 +23,7 @@ def partition_load(root: str, algo: str = "metis") -> Data:
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
node_weight = None,
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
......@@ -46,6 +48,7 @@ def partition_save(root: str, data: Data, num_parts: int,
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
......@@ -153,30 +156,41 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int):
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
node_weight = None,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
for i, key in enumerate(edge_index_dict):
edge_list = []
weight_list = []
for i,key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
edge_weight_dict[key]).unsqueeze(0)), dim=0)
# w = edges.T
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_list.append(v)
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
edge_index = torch.cat(edge_list,dim = 1)
edge_weight = torch.cat(weight_list,dim = 0)
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
"""
......@@ -187,6 +201,7 @@ weight: 各种工作负载边划分权重
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
......@@ -200,6 +215,7 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='torch_utils',
ext_modules=[
CppExtension(
name='torch_utils',
sources=['./torch_utils.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["../sample_core"],
),
],
cmdclass={
'build_ext': BuildExtension
})#
#setup(
# name='cpu_cache_manager',
# ext_modules=[
# CppExtension(
# name='cpu_cache_manager',
# sources=['cpu_cache_manager.cpp'],
# extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
# include_dirs=["./"],
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# })#
#
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor, thread
from multiprocessing import Process
from multiprocessing.pool import ThreadPool
from typing import Deque
import torch
class WorkStreamEvent:
def __init__(self):
self.train_stream = torch.cuda.Stream()
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream()
self.event = None
event = WorkStreamEvent()
def get_event():
return event.event
import torch.distributed as dist
stage = ['train_stream','write_memory','write_mail','lookup']
class PipelineManager:
def __init__(self,num_tasks = 10):
self.stream_set = {}
self.dist_set = {}
self.args_queue = {}
self.thread_pool = ThreadPoolExecutor(num_tasks)
for k in stage:
self.stream_set[k] = torch.cuda.Stream()
self.dist_set[k] = dist.new_group()
self.args_queue[k] = Deque()
def submit(self,state,func,kwargs):
future = self.thread_pool.submit(self.run, state,func,kwargs)
return future
def run(self,state,func,kwargs):
with torch.cuda.stream(self.stream_set[state]):
return func(**kwargs,group = self.dist_set[state])
manger = None
def getPipelineManger():
global manger
if manger == None:
manger = PipelineManager()
return manger
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor
def my_function():
# 执行一些操作...
print('hi')
return 'result'
# 创建线程池
executor = ThreadPoolExecutor()
# 提交任务并获取Future对象
future = executor.submit(my_function)
# 等待任务执行完成并获取结果
future.wait()
result = future.result()
# 打印结果
print(result) # 'result'
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 0 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 1 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 2 > taobao4.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 4 --rank 3 > taobao4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 0 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 1 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 2 > GDELT4.out &
nohup python train_tgnn.py --dataname GDELT --world_size 4 --rank 3 > GDELT4.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 0 > ML25M2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 2 --rank 1 > ML25M2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 0 > TaoBao2.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 2 --rank 1 > TaoBao2.out &
nohup python train_tgnn.py --dataname ML25M --world_size 1 --rank 0 > ML25M1.out &
nohup python train_tgnn.py --dataname TaoBao --world_size 1 --rank 0 > TaoBao1.out &
nohup python train_tgnn.py --dataname GDELT --world_size 1 --rank 0 > GDELT1.out &
\ No newline at end of file
......@@ -7,7 +7,8 @@ from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
......@@ -25,32 +26,18 @@ import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=str, metavar='W',
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
......@@ -60,11 +47,12 @@ import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1'
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
......@@ -76,18 +64,19 @@ def seed_everything(seed=42):
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/GDELT", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train")
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
......@@ -96,8 +85,16 @@ def main():
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
......@@ -109,7 +106,8 @@ def main():
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox)
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
......@@ -130,8 +128,12 @@ def main():
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
......@@ -159,40 +161,42 @@ def main():
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
......@@ -207,7 +211,6 @@ def main():
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
......@@ -242,15 +245,20 @@ def main():
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
......@@ -263,8 +271,11 @@ def main():
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
......@@ -272,10 +283,14 @@ def main():
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
model.eval()
if mailbox is not None:
mailbox.reset()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment