Commit 21c40ac0 by xxx

fix some bugs in batch_data.py and get_update_mail

parent 023440c3
import torch
import logging
__version__ = "0.1.0"
try:
from .lib import libstarrygl as ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.")
try:
from .lib import libstarrygl_sampler as sampler_ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl_sampler.so, some features may not be available.")
\ No newline at end of file
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
from .graph import *
from .utils import *
\ No newline at end of file
This diff is collapsed. Click to expand it.
import torch
from torch import Tensor
from typing import *
__all__ = [
"init_vc_edge_index",
]
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from .context import DistributedContext
from .utils import (
TensorAccessor,
DistributedTensor,
DistIndex,
)
def init_distributed_context(backend: str = "gloo") -> DistributedContext:
return DistributedContext.init(backend)
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"all_to_all_v",
"all_to_all_s",
"BatchWork",
"batch_send",
"batch_recv",
]
class BatchWork:
def __init__(self,
works: Optional[List[Any]],
buffer_tensor_list: Optional[List[Tuple[Tensor, Optional[Tensor]]]],
step: int = 1,
) -> None:
if works is None:
self._step = None
self._works = None
self._buffer_tensor_list = None
else:
if buffer_tensor_list:
assert len(works) // step == len(buffer_tensor_list)
self._step = step
self._works = works
self._buffer_tensor_list = buffer_tensor_list
def wait(self):
if self._works is None:
return
for i, w in enumerate(self._works):
if w is not None:
w.wait()
if (i + 1) % self._step != 0:
continue
if self._buffer_tensor_list:
out, buf = self._buffer_tensor_list[i // self._step]
if buf is not None:
out.copy_(buf)
self._step = None
self._works = None
self._buffer_tensor_list = None
def all_to_all_v(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
async_op: bool = False,
):
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size
backend = dist.get_backend(group)
if backend == "nccl":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork([work], None) if async_op else None
elif backend == "mpi":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork([work], None) if async_op else None
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
p2p_op_works = []
buffer_tensor_list = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_t = input_tensor_list[send_i]
recv_t = output_tensor_list[recv_i]
if send_t.is_cuda:
send_t = send_t.cpu()
if recv_t.is_cuda:
recv_b = torch.empty_like(recv_t, device="cpu")
buffer_tensor_list.append((recv_t, recv_b))
else:
recv_b = recv_t
buffer_tensor_list.append((recv_t, None))
p2p_op_works.extend([
dist.isend(send_t, send_i, group=group),
dist.irecv(recv_b, recv_i, group=group),
])
work = BatchWork(p2p_op_works, buffer_tensor_list, 2)
output_tensor_list[rank].copy_(input_tensor_list[rank])
if async_op:
return work
work.wait()
def all_to_all_s(
output_tensor: Tensor,
input_tensor: Tensor,
output_rowptr: List[int],
input_rowptr: List[int],
group: Optional[Any] = None,
async_op: bool = False,
):
# rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
assert len(output_rowptr) == len(input_rowptr)
assert len(output_rowptr) == world_size + 1
output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])]
input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])]
return dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
output_split_sizes=output_sizes,
input_split_sizes=input_sizes,
group=group,
async_op=async_op,
)
def batch_send(
*tensors: Tensor,
dst: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op:
works = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
works.append(dist.isend(t, dst=dst, group=group))
return BatchWork(works, None)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
dist.send(t, dst=dst, group=group)
def batch_recv(
*tensors: Tensor,
src: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op:
works = []
output_tensor_list = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
works.append(dist.irecv(b, src=src, group=group))
else:
b = None
works.append(dist.irecv(t, src=src, group=group))
output_tensor_list.append((t, b))
return BatchWork(works, output_tensor_list, 1)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
dist.recv(b, src=src, group=group)
t.copy_(b)
else:
dist.recv(t, src=src, group=group)
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
import socket
from contextlib import contextmanager
import logging
from .rpc import *
__all__ = [
"DistributedContext",
]
class DistributedContext:
"""Global context manager for distributed training
"""
@staticmethod
def init(
backend: str,
use_rpc: bool = False,
use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext':
if DistributedContext.is_initialized():
raise RuntimeError("not allowed to call init method twice.")
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is None:
logging.warning(f"LOCAL_RANK has not been set, using the default value 0.")
os.environ["LOCAL_RANK"] = local_rank = "0"
local_rank = int(local_rank)
backend = backend.lower()
if use_gpu is None:
use_gpu = False
if backend == "nccl" or backend == "mpi":
use_gpu = True
else:
use_gpu = bool(use_gpu)
if rpc_gpu is None:
rpc_gpu = use_gpu
else:
rpc_gpu = bool(rpc_gpu)
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext(
backend=backend,
ccl_init_method=ccl_init_url,
rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size,
local_rank=local_rank,
use_rpc=use_rpc,
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
return ctx
@staticmethod
def get_default_context() -> 'DistributedContext':
ctx = _get_default_dist_context()
if ctx is None:
raise RuntimeError("please call the init method first.")
return ctx
@staticmethod
def is_initialized() -> bool:
return _get_default_dist_context() is not None
def __init__(self,
backend: str,
ccl_init_method: str,
rpc_init_method: str,
rank: int, world_size: int,
local_rank: int,
use_rpc: bool,
use_gpu: bool,
rpc_gpu: bool,
) -> None:
if use_gpu:
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
init_method=ccl_init_method,
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method
if use_gpu and rpc_gpu:
device_maps = torch.zeros(world_size, dtype=torch.long, device=device)
device_maps[rank] = local_rank
dist.all_reduce(device_maps, op=dist.ReduceOp.SUM)
for i, dev in enumerate(device_maps.tolist()):
rpc_backend_options.set_device_map(
to=f"worker{i}",
device_map={local_rank: dev},
)
if use_rpc:
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
self._use_rpc = use_rpc
self._local_rank = local_rank
self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self):
if self._use_rpc:
rpc.shutdown()
@property
def rank(self) -> int:
return dist.get_rank()
@property
def world_size(self) -> int:
return dist.get_world_size()
@property
def local_rank(self) -> int:
return self._local_rank
@property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property
def device(self) -> torch.device:
return self._compute_device
def get_default_group(self):
# return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self):
return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
def remote_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_call(method, rref, *args, **kwargs)
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs)
def remote_exec(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_exec(method, rref, *args, **kwargs)
@contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None
stream.wait_stream(torch.cuda.current_stream(self.device))
with torch.cuda.stream(stream):
yield event
if with_event:
event.record()
def all_gather_remote_objects(self, obj: Any) -> List[rpc.RRef]:
if not isinstance(obj, rpc.RRef):
obj = rpc.RRef(obj)
self.__temp_ag_remote_object = obj
dist.barrier()
futs: List[torch.futures.Future] = []
for i in range(self.world_size):
info = rpc.get_worker_info(f"worker{i}")
futs.append(rpc.rpc_async(info, DistributedContext._remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
self.__temp_ag_remote_object = None
return rrefs
@staticmethod
def _remote_object():
ctx = DistributedContext.get_default_context()
return ctx.__temp_ag_remote_object
def sync_print(self, *args, **kwargs):
for i in range(self.world_size):
if i == self.rank:
print(f"rank {self.rank}:", *args, **kwargs)
dist.barrier()
def main_print(self, *args, **kwargs):
if self.rank == 0:
print(*args, **kwargs)
dist.barrier()
_DEFAULT_DIST_CONTEXT: Optional['DistributedContext'] = None
def _get_default_dist_context():
global _DEFAULT_DIST_CONTEXT
return _DEFAULT_DIST_CONTEXT
def _set_default_dist_context(ctx):
global _DEFAULT_DIST_CONTEXT
_DEFAULT_DIST_CONTEXT = ctx
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
__all__ = [
"rpc_remote_call",
"rpc_remote_void_call",
"rpc_remote_exec"
]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_call, args=args, kwargs=kwargs)
def rpc_method_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_void_call, args=args, kwargs=kwargs)
def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
method(self, *args, **kwargs) # return None
def rpc_remote_exec(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_exec, args=args, kwargs=kwargs)
@rpc.functions.async_execution
def rpc_method_exec(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
import random
import pandas as pd
import numpy as np
import os
import torch
from torch_geometric.data import Data
from starrygl.sample.graph_core import DataSet, DistributedGraphStore
def get_link_prediction_data(data_name: str, val_ratio, test_ratio):
"""
generate data for link prediction task (inductive & transductive settings)
:param dataset_name: str, dataset name
:param val_ratio: float, validation data ratio
:param test_ratio: float, test data ratio
:return: node_raw_features, edge_raw_features, (np.ndarray),
full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
"""
# Load data and train val test split
#graph_df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
# n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
#else:
# n_feat = None
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
# e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
#else:
# e_feat = None
#
## get the timestamp of validate and test set
#src_node_ids = torch.from_numpy(np.array(graph_df.src.values)).long()
#dst_node_ids = torch.from_numpy(np.array(graph_df.dst.values)).long()
#node_interact_times = torch.from_numpy(np.array(graph_df.time.values)).long()
#
#train_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 0)
#test_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 1)
#val_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 2)
# the setting of seed follows previous works
graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(data_name, data_name))
edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(data_name, data_name))
node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(data_name, data_name))
NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {data_name} is bigger than {NODE_FEAT_DIM}!'
assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {data_name} is bigger than {EDGE_FEAT_DIM}!'
# padding the features of edges and nodes to the same dimension (172 for all the datasets)
if node_raw_features.shape[1] < NODE_FEAT_DIM:
node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)
e_feat = edge_raw_features
n_feat = torch.from_numpy(node_raw_features.astype(np.float32))
e_feat = torch.from_numpy(edge_raw_features.astype(np.float32))
assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], 'Unaligned feature dimensions after feature padding!'
# get the timestamp of validate and test set
val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
src_node_ids = torch.from_numpy(graph_df.u.values.astype(np.longlong))
dst_node_ids = torch.from_numpy(graph_df.i.values.astype(np.longlong))
node_interact_times = torch.from_numpy(graph_df.ts.values.astype(np.float32))
#edge_ids = torch.from_numpy(graph_df.idx.values.astype(np.longlong))
labels = torch.from_numpy(graph_df.label.values)
unique_node_ids = torch.cat((src_node_ids,dst_node_ids)).unique()
train_mask = node_interact_times <= val_time
val_mask = ((node_interact_times > val_time)&(node_interact_times <= test_time))
test_mask = (node_interact_times > test_time)
torch.manual_seed(2020)
train_node_set = torch.cat((src_node_ids[train_mask],dst_node_ids[train_mask])).unique()
test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time]))
new_test_node_set = set(random.sample(test_node_set, int(0.1 * unique_node_ids.shape[0])))
new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
# mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
observed_edges_mask = torch.from_numpy(np.logical_and(~new_test_source_mask, ~new_test_destination_mask)).long()
train_mask = (train_mask & observed_edges_mask)
mask = torch.isin(unique_node_ids,train_node_set,invert = True)
new_node_set = unique_node_ids[mask]
edge_contains_new_node_mask = (torch.isin(src_node_ids,new_node_set) | torch.isin(dst_node_ids,new_node_set))
new_node_val_mask = (val_mask & edge_contains_new_node_mask)
new_node_test_mask = (test_mask & edge_contains_new_node_mask)
full_data = Data()
full_data.edge_index = torch.stack((src_node_ids,dst_node_ids))
sample_graph = {}
sample_src = torch.cat([src_node_ids.view(-1, 1), dst_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst_node_ids.view(-1, 1), src_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([node_interact_times.view(-1, 1), node_interact_times.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(full_data.edge_index.shape[1]).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
sample_graph['train_mask'] = train_mask
sample_graph['val_mask'] = val_mask
sample_graph['test_mask'] = val_mask
sample_graph['new_node_val_mask'] = new_node_val_mask
sample_graph['new_node_test_mask'] = new_node_test_mask
print(unique_node_ids.max().item(),unique_node_ids.shape[0])
full_data.num_nodes = int(unique_node_ids.max().item())+1
full_data.num_edges = node_interact_times.shape[0]
full_data.sample_graph = sample_graph
full_data.x = n_feat
full_data.edge_attr = e_feat
full_data.y = labels
full_data.edge_ts = node_interact_times
full_data.train_mask = train_mask
full_data.val_mask = val_mask
full_data.test_mask = test_mask
full_data.new_node_val_mask = new_node_val_mask
full_data.new_node_test_mask = new_node_test_mask
return full_data
#full_graph = DistributedGraphStore(full_data, device, uvm_node, uvm_edge)
#train_data = torch.masked_select(full_data.edge_index,train_mask.to(device)).reshape(2,-1)
#train_ts = torch.masked_select(full_data.edge_ts,train_mask.to(device))
#val_data = torch.masked_select(full_data.edge_index,val_mask.to(device)).reshape(2,-1)
#val_ts = torch.masked_select(full_data.edge_ts,val_mask.to(device))
#test_data = torch.masked_select(full_data.edge_index,test_mask.to(device)).reshape(2,-1)
#test_ts = torch.masked_select(full_data.edge_ts,test_mask.to(device))
##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
#train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(train_mask).view(-1))
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(test_mask).view(-1))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(val_mask).view(-1))
#new_node_val_data = torch.masked_select(full_data.edge_index,new_node_val_mask.to(device)).reshape(2,-1)
#new_node_val_ts = torch.masked_select(full_data.edge_ts,new_node_val_mask.to(device))
#new_node_test_data = torch.masked_select(full_data.edge_index,new_node_test_mask.to(device)).reshape(2,-1)
#new_node_test_ts = torch.masked_select(full_data.edge_ts,new_node_test_mask.to(device))
#return full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the link prediction task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
average_precision = average_precision_score(y_true=labels, y_score=predicts)
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'average_precision': average_precision, 'roc_auc': roc_auc}
def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the node classification task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'roc_auc': roc_auc}
import torch
import dgl
from os.path import abspath, join, dirname
import sys
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import yaml
import numpy as np
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
from .route import *
from .timeline import SequencePipe
from .layerpipe import LayerPipe, LayerDetach
from .sparse import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from typing import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .route import Route, RouteWork
from .timeline.utils import vector_backward
from .utils import *
__all__ = [
"LayerPipe",
"LayerDetach",
]
class LayerPipe(ABC):
def __init__(self) -> None:
self._layer_id: Optional[int] = None
self._snapshot_id: Optional[int] = None
self._rts: List[LayerPipeRuntime] = []
@property
def layer_id(self) -> int:
assert self._layer_id is not None
return self._layer_id
@property
def snapshot_id(self) -> int:
assert self._snapshot_id is not None
return self._snapshot_id
def apply(self,
num_layers: int,
num_snapshots: int,
) -> Sequence[Sequence[Tensor]]:
runtime = LayerPipeRuntime(num_layers, num_snapshots, self)
self._rts.append(runtime)
return runtime.forward()
def backward(self):
for runtime in self._rts:
runtime.backward()
self._rts.clear()
def all_reduce(self, async_op: bool = False):
works = []
for _, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
if key in {"layer_id", "snapshot_id"}:
continue
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
def parameters(self):
params: List[nn.Parameter] = []
for name, m in self.get_model():
params.extend(m.parameters())
return params
def register_route(self, *xs: Tensor):
for t in xs:
t.requires_route = True
@abstractmethod
def get_route(self) -> Route:
raise NotImplementedError
@abstractmethod
def layer_inputs(self,
inputs: Optional[Sequence[Tensor]] = None,
) -> Sequence[Tensor]:
raise NotImplementedError
@abstractmethod
def layer_forward(self,
inputs: Sequence[Tensor],
) -> Sequence[Tensor]:
raise NotImplementedError
@contextmanager
def _switch_layer(self,
layer_id: int,
snapshot_id: int,
):
saved_layer_id = self._layer_id
saved_snapshot_id = self._snapshot_id
self._layer_id = layer_id
self._snapshot_id = snapshot_id
try:
yield
finally:
self._layer_id = saved_layer_id
self._snapshot_id = saved_snapshot_id
class LayerPipeRuntime:
def __init__(self,
num_layers: int,
num_snapshots: int,
program: LayerPipe,
) -> None:
self.num_layers = num_layers
self.num_snapshots = num_snapshots
self.program = program
self.ready_bw: Dict[Any, Union[LayerDetach, LayerRoute]] = {}
def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(xs)
route = self.program.get_route()
self.ready_bw[(layer_i, snap_i, 0)] = LayerRoute(route, *xs)
elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_forward(xs)
self.ready_bw[(layer_i, snap_i, 1)] = LayerDetach(*xs)
xs = []
for snap_i in range(self.num_snapshots):
layer_i = self.num_layers - 1
xs.append(self.ready_bw[(layer_i, snap_i, 1)].values())
return xs
def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
self.ready_bw[(layer_i, snap_i, 0)].backward()
elif op == "comp":
if layer_i + 1 < self.num_layers:
self.ready_bw.pop((layer_i + 1, snap_i, 0)).wait_gradients()
self.ready_bw.pop((layer_i, snap_i, 1)).backward()
for snap_i in range(self.num_snapshots):
self.ready_bw.pop((0, snap_i, 0)).wait_gradients()
assert len(self.ready_bw) == 0
class LayerDetach:
def __init__(self,
*inputs: Tensor,
) -> None:
outputs = tuple(t.detach() for t in inputs)
for s, t in zip(inputs, outputs):
t.requires_grad_(s.requires_grad)
self._inputs = inputs
self._outputs = outputs
def values(self) -> Sequence[Tensor]:
return tuple(self._outputs)
def backward(self) -> None:
vec_loss, vec_grad = [], []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
if not s.requires_grad:
continue
vec_loss.append(s)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class LayerRoute:
def __init__(self,
route: Route,
*inputs: Tensor,
) -> None:
self._route = route
self._works: Optional[List[Union[Tensor, RouteWork]]] = []
for t in inputs:
r = t.requires_route if hasattr(t, "requires_route") else False
if r:
self._works.append(self._route.fw_tensor(t, async_op=True))
else:
self._works.append(t.detach())
self._inputs = inputs
self._outputs: Optional[List[Tensor]] = None
def values(self) -> Sequence[Tensor]:
if self._outputs is None:
works, self._works = self._works, None
assert works is not None
outputs = []
for s, t in zip(self._inputs, works):
if isinstance(t, RouteWork):
t = t.wait()
t = t.requires_grad_(s.requires_grad)
outputs.append(t)
self._outputs = outputs
return self._outputs
def backward(self):
assert self._works is None
assert self._outputs is not None
works = []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
rs = s.requires_route if hasattr(s, "requires_route") else False
rg = s.requires_grad
if rg and rs:
works.append(self._route.bw_tensor(g, async_op=True))
elif rg:
works.append(g)
else:
works.append(None)
self._works = works
self._outputs = None
def wait_gradients(self):
if self._works is None:
return
works, self._works = self._works, None
vec_loss, vec_grad = [], []
for t, g in zip(self._inputs, works):
if isinstance(g, RouteWork):
g = g.wait()
if not t.requires_grad:
continue
vec_loss.append(t)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
# starting
if self._num_snapshots > 1:
yield "sync", 0, 0
yield "sync", 0, 1
elif self._num_snapshots > 0:
yield "sync", 0, 0
for i in range(0, self._num_snapshots, 2):
for l in range(self._num_layers):
# snapshot i
yield "comp", l, i
if l + 1 < self._num_layers:
yield "sync", l + 1, i
elif i + 2 < self._num_snapshots:
yield "sync", 0, i + 2
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
if l + 1 < self._num_layers:
yield "sync", l + 1, i + 1
elif i + 3 < self._num_snapshots:
yield "sync", 0, i + 3
class BackwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
for i in range(0, self._num_snapshots, 2):
for j in range(self._num_layers):
l = self._num_layers - j - 1
# snapshot i
yield "comp", l, i
yield "sync", l, i
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
yield "sync", l, i + 1
from typing import Any
import torch
import torch.distributed as dist
import torch.autograd as autograd
from torch import Tensor
from typing import *
from torch_sparse import SparseTensor
__all__ = [
"SparseBlocks",
]
class SparseBlocks:
@staticmethod
def from_raw_indices(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Optional[Tensor] = None,
edge_attr: Optional[Tensor] = None,
group: Any = None,
) -> 'SparseBlocks':
assert edge_index.dim() == 2 and edge_index.size(0) == 2
if src_ids is None:
src_ids = dst_ids
src_ids, src_ptr = SparseBlocks.__fetch_ids_sizes(src_ids, group=group)
adj_ts = SparseBlocks.__remap_adj_t(dst_ids, edge_index, src_ids, src_ptr, edge_attr)
return SparseBlocks(adj_ts, group=group)
def __init__(self, adj_ts: List[SparseTensor], group: Any) -> None:
self._adj_ts = adj_ts
self._group = group
def adj_t(self, i: int) -> SparseTensor:
return self._adj_ts[i]
@property
def group(self):
return self._group
@property
def part_id(self) -> int:
return dist.get_rank(self._group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self._group)
def apply(self, x: Tensor) -> Tensor:
return SparseBlockMM.apply(self, x)
@staticmethod
def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1
if group is None:
group = dist.GroupMember.WORLD
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device)
all_lens: List[int] = [None] * world_size
dist.all_gather_object(all_lens, local_ids.numel(), group=group)
# all reduce num_nodes
num_nodes = torch.zeros(1, **ikw)
if local_ids.numel() > 0:
num_nodes = local_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes: int = num_nodes.item() + 1
# async fetch remote ids
all_ids: List[Tensor] = [None] * world_size
all_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_ids[i] = local_ids
else:
all_ids[i] = torch.empty(all_lens[i], **ikw)
src = dist.get_global_rank(group, i)
all_get[i] = dist.broadcast(
all_ids[i], src=src, async_op=True, group=group
)
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
offset: int = 0
for i in range(world_size):
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_get[i].wait()
ids = all_ids[i]
assert (imp[ids] == (2**62-1)*2+1).all(), "all ids must be orthogonal."
imp[ids] = torch.arange(offset, offset + all_lens[i], **ikw)
offset += all_lens[i]
assert (imp != (2**62-1)*2+1).all(), "some points that do not exist."
ids = torch.cat(all_ids, dim=0)
ptr: List[int] = [0]
for s in all_lens:
ptr.append(ptr[-1] + s)
return ids, ptr
@staticmethod
def __remap_adj_t(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Tensor,
src_ptr: List[int],
edge_attr: Optional[Tensor],
) -> List[SparseTensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
imp: Tensor = torch.full((dst_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[dst_ids] = torch.arange(dst_ids.numel(), **ikw)
dst = imp[edge_index[1]]
assert (dst != (2**62-1)*2+1).all()
imp: Tensor = torch.full((src_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[src_ids] = torch.arange(src_ids.numel(), **ikw)
src = imp[edge_index[0]]
assert (src != (2**62-1)*2+1).all()
edge_index = torch.vstack([src, dst])
adj = SparseTensor.from_edge_index(
edge_index=edge_index,
edge_attr=edge_attr,
sparse_sizes=(src_ids.numel(), dst_ids.numel()),
)
adj_ts: List[SparseTensor] = []
for s, t in zip(src_ptr, src_ptr[1:]):
adj_ts.append(adj[s:t].t())
return adj_ts
class SparseBlockMM(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
sp: SparseBlocks,
x: Tensor,
):
part_id = sp.part_id
num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1)
if i == part_id:
h = x.clone()
else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
src = dist.get_global_rank(group, i)
return dist.broadcast(h, src=src, group=sp.group, async_op=True)
last_work = None
out = None
for i in range(num_parts):
if i == 0:
work = async_fetch(0)
else:
work = last_work
if i + 1 < sp.num_parts:
last_work = async_fetch(i + 1)
work.wait()
h, = work.result()
if out is None:
out = sp.adj_t(i) @ h
else:
out += sp.adj_t(i) @ h
ctx.saved_sp = sp
return out
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
sp: SparseBlocks = ctx.saved_sp
part_id = sp.part_id
num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_reduce(i: int, g: Tensor):
dst = dist.get_global_rank(group, i)
return dist.reduce(
g, dst=dst, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True,
)
out = None
last_work = None
for i in range(num_parts):
g = sp.adj_t(i).t() @ grad
if i > 0:
last_work.wait()
last_work = async_reduce(i, g)
if i == part_id:
out = g
if last_work is not None:
last_work.wait()
return None, out
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"SyncBatchNorm",
]
class SyncBatchNorm(nn.Module):
def __init__(self,
num_features: Union[int, torch.Size],
eps: float = 1e-5,
momentum: float = 0.1,
) -> None:
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.num_features = num_features
self.eps = eps
self.momentum = momentum
def forward(self, x: Tensor) -> Tensor:
return SyncBatchNormFunction.apply(
x,
self.running_mean, self.running_var,
self.weight, self.bias,
self.training, self.momentum, self.eps,
)
def reset_parameters(self):
self.running_mean.data.fill_(0.0)
self.running_var.data.fill_(1.0)
self.weight.data.fill_(1.0)
self.bias.data.fill_(0.0)
@classmethod
def convert_sync_batchnorm(cls, net: nn.Module) -> nn.Module:
if isinstance(net, nn.modules.batchnorm._BatchNorm):
new_net = SyncBatchNorm(
num_features=net.num_features,
eps=net.eps,
momentum=net.momentum,
).to(net.weight.device)
new_net.weight.data[:] = net.weight.data
new_net.bias.data[:] = net.bias.data
new_net.get_buffer("running_mean").data[:] = net.get_buffer("running_mean").data
new_net.get_buffer("running_var").data[:] = net.get_buffer("running_var").data
return new_net
else:
for name, child in list(net.named_children()):
net.add_module(name, cls.convert_sync_batchnorm(child))
return net
class SyncBatchNormFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
running_mean: Tensor,
running_var: Tensor,
weight: Tensor,
bias: Tensor,
training: bool,
momentum: float,
eps: float,
):
if not training:
mean = running_mean
var = running_var
else:
ws = torch.zeros(1, dtype=x.dtype, device=x.device) + x.size(0)
ws_req = dist.all_reduce(ws, op=dist.ReduceOp.SUM, async_op=True)
if x.size(0) > 0:
sum_x = x.sum(dim=0)
else:
sum_x = torch.zeros(
size=(1,) + x.shape[1:],
dtype=x.dtype,
device=x.device,
)
sum_x_req = dist.all_reduce(sum_x, op=dist.ReduceOp.SUM, async_op=True)
if x.size(0) > 0:
sum_x2 = (x**2).sum(dim=0)
else:
sum_x2 = torch.zeros(
size=(1,) + x.shape[1:],
dtype=x.dtype,
device=x.device,
)
sum_x2_req = dist.all_reduce(sum_x2, op=dist.ReduceOp.SUM, async_op=True)
ws_req.wait()
whole_size = ws.item()
sum_x_req.wait()
mean = sum_x / whole_size
sum_x2_req.wait()
# var = sum_x2 / whole_size - mean ** 2
var = (sum_x2 - mean * sum_x) / whole_size
running_mean.mul_(1 - momentum).add_(mean * momentum)
running_var.mul_(1 - momentum).add_(var * momentum)
std = torch.sqrt(var + eps)
x_hat = (x - mean) / std
if training:
ctx.save_for_backward(x_hat, weight, std)
ctx.whole_size = whole_size
return x_hat * weight + bias
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
x_hat, weight, std = ctx.saved_tensors
dbias = grad.sum(dim=0)
dbias_req = dist.all_reduce(dbias, op=dist.ReduceOp.SUM, async_op=True)
dweight = (grad * x_hat).sum(dim=0)
dweight_req = dist.all_reduce(dweight, op=dist.ReduceOp.SUM, async_op=True)
dbias_req.wait()
dweight_req.wait()
n = ctx.whole_size
dx = (weight / n) / std * (n * grad - dbias - x_hat * dweight)
return dx, None, None, dweight, dbias, None, None, None
from .pipe import SequencePipe
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
def vector_backward(
vec_loss: Sequence[Tensor],
vec_grad: Sequence[Tensor],
):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vec_loss, vec_grad):
if g is None:
continue
if not x.requires_grad:
continue
loss.append(x.flatten())
grad.append(g.flatten())
if loss:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from collections import defaultdict
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
# def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
# works = []
# for p in net.parameters():
# if p.grad is None:
# p.grad = torch.zeros_like(p.data)
# w = dist.all_reduce(p.grad, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
# def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
# works = []
# for b in net.buffers():
# w = dist.all_reduce(b.data, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.parameters():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_tensors[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.parameters():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
if p.grad is not None:
typed_tensors[p.dtype][s:t] = p.grad.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
g = torch.empty(0, dtype=p.dtype, device=device)
p.grad = g.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.buffers():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_numel[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.buffers():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
typed_tensors[p.dtype][s:t] = p.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
p.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def default_stride(*size: int) -> Tuple[int,...]:
dims = len(size)
stride = [1] * dims
for i in range(1, dims):
k = dims - i
stride[k - 1] = stride[k] * size[k]
return tuple(stride)
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for i,mfg in enumerate(mfgs):
for b in mfg:
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
#print(b.edata['ID'],b.edata['dt'],b.srcdata['ID'])
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
#print(len(sample_out),eid,eid_len)
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
#print(eid_tensor)
dist_eid = eid_mapper[eid_tensor>>1].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
#print(src_node,graph.sample_graph['edge_index'][1,eid_tensor])
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
#print(src_ts,root_node)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
#sprint(nid_tensor)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = (eid_tensor>>1)
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
#print(row,col[sample_out[r].src_index()])
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
#print(idx,e_idx)
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
#print(b.all_edges)
#print(dist_nid[b.srcdata['ID']],dist_nid[b.srcdata['ID'][col[sample_out[r].src_index().to(device)]]])
#print(b.edata['dt'],b.srcdata['ts'])
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class LRUCache(Cache):
"""
Least-recently-used (LRU) cache
"""
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(LRUCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'lru'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_feature: Sequence[torch.Tensor]):
if len(uncached_index) > self.capacity:
num_to_cache = self.capacity
else:
num_to_cache = len(uncached_index)
node_id_to_cache = uncached_index[:num_to_cache].to(torch.int32)
self.now_cache_count -= 1
self.cache_count[cached_index] = 0
# get the k node id with the least water level
removing_cache_index = torch.topk(
self.cache_count, k=num_to_cache, largest=False).indices.to(torch.int32)
removing_node_id = self.cache_index_to_id[removing_cache_index]
# update cache attributes
for buffer,data in zip(self.buffers,uncached_feature):
buffer[removing_cache_index] = data[:num_to_cache].reshape(-1,*buffer.shape[1:])
self.cache_count[removing_cache_index] = 0
self.cache_validate[removing_node_id] = False
self.cache_validate[node_id_to_cache] = True
self.cache_map[removing_node_id] = -1
self.cache_map[node_id_to_cache] = removing_cache_index
self.cache_index_to_id[removing_cache_index] = node_id_to_cache
from typing import Callable, List, Optional, Sequence, Union
import numpy as np
import torch
from starrygl.distributed.utils import DistributedTensor
class Cache:
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
print(len(cache_data),cache_data)
assert torch.cuda.is_available() == True
self.use_local = use_local
self.is_update_cache = is_update_cache
self.device = torch.device('cuda')
self.use_remote = torch.distributed.get_world_size()>1
assert not (self.use_local is False and self.use_remote is False),\
"the data is on the cuda and no need remote cache"
self.cache_ratio = cache_ratio
self.num_cache = num_cache
self.capacity = int(self.num_cache * cache_ratio)
self.update_stream = torch.cuda.Stream()
self.buffers = []
self.pinned_buffers = []
for data in cache_data:
self.buffers.append(
torch.zeros(self.capacity,*data.shape[1:],
dtype = data.dtype,device = torch.device('cuda'))
)
self.cache_validate = torch.zeros(
num_cache, dtype=torch.bool, device=self.device)
# maps node id -> index
self.cache_map = torch.zeros(
num_cache, dtype=torch.int32, device=self.device) - 1
# maps index -> node id
self.cache_index_to_id = torch.zeros(
num_cache,dtype=torch.int32, device=self.device) -1
self.hit_sum = 0
self.hit_ = 0
def init_cache(self,ind:torch.Tensor,data:Sequence[torch.Tensor]):
pos = torch.arange(ind.shape[0],device = 'cuda',dtype = ind.dtype)
self.cache_map[ind] = pos.to(torch.int32).to('cuda')
self.cache_index_to_id[pos] = ind.to(torch.int32).to('cuda')
for data,buffer in zip(data,self.buffers):
buffer[:ind.shape[0],] = data
self.cache_validate[ind] = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_data: Sequence[torch.Tensor]):
raise NotImplementedError
def fetch_data(self,ind:Optional[torch.Tensor] = None,
uncached_source_fn: Callable = None, source_index:torch.Tensor = None):
self.hit_sum += ind.shape[0]
assert isinstance(ind, torch.Tensor)
cache_mask = self.cache_validate[ind]
uncached_mask = ~cache_mask
self.hit_ += torch.sum(cache_mask)
cached_data = []
cached_index = self.cache_map[ind[cache_mask]]
if uncached_mask.sum() > 0:
uncached_id = ind[uncached_mask]
source_index = source_index[uncached_mask]
uncached_feature = uncached_source_fn(source_index)
if isinstance(uncached_feature,torch.Tensor):
uncached_feature = [uncached_feature]
else:
uncached_id = None
uncached_feature = [None for _ in range(len(self.buffers))]
for data,uncached_data in zip(self.buffers,uncached_feature):
nfeature = torch.zeros(
len(ind), *data.shape[1:], dtype=data.dtype,device=self.device)
nfeature[cache_mask,:] = data[cached_index]
if uncached_id is not None:
nfeature[uncached_mask] = uncached_data.reshape(-1,*data.shape[1:])
cached_data.append(nfeature)
if self.is_update_cache and uncached_mask.sum() > 0:
self.update_cache(cached_index=cached_index,
uncached_index=uncached_id,
uncached_feature=uncached_feature)
return nfeature
def invalidate(self,ind):
self.cache_validate[ind] = False
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class StaticCache(Cache):
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(StaticCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'static'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = False
import torch
#dataloader不要加import
def pre_sample(dataloader, num_epoch:int,node_size:int,edge_size:int):
nodes_counts = torch.zeros(dataloader.graph.num_nodes,dtype = torch.long)
edges_counts = torch.zeros(dataloader.graph.num_edges,dtype = torch.long)
print(nodes_counts.shape,edges_counts.shape)
sampler = dataloader.sampler
neg_sampling = dataloader.neg_sampler
sample_fn = dataloader.sampler_fn
graph = dataloader.graph
for _ in range(num_epoch):
dataloader.__iter__()
while dataloader.recv_idxs < dataloader.expected_idx:
dataloader.recv_idxs += 1
data = dataloader._next_data()
out = sample_fn(sampler,data,neg_sampling)
if(len(out)>0):
sample_out,metadata = out
else:
sample_out = out
eid = [ret.eid() for ret in sample_out]
eid_tensor = torch.cat(eid,dim = 0)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
dst_node = graph.sample_graph['edge_index'][1,eid_tensor*2].to(graph.nids_mapper.device)
eid_tensor = torch.unique(eid_tensor)
nid_tensor = torch.unique(torch.cat((src_node,dst_node)))
edges_counts[eid_tensor] += 1
nodes_counts[nid_tensor] += 1
return nodes_counts,edges_counts
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
torch.cuda.synchronize()
return batch_data
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
from starrygl.utils.uvm import *
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
ctx = DistributedContext.get_default_context()
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
edge_attr = uvm_share(edge_attr,device = torch.device('cpu'))
edge_attr.copy_(pdata.edge_attr)
edge_attr = uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
from starrygl.distributed.context import DistributedContext
from typing import *
import torch.distributed as dist
import torch
from starrygl.distributed.utils import DistIndex
def build_mapper(nids):
rank = dist.get_rank()
world_size = dist.get_world_size()
dst_len = nids.size(0)
ikw = dict(dtype=torch.long, device=nids.device)
num_nodes = torch.zeros(1, **ikw)
num_nodes[0] = dst_len
dist.all_reduce(num_nodes, op=dist.ReduceOp.SUM)
all_ids: List[torch.Tensor] = [None] * world_size
dist.all_gather_object(all_ids,nids)
part_mp = torch.empty(num_nodes,**ikw)
ind_mp = torch.empty(num_nodes,**ikw)
for i in range(world_size):
iid = all_ids[i]
part_mp[iid] = i
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp)
def get_validate_graph(self,graph):
pass
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class EvaluateNegativeSampling(NegativeSampling):
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
src_node_ids: torch.Tensor,
dst_node_ids: torch.Tensor,
interact_times: torch.Tensor = None,
last_observed_time: float = None,
negative_sample_strategy: str = 'random',
seed: int = None
):
super(EvaluateNegativeSampling,self).__init__(mode)
self.seed = seed
self.negative_sample_strategy = negative_sample_strategy
self.src_node_ids = src_node_ids
self.dst_node_ids = dst_node_ids
self.interact_times = interact_times
self.unique_src_nodes_id = src_node_ids.unique()
self.unique_dst_nodes_id = dst_node_ids.unique()
self.src_id_mapper = torch.zeros(self.unique_src_nodes_id[-1])
self.dst_id_mapper = torch.zeros(self.unique_dst_nodes_id[-1])
self.src_id_mapper[self.unique_src_nodes_id] = torch.arange(self.unique_src_nodes_id.shape[0])
self.dst_id_mapper[self.unique_dst_nodes_id] = torch.arange(self.unique_dst_nodes_id.shape[0])
self.unique_interact_times = self.interact_times.unique()
self.earliest_time = self.unique_interact_times.min().item()
self.last_observed_time = last_observed_time
if self.negative_sample_strategy == 'inductive':
# set of observed edges
self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time)
if self.seed is not None:
self.random_state = torch.Generator()
self.random_state.manual_seed(seed)
else:
self.random_state = torch.Generator()
def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float):
selected_mask = ((self.interact_times >= start_time) and (self.interact_times <= end_time))
# return the unique select source and destination nodes in the selected time interval
return torch.cat((self.src_node_ids[selected_mask],self.dst_node_ids[selected_mask]),dim = 1)
def sample(self, num_samples: int, num_nodes: Optional[int] = None, batch_src_node_ids: Optional[torch.Tensor] = None,
batch_dst_node_ids: Optional[torch.Tensor] = None, current_batch_start_time: Optional[torch.Tensor] = None,
current_batch_end_time: Optional[torch.Tensor] = None) -> Tensor:
if self.negative_sample_strategy == 'random':
negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=num_samples)
elif self.negative_sample_strategy == 'historical':
negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
elif self.negative_sample_strategy == 'inductive':
negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
else:
raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!')
return negative_src_node_ids, negative_dst_node_ids
def random_sample(self, size: int):
if self.seed is None:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size)
else:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size, generate = self.random_state)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size, generate = self.random_state)
return self.unique_src_nodes_id[random_sample_edge_src_node_indices], self.unique_dst_nodes_id[random_sample_edge_dst_node_indices]
def random_sample_with_collision_check(self, size: int, batch_src_nodes_id:torch.Tensor, batch_dst_nodes_id:torch.Tensor):
batch_edge = torch.stack((batch_src_nodes_id,batch_dst_nodes_id))
batch_src_index = self.src_id_mapper[batch_src_nodes_id]
batch_dst_index = self.dst_id_mapper[batch_dst_nodes_id]
return_edge = torch.tensor([[],[]])
while(True):
src_ = torch.randint(0, len(self.unique_src_nodes_id), size*2)
dst_ = torch.randint(0, len(self.unique_dst_nodes_id), size*2)
edge = torch.stack((src_,dst_))
sample_id = src_*self.unique_dst_nodes_id.shape[0] + dst_
batch_id = batch_src_index * self.unique_dst_nodes_id.shape[0] + batch_dst_index
mask = torch.isin(sample_id,batch_id,invert = True)
edge = edge[:,mask]
if(edge.shape[1] >= size):
return_edge = torch.cat((return_edge,edge[:,:size]),1)
break
else:
return_edge = torch.cat((return_edge,edge),1)
size = size - edge.shape[1]
return return_edge
def historical_sample(self, size: int, batch_src_nodes_id: torch.Tensor, batch_dst_nodes_id: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]]] = True
mask = (~mask)
unique_historical_edges = uni[:,mask]
if size > unique_historical_edges.shape[1]:
num_random_sample_edges = size - len(unique_historical_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,batch_src_node_ids=batch_src_nodes_id,
batch_dst_node_ids=batch_dst_nodes_id)
sample_edges = torch.cat((unique_historical_edges,random_sample_edge),dim = 1)
else:
historical_sample_edge_node_indices = torch.randperm(unique_historical_edges.shape[1],generator=self.random_state)
sample_edges = unique_historical_edges[:,historical_sample_edge_node_indices[:size]]
return sample_edges
def inductive_sample(self, size: int, batch_src_node_ids: torch.Tensor, batch_dst_node_ids: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((self.observed_edges,current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]+historical_edges.shape[1]]] = True
mask = (~mask)
unique_inductive_edges = uni[:,mask]
if size > len(unique_inductive_edges):
num_random_sample_edges = size - len(unique_inductive_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids)
sample_edges = torch.cat((unique_inductive_edges,random_sample_edge),dim = 1)
else:
inductive_sample_edge_node_indices = torch.randperm(unique_inductive_edges.shape[1],generator=self.random_state)
sample_edges = unique_inductive_edges[:, inductive_sample_edge_node_indices[:size]]
return sample_edges
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class PreNegativeSampling(NegativeSampling):
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
neg_sample_list: torch.Tensor
):
super(PreNegativeSampling,self).__init__(mode)
self.neg_sample_list = neg_sample_list
self.next_pos = 0
def set_next_pos(self,pos):
self.next_pos = pos
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
neg_sample_out = self.neg_sample_list[
self.next_pos:self.next_pos+num_samples,:].reshape(-1)
self.next_pos = self.next_pos + num_samples
return neg_sample_out
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
import numpy as np
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
unique: bool
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
unique: bool = False
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
self.unique = unique
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class SampleOutput:
node: Optional[torch.Tensor] = None
edge_index_list: Optional[List[torch.Tensor]] = None
eid_list: Optional[List[torch.Tensor]] = None
delta_ts_list: Optional[List[torch.Tensor]] = None
metadata: Optional[Any] = None
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
with_outer_sample: SampleType,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2]
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_ts = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6]).double()
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
edge_weight1 = None
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=edge_ts, data=None, partptr=torch.tensor([0, num_nodes1]))
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
# from neighbor_sampler import get_neighbors, update_edge_weight
# row, col = edge_index1
# tnb = get_neighbors(row.contiguous(), col.contiguous(), num_nodes1, edge_weight1)
# print("tnb.neighbors:", tnb.neighbors)
# print("tnb.deg:", tnb.deg)
# print("tnb.weight:", tnb.edge_weight)
sampler = NeighborSampler(num_nodes1,
num_layers=1,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
is_distinct = 0,
policy="recent")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.ts:", sampler.tnb.timestamp)
print("tnb.weight:", sampler.tnb.edge_weight)
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.DoubleTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', tnb.edge_weight)
# print('begin update')
# pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([6,7]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([9, 9]))
end = time.time()
# print('node:', out.node)
# print('edge_index_list:', out.edge_index_list)
# print('eid_list:', out.eid_list)
# print('eid_ts_list:', out.eid_ts_list)
print("sample time:", end-pre)
print("tot_time", out[0].tot_time)
print("sam_time", out[0].sample_time)
print("sam_edge", out[0].sample_edge_num)
print('eid_list:', out[0].eid)
print('delta_ts_list:', out[0].delta_ts)
print((out[0].sample_nodes<10000).sum())
print('node:', out[0].sample_nodes)
print('node_ts:', out[0].sample_nodes_ts)
\ No newline at end of file
import torch
import time
from .Utils import GraphData
def test():
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2,2] # index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
src,dst = edge_index1
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index1 = torch.stack([row, col])
g_data = GraphData(id=0, edge_index=edge_index1, data=None, partptr=torch.tensor([0, num_nodes1]))
edge_weight1 = None
# g_data.eid=None
from .neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
sampler = NeighborSampler(num_nodes1,
num_layers=2,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.eid:", sampler.tnb.eid)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.weight:", sampler.tnb.edge_weight)
# row,col = edge_index1
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('begin update')
# pre = time.time()
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
end = time.time()
print('node1:\t', out[0].sample_nodes().tolist())
print('eid1:\t', out[0].eid().tolist())
print('edge1:\t', edge_index1[:, out[0].eid()].tolist())
print('node2:\t', out[1].sample_nodes().tolist())
print('eid2:\t', out[1].eid().tolist())
print('edge2:\t', edge_index1[:, out[1].eid()].tolist())
print("sample time:", end-pre)
if __name__ == "__main__":
test()
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
print('begin load')
pre = time.time()
timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
tnb = torch.load("tnb_before.my")
end = time.time()
print("load time:", end-pre)
row, col = g.edge_index
edge_weight=None
g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
# print('begin tnb')
# pre = time.time()
# tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes, 0, g_data.eid, edge_weight, timestamp)
# end = time.time()
# print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_before.my")
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb=tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
policy="uniform",
is_root_ts=0,
graph_name='a')
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
ts = torch.tensor([i%5+1 for i in range(0, 600000)])
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)),
ts=ts,
with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print("sample time", end-pre)
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from .Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
def test():
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
g_data = GraphData(id=1, edge_index=g.edge_index, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
row, col = g.edge_index
# edge_weight = torch.ones(g.num_edges).float()
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight[indices] = 2.0
# g_data.eid = None
edge_weight = None
timestamp = None
from .neighbor_sampler import NeighborSampler, SampleType
from .neighbor_sampler import get_neighbors
update_edge_row = row
update_edge_col = col
update_edge_w = torch.DoubleTensor([i for i in range(g.num_edges)])
# print('begin update')
# pre = time.time()
# # update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
print('begin tnb')
pre = time.time()
tnb = get_neighbors("a",
row.contiguous(),
col.contiguous(),
g.num_nodes, 0,
g_data.eid,
edge_weight,
timestamp)
end = time.time()
print("init tnb time:", end-pre)
# torch.save(tnb, "/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# print('begin load')
# pre = time.time()
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# end = time.time()
# print("load time:", end-pre)
print('begin init')
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb = tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node1:\t', out[0].sample_nodes())
print('eid1:\t', out[0].eid())
print('edge1:\t', g.edge_index[:, out[0].eid()])
print('node2:\t', out[1].sample_nodes())
print('eid2:\t', out[1].eid())
print('edge2:\t', g.edge_index[:, out[1].eid()])
print("sample time", end-pre)
if __name__ == "__main__":
test()
\ No newline at end of file
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling, SampleOutput
from neighbor_sampler import NeighborSampler, SampleType
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct, 0-don't need distinct
policy: "uniform" or "recent" or "weighted"
is_root_ts: 1-base on root's ts, 0-base on parent node's ts
edge_weight: the initial weights of edges
graph_name: the name of graph
"""
super().__init__()
self.sampler = NeighborSampler(
num_nodes=num_nodes,
tnb=tnb,
num_layers=num_layers,
fanout=[1 for i in range(num_layers)],
graph_data=graph_data,
edge_weight = edge_weight,
workers=workers,
policy=policy,
graph_name=graph_name,
is_distinct = is_distinct
)
self.num_layers = num_layers
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
ts: Optional[torch.Tensor] = None
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes, ts, with_outer_sample)
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, ets, neg_sampling, with_outer_sample)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
edge_weight1 = None
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
# Run the random walk sampling
sampler=RandomWalkSampler(num_nodes=num_nodes1,
num_layers=3,
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_root_ts=0,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([1, 2]))
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print('metadata: ', out.metadata)
import argparse
import random
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm
from .Utils import GraphData
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
class NegLinkInductiveSampler:
def __init__(self, nodes):
self.nodes = list(nodes)
def sample(self, n):
return np.random.choice(self.nodes, size=n)
def load_reddit_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
def load_gdelt_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("GDELT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
def test():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
dataset = "gdelt"#"reddit"#"gdelt"
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data, df = load_gdelt_dataset()
print(g_data)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
# print('begin load')
# pre = time.time()
# # timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
# tnb = torch.load("tnb_reddit_before.my")
# end = time.time()
# print("load time:", end-pre)
# row, col = g.edge_index
edge_weight=None
# g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from .neighbor_sampler import NeighborSampler, SampleType, get_neighbors
print('begin tnb')
row, col = g_data.edge_index
row = torch.cat([row, col])
col = torch.cat([col, row])
eid = torch.cat([g_data.eid, g_data.eid])
timestamp = torch.cat([g_data.edge_ts, g_data.edge_ts])
timestamp,ind = timestamp.sort()
timestamp = timestamp.float().contiguous()
eid = eid[ind].contiguous()
row = row[ind]
col = col[ind]
print(row, col)
g2 = GraphData(0, torch.stack([row, col]), timestamp=timestamp, data=None, partptr=torch.tensor([0, max(int(df['src'].max()), int(df['dst'].max())) + 1]))
print(g2)
pre = time.time()
tnb = get_neighbors(dataset, row.contiguous(), col.contiguous(), g_data.num_nodes, 0, eid, edge_weight, timestamp)
end = time.time()
print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_{}_before.my".format(dataset), pickle_protocol=4)
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
tnb=tnb,
num_layers=1,
fanout=[10],
graph_data=g_data,
workers=32,
policy="recent",
graph_name='a')
end = time.time()
print("init time:", end-pre)
# neg_link_sampler = NegLinkSampler(g_data.num_nodes)
from .base import NegativeSampling, NegativeSamplingMode
neg_link_sampler = NegativeSampling(NegativeSamplingMode.triplet)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
pre = time.time()
min_than_ten = 0
min_than_ten_sum = 0
seed_node_sum = 0
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
# root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long()
# ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32))
# outi = sampler.sample_from_nodes(root_nodes, ts=ts)
edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long()
outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler)
# min_than_ten += (torch.tensor(tnb.deg)[meta['seed']]<10).sum()
# min_than_ten_sum += ((torch.tensor(tnb.deg)[meta['seed']])[torch.tensor(tnb.deg)[meta['seed']]<10]).sum()
# seed_node_sum += meta['seed'].size(0)
tot_time += outi[0].tot_time
sam_time += outi[0].sample_time
# print(outi[0].sample_edge_num)
sam_edge += outi[0].sample_edge_num
# out.append(outi)
end = time.time()
# print("row", out[23][0].row())
print("sample time", end-pre)
print("tot_time", tot_time)
print("sam_time", sam_time)
print("sam_edge", sam_edge)
# print('eid_list:', out[23][0].eid())
# print('delta_ts_list:', out[10][0].delta_ts)
# print('node:', out[23][0].sample_nodes())
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('eid_list:', out[23][1].eid)
# print('node:', out[23][1].sample_nodes)
# print('node_ts:', out[23][1].sample_nodes_ts)
# print('edge_index_list:', out[0][0].edge_index)
# print("min_than_ten", min_than_ten)
# print("min_than_ten_sum", min_than_ten_sum)
# print("seed_node_sum", seed_node_sum)
# print("predict edge_num", (seed_node_sum-min_than_ten)*9+min_than_ten_sum)
print('吞吐量 : {:.4f}'.format(sam_edge/(end-pre)))
if __name__ == "__main__":
test()
\ No newline at end of file
import argparse
import random
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm
from .Utils import GraphData
def load_reddit_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=None, data=None, partptr=torch.tensor([0, num_nodes]))
return g
def load_gdelt_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("GDELT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=None, data=None, partptr=torch.tensor([0, num_nodes]))
return g
def load_ogb_dataset():
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name='ogbn-products', root="/home/zlj/hzq/code/gnn/dataset/")
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
src, dst = g.edge_index
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
g = GraphData(id=0, edge_index=edge_index, data=g, partptr=torch.tensor([0, g.num_nodes]))
return g # , node_data
def test():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
# parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
# parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data = load_ogb_dataset()
print(g_data)
from .neighbor_sampler import NeighborSampler, get_neighbors
# print('begin tnb')
# row, col = g_data.edge_index
# pre = time.time()
# tnb = get_neighbors("a",
# row.contiguous(),
# col.contiguous(),
# g_data.num_nodes, 0,
# g_data.eid,
# None,
# None)
# end = time.time()
# print("init tnb time:", end-pre)
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
num_layers=2,
fanout=[10,10],
graph_data=g_data,
workers=32,
policy="uniform",
graph_name='a',
is_distinct=0)
end = time.time()
print("init time:", end-pre)
print(sampler.tnb.deg[0:100])
n_list = []
for i in range (g_data.num_nodes.item() // args.batch_size):
if i+args.batch_size< g_data.num_nodes.item():
n_list.append(range(i*args.batch_size, i*args.batch_size+args.batch_size))
else:
n_list.append(range(i*args.batch_size, g_data.num_nodes.item()))
# print(n_list)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
sam_node = 0
pre = time.time()
for i, nodes in tqdm(enumerate(n_list), total=g_data.num_nodes.item() // args.batch_size):
# for nodes in n_list:
root_nodes = torch.tensor(nodes).long()
outi = sampler.sample_from_nodes(root_nodes)
sam_node += outi[0].sample_nodes().size(0)
sam_node += outi[1].sample_nodes().size(0)
sam_edge += outi[0].sample_edge_num
end = time.time()
print("sample time", end-pre)
print("sam_edge", sam_edge)
print("sam_node", sam_node)
print('边吞吐量 : {:.4f}'.format(sam_edge/(end-pre)))
print('点吞吐量 : {:.4f}'.format(sam_node/(end-pre)))
if __name__ == "__main__":
test()
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor, thread
from multiprocessing import Process
from multiprocessing.pool import ThreadPool
from typing import Deque
import torch
import torch.distributed as dist
stage = ['train_stream','write_memory','write_mail','lookup']
class PipelineManager:
def __init__(self,num_tasks = 10):
self.stream_set = {}
self.dist_set = {}
self.args_queue = {}
self.thread_pool = ThreadPoolExecutor(num_tasks)
for k in stage:
self.stream_set[k] = torch.cuda.Stream()
self.dist_set[k] = dist.new_group()
self.args_queue[k] = Deque()
def submit(self,state,func,kwargs):
future = self.thread_pool.submit(self.run, state,func,kwargs)
return future
def run(self,state,func,kwargs):
with torch.cuda.stream(self.stream_set[state]):
return func(**kwargs,group = self.dist_set[state])
manger = None
def getPipelineManger():
global manger
if manger == None:
manger = PipelineManager()
return manger
\ No newline at end of file
......@@ -149,6 +149,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
......
......@@ -58,7 +58,6 @@ class GeneralModel(torch.nn.Module):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
print(rst,mfgs[l][h])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
......
......@@ -31,7 +31,6 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
#print(b.edata['ID'],b.edata['dt'],b.srcdata['ID'])
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
......@@ -51,21 +50,19 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
sample_out,metadata = sample_out
else:
metadata = None
# print(sample_out)
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
#print(len(sample_out),eid,eid_len)
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
#print(eid_tensor)
dist_eid = eid_mapper[eid_tensor>>1].to(device)
dist_eid = graph.sample_graph['dist_eid'][eid_tensor].to(device)#eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
#print(src_node,graph.sample_graph['edge_index'][1,eid_tensor])
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = root_node.shape[0]
root_node = data.nodes.to(graph.nidst_eid_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor].to(device)])
......@@ -77,44 +74,30 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
#print(src_ts,root_node)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
#sprint(nid_tensor)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = None
mem = mailbox.get_memory(dist_nid)
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = (eid_tensor>>1)
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
......@@ -123,23 +106,18 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
#print(row,col[sample_out[r].src_index()])
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
#print(idx,e_idx)
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + elen]
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
#print(b.all_edges)
#print(dist_nid[b.srcdata['ID']],dist_nid[b.srcdata['ID'][col[sample_out[r].src_index().to(device)]]])
#print(b.edata['dt'],b.srcdata['ts'])
col = row
col_len += eid_len[r]
row_len += eid_len[r]
......@@ -152,7 +130,6 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
......
......@@ -37,6 +37,7 @@ class DistributedGraphStore:
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
self.sample_graph['dist_eid'] = self.eids_mapper[pdata.sample_graph['eids']]
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
......@@ -255,7 +256,7 @@ class TemporalNeighborSampleGraph(DistributedGraphStore):
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
self.eid = sample_graph['eids']#torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
......
......@@ -56,11 +56,11 @@ import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
#torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9437'
def seed_everything(seed=42):
......@@ -71,7 +71,7 @@ def seed_everything(seed=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
seed_everything(34)
def main():
print('main')
use_cuda = True
......@@ -80,9 +80,9 @@ def main():
torch.set_num_threads(int(40/torch.distributed.get_world_size()))
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/here/{}".format(args.dataname), algo="metis_for_tgnn")
pdata = partition_load("/mnt/data/part_data/v2/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
print(graph.num_nodes)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
......@@ -100,18 +100,15 @@ def main():
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=int(40/torch.distributed.get_world_size()),policy = policy, graph_name = "wiki_train")
print(sample_graph.eid.shape,sample_graph.edge_index.shape,sample_graph.edge_ts.shape,graph.edge_attr)
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_ts.max(),val_ts.max(),test_ts.max())
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
......@@ -148,7 +145,6 @@ def main():
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
......@@ -299,10 +295,6 @@ def main():
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
......@@ -315,8 +307,12 @@ def main():
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}\n'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s\n'.format(fetch_time,write_back_time))
#torch.save(model.state_dict(), get_checkpoint_path(e))
torch.save(model.state_dict(), get_checkpoint_path(e))
if not early_stop:
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
model.eval()
if mailbox is not None:
mailbox.reset()
......
from .partition import *
from .uvm import *
\ No newline at end of file
import torch
import starrygl
from torch import Tensor
from torch_sparse import SparseTensor
from typing import *
__all__ = [
"metis_partition",
"mt_metis_partition",
"random_partition",
"pyg_metis_partition"
]
def _nopart(edge_index: Tensor, num_nodes: int):
return torch.zeros(num_nodes).type_as(edge_index)
def metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
node_sizes: Optional[Tensor] = None,
recursive: bool = False,
min_edge_cut: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.metis_partition(
rowptr, col, value, node_weight, node_sizes, num_parts, recursive, min_edge_cut)
return node_parts
def pyg_metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, _ = adj_t.coalesce().to_symmetric().csr()
node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
return node_parts
def mt_metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
num_workers: int = 8,
recursive: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.mt_metis_partition(
rowptr, col, value, node_weight, num_parts, num_workers, recursive)
return node_parts
def random_partition(edge_index: Tensor, num_nodes: int, num_parts: int) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
return torch.randint(num_parts, size=(num_nodes,)).type_as(edge_index)
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment