Commit 21c40ac0 by xxx

fix some bugs in batch_data.py and get_update_mail

parent 023440c3
import torch
import logging
__version__ = "0.1.0"
try:
from .lib import libstarrygl as ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.")
try:
from .lib import libstarrygl_sampler as sampler_ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl_sampler.so, some features may not be available.")
\ No newline at end of file
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
from .graph import *
from .utils import *
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
import shutil
from pathlib import Path
from torch_sparse import SparseTensor
from starrygl.utils.partition import *
from starrygl.parallel.route import Route
from starrygl.parallel.sparse import SparseBlocks
from .utils import init_vc_edge_index
import logging
__all__ = [
"GraphData",
]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
num_nodes: Union[int, Dict[str, int]],
) -> None:
if isinstance(edge_indices, Tensor):
self._heterogeneous = False
edge_indices = {("#", "@", "#"): edge_indices}
num_nodes = {"#": int(num_nodes)}
else:
self._heterogeneous = True
self._num_nodes: Dict[str, int] = {}
self._node_data: Dict[str, 'NodeData'] = {}
for ntype, num in num_nodes.items():
ntype, num = str(ntype), int(num)
self._num_nodes[ntype] = num
self._node_data[ntype] = NodeData(ntype, num)
self._edge_indices: Dict[Tuple[str, str, str], Tensor] = {}
self._edge_data: Dict[Tuple[str, str, str], 'EdgeData'] = {}
for (es, et, ed), edge_index in edge_indices.items():
assert isinstance(edge_index, Tensor), f"edge_index must be a tensor, got {type(edge_index)}"
assert edge_index.dim() == 2 and edge_index.size(0) == 2
es, et, ed = str(es), str(et), str(ed)
assert es in self._num_nodes, f"unknown node type '{es}', should be one of {list(self._num_nodes.keys())}."
assert ed in self._num_nodes, f"unknown node type '{ed}', should be one of {list(self._num_nodes.keys())}."
etype = (es, et, ed)
self._edge_indices[etype] = edge_index
self._edge_data[etype] = EdgeData(etype, edge_index.size(1))
self._meta = MetaData()
def meta(self) -> 'MetaData':
return self._meta
def node(self, node_type: Optional[str] = None) -> 'NodeData':
if len(self._node_data) == 1:
for data in self._node_data.values():
return data
return self._node_data[node_type]
def edge(self, edge_type: Optional[Tuple[str, str, str]] = None) -> 'EdgeData':
if len(self._edge_data) == 1:
for data in self._edge_data.values():
return data
return self._edge_data[edge_type]
def edge_index(self, edge_type: Optional[Tuple[str, str, str]] = None) -> Tensor:
if len(self._edge_indices) == 1:
for data in self._edge_indices.values():
return data
return self._edge_indices[edge_type]
def node_types(self) -> List[str]:
return list(self._node_data.keys())
def edge_types(self) -> List[Tuple[str, str, str]]:
return list(self._edge_data.keys())
def to_route(self, group: Any = None) -> Route:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
def to_sparse(self, key: Optional[str] = None, group: Any = None) -> SparseBlocks:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
edge_index = self.edge_index()
edge_index = torch.vstack([
src_ids[edge_index[0]],
dst_ids[edge_index[1]],
])
edge_attr = None if key is None else self.edge()[key]
return SparseBlocks.from_raw_indices(dst_ids, edge_index, edge_attr=edge_attr, group=group)
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
def to(self, device: Any) -> 'GraphData':
self._meta.to(device)
for ndata in self._node_data.values():
ndata.to(device)
for edata in self._edge_data.values():
edata.to(device)
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self
@staticmethod
def from_bipartite(
edge_index: Tensor,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
raw_src_ids: Optional[Tensor] = None,
raw_dst_ids: Optional[Tensor] = None,
) -> 'GraphData':
if num_src_nodes is None:
num_src_nodes = raw_src_ids.numel()
if num_dst_nodes is None:
num_dst_nodes = raw_dst_ids.numel()
g = GraphData(
edge_indices={
("src", "@", "dst"): edge_index,
},
num_nodes={
"src": num_src_nodes,
"dst": num_dst_nodes,
}
)
if raw_src_ids is not None:
g.node("src")["raw_ids"] = raw_src_ids
if raw_dst_ids is not None:
g.node("dst")["raw_ids"] = raw_dst_ids
return g
@staticmethod
def from_pyg_data(data) -> 'GraphData':
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
g = GraphData(data.edge_index, data.num_nodes)
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == data.num_nodes:
g.node()[key] = val
elif val.size(0) == data.num_edges:
g.edge()[key] = val
elif isinstance(val, SparseTensor):
logging.warning(f"found sparse matrix {key}, but ignored.")
else:
g.meta()[key] = val
return g
@staticmethod
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algorithm}")
partition_kwargs = partition_kwargs or {}
not_self_loop = (edge_index[0] != edge_index[1])
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
edge_weight = edge_weight[not_self_loop]
if algorithm == "metis":
node_parts = metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
**partition_kwargs,
)
elif algorithm == "pyg-metis":
node_parts = pyg_metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
)
else:
raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
local_edges = edge_index[:, epart_mask]
raw_src_ids, local_edges = init_vc_edge_index(
raw_dst_ids, local_edges, bipartite=True,
)
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
)
for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask]
for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask]
for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData:
def __init__(self) -> None:
self._data: Dict[str, Any] = {}
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, val: Any):
assert isinstance(key, str)
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'MetaData':
for k in self.keys():
v = self._data[k]
if isinstance(v, Tensor):
self._data[k] = v.to(device)
return self
class NodeData:
def __init__(self,
node_type: str,
num_nodes: int,
) -> None:
self._node_type = str(node_type)
self._num_nodes = int(num_nodes)
self._data: Dict[str, Tensor] = {}
@property
def node_type(self) -> str:
return self._node_type
@property
def num_nodes(self) -> int:
return self._num_nodes
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Tensor):
assert isinstance(key, str)
assert val.size(0) == self._num_nodes
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'NodeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
class EdgeData:
def __init__(self,
edge_type: Tuple[str, str, str],
num_edges: int,
) -> None:
self._edge_type = tuple(str(t) for t in edge_type)
self._num_edges = num_edges
assert len(self._edge_type) == 3
self._data: Dict[str, Tensor] = {}
@property
def edge_type(self) -> Tuple[str, str, str]:
return self._edge_type
@property
def num_edges(self) -> int:
return self._num_edges
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Optional[Tensor]) -> Tensor:
assert isinstance(key, str)
assert val.size(0) == self._num_edges
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
import torch
from torch import Tensor
from typing import *
__all__ = [
"init_vc_edge_index",
]
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from .context import DistributedContext
from .utils import (
TensorAccessor,
DistributedTensor,
DistIndex,
)
def init_distributed_context(backend: str = "gloo") -> DistributedContext:
return DistributedContext.init(backend)
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"all_to_all_v",
"all_to_all_s",
"BatchWork",
"batch_send",
"batch_recv",
]
class BatchWork:
def __init__(self,
works: Optional[List[Any]],
buffer_tensor_list: Optional[List[Tuple[Tensor, Optional[Tensor]]]],
step: int = 1,
) -> None:
if works is None:
self._step = None
self._works = None
self._buffer_tensor_list = None
else:
if buffer_tensor_list:
assert len(works) // step == len(buffer_tensor_list)
self._step = step
self._works = works
self._buffer_tensor_list = buffer_tensor_list
def wait(self):
if self._works is None:
return
for i, w in enumerate(self._works):
if w is not None:
w.wait()
if (i + 1) % self._step != 0:
continue
if self._buffer_tensor_list:
out, buf = self._buffer_tensor_list[i // self._step]
if buf is not None:
out.copy_(buf)
self._step = None
self._works = None
self._buffer_tensor_list = None
def all_to_all_v(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
async_op: bool = False,
):
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size
backend = dist.get_backend(group)
if backend == "nccl":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork([work], None) if async_op else None
elif backend == "mpi":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork([work], None) if async_op else None
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
p2p_op_works = []
buffer_tensor_list = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_t = input_tensor_list[send_i]
recv_t = output_tensor_list[recv_i]
if send_t.is_cuda:
send_t = send_t.cpu()
if recv_t.is_cuda:
recv_b = torch.empty_like(recv_t, device="cpu")
buffer_tensor_list.append((recv_t, recv_b))
else:
recv_b = recv_t
buffer_tensor_list.append((recv_t, None))
p2p_op_works.extend([
dist.isend(send_t, send_i, group=group),
dist.irecv(recv_b, recv_i, group=group),
])
work = BatchWork(p2p_op_works, buffer_tensor_list, 2)
output_tensor_list[rank].copy_(input_tensor_list[rank])
if async_op:
return work
work.wait()
def all_to_all_s(
output_tensor: Tensor,
input_tensor: Tensor,
output_rowptr: List[int],
input_rowptr: List[int],
group: Optional[Any] = None,
async_op: bool = False,
):
# rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
assert len(output_rowptr) == len(input_rowptr)
assert len(output_rowptr) == world_size + 1
output_sizes = [t-s for s, t in zip(output_rowptr, output_rowptr[1:])]
input_sizes = [t-s for s, t in zip(input_rowptr, input_rowptr[1:])]
return dist.all_to_all_single(
output=output_tensor,
input=input_tensor,
output_split_sizes=output_sizes,
input_split_sizes=input_sizes,
group=group,
async_op=async_op,
)
def batch_send(
*tensors: Tensor,
dst: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op:
works = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
works.append(dist.isend(t, dst=dst, group=group))
return BatchWork(works, None)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
dist.send(t, dst=dst, group=group)
def batch_recv(
*tensors: Tensor,
src: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op:
works = []
output_tensor_list = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
works.append(dist.irecv(b, src=src, group=group))
else:
b = None
works.append(dist.irecv(t, src=src, group=group))
output_tensor_list.append((t, b))
return BatchWork(works, output_tensor_list, 1)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
dist.recv(b, src=src, group=group)
t.copy_(b)
else:
dist.recv(t, src=src, group=group)
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
import socket
from contextlib import contextmanager
import logging
from .rpc import *
__all__ = [
"DistributedContext",
]
class DistributedContext:
"""Global context manager for distributed training
"""
@staticmethod
def init(
backend: str,
use_rpc: bool = False,
use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext':
if DistributedContext.is_initialized():
raise RuntimeError("not allowed to call init method twice.")
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is None:
logging.warning(f"LOCAL_RANK has not been set, using the default value 0.")
os.environ["LOCAL_RANK"] = local_rank = "0"
local_rank = int(local_rank)
backend = backend.lower()
if use_gpu is None:
use_gpu = False
if backend == "nccl" or backend == "mpi":
use_gpu = True
else:
use_gpu = bool(use_gpu)
if rpc_gpu is None:
rpc_gpu = use_gpu
else:
rpc_gpu = bool(rpc_gpu)
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext(
backend=backend,
ccl_init_method=ccl_init_url,
rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size,
local_rank=local_rank,
use_rpc=use_rpc,
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
return ctx
@staticmethod
def get_default_context() -> 'DistributedContext':
ctx = _get_default_dist_context()
if ctx is None:
raise RuntimeError("please call the init method first.")
return ctx
@staticmethod
def is_initialized() -> bool:
return _get_default_dist_context() is not None
def __init__(self,
backend: str,
ccl_init_method: str,
rpc_init_method: str,
rank: int, world_size: int,
local_rank: int,
use_rpc: bool,
use_gpu: bool,
rpc_gpu: bool,
) -> None:
if use_gpu:
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
init_method=ccl_init_method,
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method
if use_gpu and rpc_gpu:
device_maps = torch.zeros(world_size, dtype=torch.long, device=device)
device_maps[rank] = local_rank
dist.all_reduce(device_maps, op=dist.ReduceOp.SUM)
for i, dev in enumerate(device_maps.tolist()):
rpc_backend_options.set_device_map(
to=f"worker{i}",
device_map={local_rank: dev},
)
if use_rpc:
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
self._use_rpc = use_rpc
self._local_rank = local_rank
self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self):
if self._use_rpc:
rpc.shutdown()
@property
def rank(self) -> int:
return dist.get_rank()
@property
def world_size(self) -> int:
return dist.get_world_size()
@property
def local_rank(self) -> int:
return self._local_rank
@property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property
def device(self) -> torch.device:
return self._compute_device
def get_default_group(self):
# return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self):
return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
def remote_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_call(method, rref, *args, **kwargs)
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs)
def remote_exec(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_exec(method, rref, *args, **kwargs)
@contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None
stream.wait_stream(torch.cuda.current_stream(self.device))
with torch.cuda.stream(stream):
yield event
if with_event:
event.record()
def all_gather_remote_objects(self, obj: Any) -> List[rpc.RRef]:
if not isinstance(obj, rpc.RRef):
obj = rpc.RRef(obj)
self.__temp_ag_remote_object = obj
dist.barrier()
futs: List[torch.futures.Future] = []
for i in range(self.world_size):
info = rpc.get_worker_info(f"worker{i}")
futs.append(rpc.rpc_async(info, DistributedContext._remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
self.__temp_ag_remote_object = None
return rrefs
@staticmethod
def _remote_object():
ctx = DistributedContext.get_default_context()
return ctx.__temp_ag_remote_object
def sync_print(self, *args, **kwargs):
for i in range(self.world_size):
if i == self.rank:
print(f"rank {self.rank}:", *args, **kwargs)
dist.barrier()
def main_print(self, *args, **kwargs):
if self.rank == 0:
print(*args, **kwargs)
dist.barrier()
_DEFAULT_DIST_CONTEXT: Optional['DistributedContext'] = None
def _get_default_dist_context():
global _DEFAULT_DIST_CONTEXT
return _DEFAULT_DIST_CONTEXT
def _set_default_dist_context(ctx):
global _DEFAULT_DIST_CONTEXT
_DEFAULT_DIST_CONTEXT = ctx
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
__all__ = [
"rpc_remote_call",
"rpc_remote_void_call",
"rpc_remote_exec"
]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_call, args=args, kwargs=kwargs)
def rpc_method_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_void_call, args=args, kwargs=kwargs)
def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
method(self, *args, **kwargs) # return None
def rpc_remote_exec(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_exec, args=args, kwargs=kwargs)
@rpc.functions.async_execution
def rpc_method_exec(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.types import Number
from typing import *
from torch_sparse import SparseTensor
from .cclib import all_to_all_s
class TensorAccessor:
def __init__(self, data: Tensor) -> None:
from .context import DistributedContext
self._data = data
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
else:
self._rref = None
self.stream = torch.cuda.Stream()
@property
def data(self):
return self._data
@property
def rref(self):
return self._rref
@property
def ctx(self):
return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
@staticmethod
def _index_select(data: Tensor, dim: int, index: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data = data.index_select(dim, index)
fut = torch.futures.Future()
fut.set_result(data)
return fut
@staticmethod
def _index_copy_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_copy_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def _index_add_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_add_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def get_stream() -> Optional[torch.cuda.Stream]:
global _TENSOR_ACCESSOR_STREAM
if torch.cuda.is_available():
return None
if _TENSOR_ACCESSOR_STREAM is None:
_TENSOR_ACCESSOR_STREAM = torch.cuda.Stream()
return _TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM: Optional[torch.cuda.Stream] = None
class DistInt:
def __init__(self, sizes: List[int]) -> None:
self._data = tuple([int(t) for t in sizes])
self._total = sum(self._data)
def __getitem__(self, idx: int) -> int:
return self._data[idx]
def __call__(self) -> int:
return self._total
class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None:
self._data = index.long()
else:
index, part_ids = index.long(), part_ids.long()
self._data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property
def loc(self) -> Tensor:
return self._data & 0xFFFFFFFFFFFF
@property
def part(self) -> Tensor:
return (self._data >> 48) & 0xFFFF
@property
def dist(self) -> Tensor:
return self._data
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def to(self,device) -> Tensor:
return DistIndex(self._data.to(device))
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
if self.accessor.rref is not None:
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property
def dtype(self):
return self.accessor.data.dtype
@property
def device(self):
return self.accessor.data.device
@property
def num_nodes(self) -> int:
return self._num_nodes
@property
def num_part_nodes(self):# -> tuple[int,...]:
return self._num_part_nodes
@property
def part_id(self) -> int:
return self._part_id
@property
def num_parts(self) -> int:
return self._num_parts
def to(self,device):
return self.accessor.data.to(device)
def __getitem__(self,index):
return self.accessor.data[index]
@property
def ctx(self):
return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
recv_ptr = torch.zeros(recv_sizes.numel() + 1).type_as(recv_sizes)
recv_ptr[1:] = recv_sizes.cumsum(dim=0)
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
"send_ptr": send_ptr,
"recv_ptr": recv_ptr,
"recv_ind": recv_ind,
}
def all_to_all_get(self,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
def all_to_all_set(self,
data: Tensor,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
def callback(fs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
result: Optional[Tensor] = None
for i, f in enumerate(fs.value()):
t: Tensor = f.value()
if result is None:
result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
)
result[part_idx == i] = t
return result
return torch.futures.collect_all(futs).then(callback)
def index_copy_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
def index_add_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
\ No newline at end of file
import random
import pandas as pd
import numpy as np
import os
import torch
from torch_geometric.data import Data
from starrygl.sample.graph_core import DataSet, DistributedGraphStore
def get_link_prediction_data(data_name: str, val_ratio, test_ratio):
"""
generate data for link prediction task (inductive & transductive settings)
:param dataset_name: str, dataset name
:param val_ratio: float, validation data ratio
:param test_ratio: float, test data ratio
:return: node_raw_features, edge_raw_features, (np.ndarray),
full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
"""
# Load data and train val test split
#graph_df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
# n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
#else:
# n_feat = None
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
# e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
#else:
# e_feat = None
#
## get the timestamp of validate and test set
#src_node_ids = torch.from_numpy(np.array(graph_df.src.values)).long()
#dst_node_ids = torch.from_numpy(np.array(graph_df.dst.values)).long()
#node_interact_times = torch.from_numpy(np.array(graph_df.time.values)).long()
#
#train_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 0)
#test_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 1)
#val_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 2)
# the setting of seed follows previous works
graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(data_name, data_name))
edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(data_name, data_name))
node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(data_name, data_name))
NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {data_name} is bigger than {NODE_FEAT_DIM}!'
assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {data_name} is bigger than {EDGE_FEAT_DIM}!'
# padding the features of edges and nodes to the same dimension (172 for all the datasets)
if node_raw_features.shape[1] < NODE_FEAT_DIM:
node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)
e_feat = edge_raw_features
n_feat = torch.from_numpy(node_raw_features.astype(np.float32))
e_feat = torch.from_numpy(edge_raw_features.astype(np.float32))
assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], 'Unaligned feature dimensions after feature padding!'
# get the timestamp of validate and test set
val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
src_node_ids = torch.from_numpy(graph_df.u.values.astype(np.longlong))
dst_node_ids = torch.from_numpy(graph_df.i.values.astype(np.longlong))
node_interact_times = torch.from_numpy(graph_df.ts.values.astype(np.float32))
#edge_ids = torch.from_numpy(graph_df.idx.values.astype(np.longlong))
labels = torch.from_numpy(graph_df.label.values)
unique_node_ids = torch.cat((src_node_ids,dst_node_ids)).unique()
train_mask = node_interact_times <= val_time
val_mask = ((node_interact_times > val_time)&(node_interact_times <= test_time))
test_mask = (node_interact_times > test_time)
torch.manual_seed(2020)
train_node_set = torch.cat((src_node_ids[train_mask],dst_node_ids[train_mask])).unique()
test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time]))
new_test_node_set = set(random.sample(test_node_set, int(0.1 * unique_node_ids.shape[0])))
new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
# mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
observed_edges_mask = torch.from_numpy(np.logical_and(~new_test_source_mask, ~new_test_destination_mask)).long()
train_mask = (train_mask & observed_edges_mask)
mask = torch.isin(unique_node_ids,train_node_set,invert = True)
new_node_set = unique_node_ids[mask]
edge_contains_new_node_mask = (torch.isin(src_node_ids,new_node_set) | torch.isin(dst_node_ids,new_node_set))
new_node_val_mask = (val_mask & edge_contains_new_node_mask)
new_node_test_mask = (test_mask & edge_contains_new_node_mask)
full_data = Data()
full_data.edge_index = torch.stack((src_node_ids,dst_node_ids))
sample_graph = {}
sample_src = torch.cat([src_node_ids.view(-1, 1), dst_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst_node_ids.view(-1, 1), src_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([node_interact_times.view(-1, 1), node_interact_times.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(full_data.edge_index.shape[1]).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
sample_graph['train_mask'] = train_mask
sample_graph['val_mask'] = val_mask
sample_graph['test_mask'] = val_mask
sample_graph['new_node_val_mask'] = new_node_val_mask
sample_graph['new_node_test_mask'] = new_node_test_mask
print(unique_node_ids.max().item(),unique_node_ids.shape[0])
full_data.num_nodes = int(unique_node_ids.max().item())+1
full_data.num_edges = node_interact_times.shape[0]
full_data.sample_graph = sample_graph
full_data.x = n_feat
full_data.edge_attr = e_feat
full_data.y = labels
full_data.edge_ts = node_interact_times
full_data.train_mask = train_mask
full_data.val_mask = val_mask
full_data.test_mask = test_mask
full_data.new_node_val_mask = new_node_val_mask
full_data.new_node_test_mask = new_node_test_mask
return full_data
#full_graph = DistributedGraphStore(full_data, device, uvm_node, uvm_edge)
#train_data = torch.masked_select(full_data.edge_index,train_mask.to(device)).reshape(2,-1)
#train_ts = torch.masked_select(full_data.edge_ts,train_mask.to(device))
#val_data = torch.masked_select(full_data.edge_index,val_mask.to(device)).reshape(2,-1)
#val_ts = torch.masked_select(full_data.edge_ts,val_mask.to(device))
#test_data = torch.masked_select(full_data.edge_index,test_mask.to(device)).reshape(2,-1)
#test_ts = torch.masked_select(full_data.edge_ts,test_mask.to(device))
##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
#train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(train_mask).view(-1))
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(test_mask).view(-1))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(val_mask).view(-1))
#new_node_val_data = torch.masked_select(full_data.edge_index,new_node_val_mask.to(device)).reshape(2,-1)
#new_node_val_ts = torch.masked_select(full_data.edge_ts,new_node_val_mask.to(device))
#new_node_test_data = torch.masked_select(full_data.edge_index,new_node_test_mask.to(device)).reshape(2,-1)
#new_node_test_ts = torch.masked_select(full_data.edge_ts,new_node_test_mask.to(device))
#return full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the link prediction task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
average_precision = average_precision_score(y_true=labels, y_score=predicts)
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'average_precision': average_precision, 'roc_auc': roc_auc}
def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the node classification task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'roc_auc': roc_auc}
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import logging
import time
import argparse
import os
import json
from models.EdgeBank import edge_bank_link_prediction
from starrygl.evaluation.metrics import get_link_prediction_metrics, get_node_classification_metrics
from utils.utils import set_random_seed
from starrygl.sample.sample_core import EvaluateNegativeSampling
from utils.DataLoader import Data
def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate_data: Data, loss_func: nn.Module,
num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the link prediction task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate negative edge sampler
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert evaluate_neg_edge_sampler.seed is not None
evaluate_neg_edge_sampler.reset_random_state()
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses and metrics
evaluate_losses, evaluate_metrics = [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
evaluate_data_indices = evaluate_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices]
if evaluate_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
# we need to compute for positive and negative edges respectively, because the new sampling strategy (for evaluation) allows the negative source nodes to be
# different from the source nodes, this is different from previous works that just replace destination nodes with negative destination nodes
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# note that negative nodes do not change the memories while the positive nodes change the memories,
# we need to first compute the embeddings of negative nodes for memory-based models
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=None,
edges_are_positive=False,
num_neighbors=num_neighbors)
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get positive and negative probabilities, shape (batch_size, )
positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()
predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)
labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)
loss = loss_func(input=predicts, target=labels)
evaluate_losses.append(loss.item())
evaluate_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
return evaluate_losses, evaluate_metrics
def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Data, val_data: Data, test_idx_data_loader: DataLoader,
test_neg_edge_sampler: NegativeEdgeSampler, test_data: Data):
"""
evaluate the EdgeBank model for link prediction
:param args: argparse.Namespace, configuration
:param train_data: Data, train data
:param val_data: Data, validation data
:param test_idx_data_loader: DataLoader, test index data loader
:param test_neg_edge_sampler: NegativeEdgeSampler, test negative edge sampler
:param test_data: Data, test data
:return:
"""
# generate the train_validation split of the data: needed for constructing the memory for EdgeBank
train_val_data = Data(src_node_ids=np.concatenate([train_data.src_node_ids, val_data.src_node_ids]),
dst_node_ids=np.concatenate([train_data.dst_node_ids, val_data.dst_node_ids]),
node_interact_times=np.concatenate([train_data.node_interact_times, val_data.node_interact_times]),
edge_ids=np.concatenate([train_data.edge_ids, val_data.edge_ids]),
labels=np.concatenate([train_data.labels, val_data.labels]))
test_metric_all_runs = []
for run in range(args.num_runs):
set_random_seed(seed=run)
args.seed = run
args.save_result_name = f'{args.negative_sample_strategy}_negative_sampling_{args.model_name}_seed{args.seed}'
# set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/", exist_ok=True)
# create file handler that logs debug and higher level messages
fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/{str(time.time())}.log")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(fh)
logger.addHandler(ch)
run_start_time = time.time()
logger.info(f"********** Run {run + 1} starts. **********")
logger.info(f'configuration is {args}')
loss_func = nn.BCELoss()
# evaluate EdgeBank
logger.info(f'get final performance on dataset {args.dataset_name}...')
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert test_neg_edge_sampler.seed is not None
test_neg_edge_sampler.reset_random_state()
test_losses, test_metrics = [], []
test_idx_data_loader_tqdm = tqdm(test_idx_data_loader, ncols=120)
for batch_idx, test_data_indices in enumerate(test_idx_data_loader_tqdm):
test_data_indices = test_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times = \
test_data.src_node_ids[test_data_indices], test_data.dst_node_ids[test_data_indices], \
test_data.node_interact_times[test_data_indices]
if test_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
positive_edges = (batch_src_node_ids, batch_dst_node_ids)
negative_edges = (batch_neg_src_node_ids, batch_neg_dst_node_ids)
# incorporate the testing data before the current batch to history_data, which is similar to memory-based models
history_data = Data(src_node_ids=np.concatenate([train_val_data.src_node_ids, test_data.src_node_ids[: test_data_indices[0]]]),
dst_node_ids=np.concatenate([train_val_data.dst_node_ids, test_data.dst_node_ids[: test_data_indices[0]]]),
node_interact_times=np.concatenate([train_val_data.node_interact_times, test_data.node_interact_times[: test_data_indices[0]]]),
edge_ids=np.concatenate([train_val_data.edge_ids, test_data.edge_ids[: test_data_indices[0]]]),
labels=np.concatenate([train_val_data.labels, test_data.labels[: test_data_indices[0]]]))
# perform link prediction for EdgeBank
positive_probabilities, negative_probabilities = edge_bank_link_prediction(history_data=history_data,
positive_edges=positive_edges,
negative_edges=negative_edges,
edge_bank_memory_mode=args.edge_bank_memory_mode,
time_window_mode=args.time_window_mode,
time_window_proportion=args.test_ratio)
predicts = torch.from_numpy(np.concatenate([positive_probabilities, negative_probabilities])).float()
labels = torch.cat([torch.ones(len(positive_probabilities)), torch.zeros(len(negative_probabilities))], dim=0)
loss = loss_func(input=predicts, target=labels)
test_losses.append(loss.item())
test_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
test_idx_data_loader_tqdm.set_description(f'test for the {batch_idx + 1}-th batch, test loss: {loss.item()}')
# store the evaluation metrics at the current run
test_metric_dict = {}
logger.info(f'test loss: {np.mean(test_losses):.4f}')
for metric_name in test_metrics[0].keys():
average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
logger.info(f'test {metric_name}, {average_test_metric:.4f}')
test_metric_dict[metric_name] = average_test_metric
single_run_time = time.time() - run_start_time
logger.info(f'Run {run + 1} cost {single_run_time:.2f} seconds.')
test_metric_all_runs.append(test_metric_dict)
# avoid the overlap of logs
if run < args.num_runs - 1:
logger.removeHandler(fh)
logger.removeHandler(ch)
# save model result
result_json = {
"test metrics": {metric_name: f'{test_metric_dict[metric_name]:.4f}'for metric_name in test_metric_dict}
}
result_json = json.dumps(result_json, indent=4)
save_result_folder = f"./saved_results/{args.model_name}/{args.dataset_name}"
os.makedirs(save_result_folder, exist_ok=True)
save_result_path = os.path.join(save_result_folder, f"{args.save_result_name}.json")
with open(save_result_path, 'w') as file:
file.write(result_json)
logger.info(f'save negative sampling results at {save_result_path}')
# store the average metrics at the log of the last run
logger.info(f'metrics over {args.num_runs} runs:')
for metric_name in test_metric_all_runs[0].keys():
logger.info(f'test {metric_name}, {[test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]}')
logger.info(f'average test {metric_name}, {np.mean([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]):.4f} '
f'± {np.std([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs], ddof=1):.4f}')
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2])
h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_src + h_neg_dst.tile(neg_samples, 1))
#h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
#h_pos_dst = self.dst_fc(h[:num_edge]) #
#h_neg_src = self.src_fc(h[2 * num_edge:])
#h_pos_edge = torch.nn.functional.relu(h_src + #h_pos_dst)
#h_neg_edge = torch.nn.functional.relu(h_neg_src #+ h_pos_dst.tile(neg_samples, 1))
#h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst)
#print(h_src,h_pos_dst,h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.combined:
Q = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
K = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
V = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
if self.dim_node_feat > 0:
Q += self.w_q_n(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K += self.w_k_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
V += self.w_v_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
if self.dim_edge_feat > 0:
K += self.w_k_e(b.edata['f'])
V += self.w_v_e(b.edata['f'])
if self.dim_time > 0:
Q += self.w_q_t(zero_time_feat)[b.edges()[1]]
K += self.w_k_t(time_feat)
V += self.w_v_t(time_feat)
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_edge('v', 'm'), dgl.function.sum('m', 'h'))
else:
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0 and self.dim_edge_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(time_feat)
V = self.w_v(time_feat)
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
from layers import TimeEncode
from torch_scatter import scatter
class MailBox():
def __init__(self, memory_param, num_nodes, dim_edge_feat, _node_memory=None, _node_memory_ts=None,_mailbox=None, _mailbox_ts=None, _next_mail_pos=None, _update_mail_pos=None):
self.memory_param = memory_param
self.dim_edge_feat = dim_edge_feat
if memory_param['type'] != 'node':
raise NotImplementedError
self.node_memory = torch.zeros((num_nodes, memory_param['dim_out']), dtype=torch.float32) if _node_memory is None else _node_memory
self.node_memory_ts = torch.zeros(num_nodes, dtype=torch.float32) if _node_memory_ts is None else _node_memory_ts
self.mailbox = torch.zeros((num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_edge_feat), dtype=torch.float32) if _mailbox is None else _mailbox
self.mailbox_ts = torch.zeros((num_nodes, memory_param['mailbox_size']), dtype=torch.float32) if _mailbox_ts is None else _mailbox_ts
self.next_mail_pos = torch.zeros((num_nodes), dtype=torch.long) if _next_mail_pos is None else _next_mail_pos
self.update_mail_pos = _update_mail_pos
self.device = torch.device('cpu')
def reset(self):
self.node_memory.fill_(0)
self.node_memory_ts.fill_(0)
self.mailbox.fill_(0)
self.mailbox_ts.fill_(0)
self.next_mail_pos.fill_(0)
def move_to_gpu(self):
self.node_memory = self.node_memory.cuda()
self.node_memory_ts = self.node_memory_ts.cuda()
self.mailbox = self.mailbox.cuda()
self.mailbox_ts = self.mailbox_ts.cuda()
self.next_mail_pos = self.next_mail_pos.cuda()
self.device = torch.device('cuda:0')
def allocate_pinned_memory_buffers(self, sample_param, batch_size):
limit = int(batch_size * 3.3)
if 'neighbor' in sample_param:
for i in sample_param['neighbor']:
limit *= i + 1
self.pinned_node_memory_buffs = list()
self.pinned_node_memory_ts_buffs = list()
self.pinned_mailbox_buffs = list()
self.pinned_mailbox_ts_buffs = list()
for _ in range(sample_param['history']):
self.pinned_node_memory_buffs.append(torch.zeros((limit, self.node_memory.shape[1]), pin_memory=True))
self.pinned_node_memory_ts_buffs.append(torch.zeros((limit,), pin_memory=True))
self.pinned_mailbox_buffs.append(torch.zeros((limit, self.mailbox.shape[1], self.mailbox.shape[2]), pin_memory=True))
self.pinned_mailbox_ts_buffs.append(torch.zeros((limit, self.mailbox_ts.shape[1]), pin_memory=True))
def prep_input_mails(self, mfg, global_id_list = None,use_pinned_buffers=False):
if(global_id_list is not None):
idx = global_id_list[b.srcdata['ID']]
else:
idx = b.srcdata['ID']
if use_pinned_buffers:
idx = idx.cpu().long()
for i, b in enumerate(mfg):
if use_pinned_buffers:
torch.index_select(self.node_memory, 0, idx, out=self.pinned_node_memory_buffs[i][:idx.shape[0]])
b.srcdata['mem'] = self.pinned_node_memory_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.node_memory_ts,0, idx, out=self.pinned_node_memory_ts_buffs[i][:idx.shape[0]])
b.srcdata['mem_ts'] = self.pinned_node_memory_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.mailbox, 0, idx, out=self.pinned_mailbox_buffs[i][:idx.shape[0]])
b.srcdata['mem_input'] = self.pinned_mailbox_buffs[i][:idx.shape[0]].reshape(b.srcdata['ID'].shape[0], -1).cuda(non_blocking=True)
torch.index_select(self.mailbox_ts, 0, idx, out=self.pinned_mailbox_ts_buffs[i][:idx.shape[0]])
b.srcdata['mail_ts'] = self.pinned_mailbox_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
else:
b.srcdata['mem'] = self.node_memory[idx].cuda()
b.srcdata['mem_ts'] = self.node_memory_ts[idx].cuda()
b.srcdata['mem_input'] = self.mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = self.mailbox_ts[idx].cuda()
def update_memory(self, nid, memory, root_nodes, ts, global_node_list = None,neg_samples=1):
if nid is None:
return
num_true_src_dst = root_nodes.shape[0] // (neg_samples + 2) * 2
with torch.no_grad():
nid = nid[:num_true_src_dst].to(self.device)
memory = memory[:num_true_src_dst].to(self.device)
ts = ts[:num_true_src_dst].to(self.device)
self.node_memory[nid.long()] = memory
self.node_memory_ts[nid.long()] = ts
def update_mailbox(self, nid, memory, root_nodes, ts, edge_feats, block, global_node_list = None, neg_samples=1):
with torch.no_grad():
num_true_edges = root_nodes.shape[0] // (neg_samples + 2)
memory = memory.to(self.device)
if edge_feats is not None:
edge_feats = edge_feats.to(self.device)
if block is not None:
block = block.to(self.device)
# TGN/JODIE
if self.memory_param['deliver_to'] == 'self':
src = torch.from_numpy(root_nodes[:num_true_edges]).to(self.device)
dst = torch.from_numpy(root_nodes[num_true_edges:num_true_edges * 2]).to(self.device)
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
nid = torch.cat([src.unsqueeze(1), dst.unsqueeze(1)], dim=1).reshape(-1)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
if mail_ts.dtype == torch.float64:
import pdb; pdb.set_trace()
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
if self.memory_param['mail_combine'] == 'last':
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
# APAN
elif self.memory_param['deliver_to'] == 'neighbors':
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
if self.memory_param['mail_combine'] == 'mean':
(nid, idx) = torch.unique(block.dstdata['ID'], return_inverse=True)
mail = scatter(mail, idx, reduce='mean', dim=0)
mail_ts = scatter(mail_ts, idx, reduce='mean')
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
elif self.memory_param['mail_combine'] == 'last':
nid = block.dstdata['ID']
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
else:
raise NotImplementedError
if self.memory_param['mailbox_size'] > 1:
if self.update_mail_pos is None:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
else:
self.update_mail_pos[nid.long()] = 1
else:
raise NotImplementedError
def update_next_mail_pos(self):
if self.update_mail_pos is not None:
nid = torch.where(self.update_mail_pos == 1)[0]
self.next_mail_pos[nid] = torch.remainder(self.next_mail_pos[nid] + 1, self.memory_param['mailbox_size'])
self.update_mail_pos.fill_(0)
class GRUMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(GRUMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.GRUCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
self.delta_memory = 0
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class RNNMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(RNNMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.RNNCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
self.delta_memory = 0
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class TransformerMemoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param):
super(TransformerMemoryUpdater, self).__init__()
self.memory_param = memory_param
self.dim_time = dim_time
self.att_h = memory_param['attention_head']
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
self.w_q = torch.nn.Linear(dim_out, dim_out)
self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out)
self.att_act = torch.nn.LeakyReLU(0.2)
self.layer_norm = torch.nn.LayerNorm(dim_out)
self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
def forward(self, mfg):
for b in mfg:
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
mails = torch.cat([mails, time_feat], dim=2)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
att = torch.nn.functional.softmax(att, dim=1)
att = self.att_dropout(att)
rst = (att[:,:,:,None]*V).sum(dim=1)
rst = rst.reshape((rst.shape[0], -1))
rst += b.srcdata['mem']
rst = self.layer_norm(rst)
rst = self.mlp(rst)
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
b.srcdata['h'] = rst
self.last_updated_memory = rst.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
self.last_updated_ts = b.srcdata['ts'].detach().clone()
import torch
import dgl
from os.path import abspath, join, dirname
import sys
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import yaml
import numpy as np
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
from .route import *
from .timeline import SequencePipe
from .layerpipe import LayerPipe, LayerDetach
from .sparse import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from typing import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .route import Route, RouteWork
from .timeline.utils import vector_backward
from .utils import *
__all__ = [
"LayerPipe",
"LayerDetach",
]
class LayerPipe(ABC):
def __init__(self) -> None:
self._layer_id: Optional[int] = None
self._snapshot_id: Optional[int] = None
self._rts: List[LayerPipeRuntime] = []
@property
def layer_id(self) -> int:
assert self._layer_id is not None
return self._layer_id
@property
def snapshot_id(self) -> int:
assert self._snapshot_id is not None
return self._snapshot_id
def apply(self,
num_layers: int,
num_snapshots: int,
) -> Sequence[Sequence[Tensor]]:
runtime = LayerPipeRuntime(num_layers, num_snapshots, self)
self._rts.append(runtime)
return runtime.forward()
def backward(self):
for runtime in self._rts:
runtime.backward()
self._rts.clear()
def all_reduce(self, async_op: bool = False):
works = []
for _, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
if key in {"layer_id", "snapshot_id"}:
continue
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
def parameters(self):
params: List[nn.Parameter] = []
for name, m in self.get_model():
params.extend(m.parameters())
return params
def register_route(self, *xs: Tensor):
for t in xs:
t.requires_route = True
@abstractmethod
def get_route(self) -> Route:
raise NotImplementedError
@abstractmethod
def layer_inputs(self,
inputs: Optional[Sequence[Tensor]] = None,
) -> Sequence[Tensor]:
raise NotImplementedError
@abstractmethod
def layer_forward(self,
inputs: Sequence[Tensor],
) -> Sequence[Tensor]:
raise NotImplementedError
@contextmanager
def _switch_layer(self,
layer_id: int,
snapshot_id: int,
):
saved_layer_id = self._layer_id
saved_snapshot_id = self._snapshot_id
self._layer_id = layer_id
self._snapshot_id = snapshot_id
try:
yield
finally:
self._layer_id = saved_layer_id
self._snapshot_id = saved_snapshot_id
class LayerPipeRuntime:
def __init__(self,
num_layers: int,
num_snapshots: int,
program: LayerPipe,
) -> None:
self.num_layers = num_layers
self.num_snapshots = num_snapshots
self.program = program
self.ready_bw: Dict[Any, Union[LayerDetach, LayerRoute]] = {}
def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(xs)
route = self.program.get_route()
self.ready_bw[(layer_i, snap_i, 0)] = LayerRoute(route, *xs)
elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_forward(xs)
self.ready_bw[(layer_i, snap_i, 1)] = LayerDetach(*xs)
xs = []
for snap_i in range(self.num_snapshots):
layer_i = self.num_layers - 1
xs.append(self.ready_bw[(layer_i, snap_i, 1)].values())
return xs
def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
self.ready_bw[(layer_i, snap_i, 0)].backward()
elif op == "comp":
if layer_i + 1 < self.num_layers:
self.ready_bw.pop((layer_i + 1, snap_i, 0)).wait_gradients()
self.ready_bw.pop((layer_i, snap_i, 1)).backward()
for snap_i in range(self.num_snapshots):
self.ready_bw.pop((0, snap_i, 0)).wait_gradients()
assert len(self.ready_bw) == 0
class LayerDetach:
def __init__(self,
*inputs: Tensor,
) -> None:
outputs = tuple(t.detach() for t in inputs)
for s, t in zip(inputs, outputs):
t.requires_grad_(s.requires_grad)
self._inputs = inputs
self._outputs = outputs
def values(self) -> Sequence[Tensor]:
return tuple(self._outputs)
def backward(self) -> None:
vec_loss, vec_grad = [], []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
if not s.requires_grad:
continue
vec_loss.append(s)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class LayerRoute:
def __init__(self,
route: Route,
*inputs: Tensor,
) -> None:
self._route = route
self._works: Optional[List[Union[Tensor, RouteWork]]] = []
for t in inputs:
r = t.requires_route if hasattr(t, "requires_route") else False
if r:
self._works.append(self._route.fw_tensor(t, async_op=True))
else:
self._works.append(t.detach())
self._inputs = inputs
self._outputs: Optional[List[Tensor]] = None
def values(self) -> Sequence[Tensor]:
if self._outputs is None:
works, self._works = self._works, None
assert works is not None
outputs = []
for s, t in zip(self._inputs, works):
if isinstance(t, RouteWork):
t = t.wait()
t = t.requires_grad_(s.requires_grad)
outputs.append(t)
self._outputs = outputs
return self._outputs
def backward(self):
assert self._works is None
assert self._outputs is not None
works = []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
rs = s.requires_route if hasattr(s, "requires_route") else False
rg = s.requires_grad
if rg and rs:
works.append(self._route.bw_tensor(g, async_op=True))
elif rg:
works.append(g)
else:
works.append(None)
self._works = works
self._outputs = None
def wait_gradients(self):
if self._works is None:
return
works, self._works = self._works, None
vec_loss, vec_grad = [], []
for t, g in zip(self._inputs, works):
if isinstance(g, RouteWork):
g = g.wait()
if not t.requires_grad:
continue
vec_loss.append(t)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
# starting
if self._num_snapshots > 1:
yield "sync", 0, 0
yield "sync", 0, 1
elif self._num_snapshots > 0:
yield "sync", 0, 0
for i in range(0, self._num_snapshots, 2):
for l in range(self._num_layers):
# snapshot i
yield "comp", l, i
if l + 1 < self._num_layers:
yield "sync", l + 1, i
elif i + 2 < self._num_snapshots:
yield "sync", 0, i + 2
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
if l + 1 < self._num_layers:
yield "sync", l + 1, i + 1
elif i + 3 < self._num_snapshots:
yield "sync", 0, i + 3
class BackwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
for i in range(0, self._num_snapshots, 2):
for j in range(self._num_layers):
l = self._num_layers - j - 1
# snapshot i
yield "comp", l, i
yield "sync", l, i
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
yield "sync", l, i + 1
import torch
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed.cclib import all_to_all_s, all_to_all_v
__all__ = [
"Route",
"RouteWork",
"RouteWorkCache",
"RouteAlltoAll",
]
class Route:
@staticmethod
def from_raw_indices(
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool = True,
group: Any = None,
) -> 'Route':
if group is None:
group = dist.GroupMember.WORLD
fw_tables, bw_tables = Route._build_route_tables(
src_ids=src_ids, dst_ids=dst_ids,
bipartite=bipartite, group=group,
)
return Route(
src_len=src_ids.size(0),
dst_len=dst_ids.size(0),
**Route.__tables_to_indptr(fw_tables, bw_tables),
group=group,
)
def filter(self,
dst_mask: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
remap: bool = False,
):
if dst_mask is None:
if src_mask is None:
raise ValueError("please provide at least one parameter.")
else:
assert src_mask.dtype == torch.bool
assert src_mask.numel() == self.src_len
dst_mask = self.bw_tensor(src_mask.long()) != 0
else:
assert dst_mask.dtype == torch.bool
assert dst_mask.numel() == self.dst_len
tmp_src_mask = self.fw_tensor(dst_mask.long()) != 0
if src_mask is None:
src_mask = tmp_src_mask
else:
tmp_src_mask &= src_mask
src_mask = tmp_src_mask
dst_mask = self.bw_tensor(src_mask.long()) != 0
fw_ptr, fw_ind = Route.__filter_ind_and_ptr(self._fw_ptr, self._fw_ind, dst_mask)
bw_ptr, bw_ind = Route.__filter_ind_and_ptr(self._bw_ptr, self._bw_ind, src_mask)
route = Route(
src_len=self.src_len,
dst_len=self.dst_len,
fw_ptr=fw_ptr, fw_ind=fw_ind,
bw_ptr=bw_ptr, bw_ind=bw_ind,
group=self.group,
)
if remap:
fw_ind, dst_len = Route.__remap_ind(route._fw_ind, dst_mask)
bw_ind, src_len = Route.__remap_ind(route._bw_ind, src_mask)
route = Route(
src_len=src_len,
dst_len=dst_len,
fw_ptr=route._fw_ptr, fw_ind=fw_ind,
bw_ptr=route._bw_ptr, bw_ind=bw_ind,
group=route.group,
)
return dst_mask, src_mask, route
def rev(self):
return Route(
src_len=self.dst_len,
dst_len=self.src_len,
fw_ptr=self._bw_ptr, fw_ind=self._bw_ind,
bw_ptr=self._fw_ptr, bw_ind=self._fw_ind,
group=self.group,
)
def __init__(self,
src_len: int, dst_len: int,
fw_ptr: List[int], fw_ind: Tensor,
bw_ptr: List[int], bw_ind: Tensor,
group: Any,
) -> None:
assert len(fw_ptr) == len(bw_ptr)
self._src_len = src_len
self._dst_len = dst_len
self._fw_ptr = tuple(fw_ptr)
self._fw_ind = fw_ind
self._bw_ptr = tuple(bw_ptr)
self._bw_ind = bw_ind
self._group = group
@property
def group(self):
return self._group
@property
def src_len(self):
return self._src_len
@property
def dst_len(self):
return self._dst_len
@property
def part_id(self) -> int:
return dist.get_rank(self.group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self.group)
def to(self, device: Any):
self._fw_ind = self._fw_ind.to(device)
self._bw_ind = self._bw_ind.to(device)
return self
# def fw_table(self, i: int):
# return self._fw_ind[self._fw_ptr[i]:self._fw_ptr[i+1]]
# def bw_table(self, i: int):
# return self._bw_ind[self._bw_ptr[i]:self._bw_ptr[i+1]]
def apply(self,
data: Tensor,
cache: Optional['RouteWorkCache'] = None,
cache_key: Optional[str] = None,
) -> Tensor:
return RouteAlltoAll.apply(data, self, cache, cache_key)
@torch.no_grad()
def fw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.dst_len
output_tensor = torch.empty(
self._bw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device,
)
work = all_to_all_s(
output_tensor, data[self._fw_ind],
self._bw_ptr, self._fw_ptr,
group=self.group,
async_op=async_op,
)
work = RouteWork(
work if async_op else None,
self._bw_ptr, self._bw_ind,
self.src_len, output_tensor,
)
return work if async_op else work.wait()
@torch.no_grad()
def bw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.src_len
output_tensor = torch.empty(
self._fw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device,
)
work = all_to_all_s(
output_tensor, data[self._bw_ind],
self._fw_ptr, self._bw_ptr,
group=self.group,
async_op=async_op,
)
work = RouteWork(
work if async_op else None,
self._fw_ptr, self._fw_ind,
self.dst_len, output_tensor,
)
return work if async_op else work.wait()
@torch.no_grad()
def get_src_part_ids(self) -> Tensor:
input_tensor = torch.full_like(self._fw_ind, self.part_id)
output_tensor = torch.empty_like(self._bw_ind)
all_to_all_s(
output_tensor, input_tensor,
self._bw_ptr, self._fw_ptr,
group=self.group,
)
out = torch.full(
(self.src_len,), 2**16-1,
dtype=self._bw_ind.dtype,
device=self._bw_ind.device,
)
for s, t in zip(self._bw_ptr, self._bw_ptr[1:]):
ind = self._bw_ind[s:t]
assert (out[ind] == 2**16-1).all(), f"some vertices exist in more than one partition"
out[ind] = output_tensor[s:t] & 0xFF
return out
@staticmethod
def _build_route_tables(
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool,
group: Any,
) -> Tuple[List[Tensor], List[Tensor]]:
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
src_len = src_ids.size(0)
dst_len = dst_ids.size(0)
if not bipartite:
assert dst_len <= src_len
assert (src_ids[:dst_len] == dst_ids).all()
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=dst_ids.device)
all_dst_lens: List[int] = [None] * world_size
dist.all_gather_object(all_dst_lens, dst_len, group=group)
# all_reduce number of nodes
num_nodes = torch.zeros(1, **ikw)
if src_ids.numel() > 0:
num_nodes = src_ids.max().max(num_nodes)
if dst_ids.numel() > 0:
num_nodes = dst_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes = num_nodes.item() + 1
# src_ids -> local index
smp: Tensor = torch.empty(num_nodes, **ikw).fill_((2**62-1)*2+1)
smp[src_ids] = torch.arange(src_len, **ikw)
# async fetch dst_ids from other partitions
all_dst_ids: List[Tensor] = [None] * world_size
all_dst_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_dst_ids[i] = dst_ids
else:
all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw)
src_rank = dist.get_global_rank(group, i)
all_dst_get[i] = dist.broadcast(
all_dst_ids[i], src=src_rank,
async_op=True, group=group,
)
fw_tables: List[Tensor] = []
bw_tables: List[Tensor] = []
xmp = torch.empty_like(smp)
for i in range(world_size):
# prefetch dst_ids
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_dst_get[i].wait()
ids = all_dst_ids[i]
xmp.fill_(0)
xmp[ids] += 1
xmp[src_ids] += 1
ind = torch.where(xmp > 1)[0]
# dst_ids -> local index
xmp.fill_((2**62-1)*2+1)
xmp[ids] = torch.arange(ids.size(0), **ikw)
# remap src_ids and dst_ids
src_ind = smp[ind]
dst_ind = xmp[ind]
fw_tables.append(dst_ind)
bw_tables.append(src_ind)
fw_tables = Route.__backward_fw_tables(fw_tables, group=group)
# add self-loops if not bipartite graph
if not bipartite:
rank_ind = torch.arange(dst_len, **ikw)
fw_tables[rank] = bw_tables[rank] = rank_ind
return fw_tables, bw_tables
@staticmethod
def __filter_ind_and_ptr(ptr: List[int], ind: Tensor, mask: Tensor) -> Tuple[List[int], Tensor]:
m = mask[ind]
new_ptr: List[int] = [0]
new_ind: List[Tensor] = []
for s, t in zip(ptr, ptr[1:]):
new_ind.append(ind[s:t][m[s:t]])
new_ptr.append(new_ptr[-1] + new_ind[-1].numel())
return new_ptr, torch.cat(new_ind, dim=0)
@staticmethod
def __remap_ind(ind: Tensor, mask: Tensor) -> Tuple[Tensor, int]:
idx = torch.where(mask)[0]
imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
imp[idx] = torch.arange(idx.numel(), dtype=ind.dtype, device=ind.device)
return imp[ind], idx.numel()
# n: int = mask.count_nonzero().item()
# imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
# imp[mask] = torch.arange(n, dtype=ind.dtype, device=ind.device)
# return ind, int(n)
@staticmethod
def __backward_fw_tables(
fw_tables: List[Tensor],
group: Any,
) -> List[Tensor]:
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
send_sizes = [t.size() for t in fw_tables]
recv_sizes = [None] * world_size
dist.all_gather_object(recv_sizes, send_sizes, group=group)
recv_sizes = [s[rank] for s in recv_sizes]
fixed_tables = []
for s, t in zip(recv_sizes, fw_tables):
t = torch.empty(*s, dtype=t.dtype, device=t.device)
fixed_tables.append(t)
all_to_all_v(fixed_tables, fw_tables, group=group)
return fixed_tables
@staticmethod
def __tables_to_indptr(
fw_tables: List[Tensor],
bw_tables: List[Tensor],
):
fw_ptr: List[int] = [0]
for t in fw_tables:
assert t.dim() == 1
fw_ptr.append(fw_ptr[-1] + t.numel())
fw_ind = torch.cat(fw_tables, dim=0)
bw_ptr: List[int] = [0]
for t in bw_tables:
assert t.dim() == 1
bw_ptr.append(bw_ptr[-1] + t.numel())
bw_ind = torch.cat(bw_tables, dim=0)
return {
"fw_ptr": fw_ptr, "bw_ptr": bw_ptr,
"fw_ind": fw_ind, "bw_ind": bw_ind,
}
class RouteWork:
def __init__(self,
work: Any,
ptr: List[int],
ind: Tensor,
len: int,
recv_t: Tensor,
) -> None:
self._work = work
self._ptr = ptr
self._ind = ind
self._len = len
self._recv_t = recv_t
if self._work is None:
self._reduce()
def _reduce(self):
out = torch.zeros(
self._len, *self._recv_t.shape[1:],
dtype=self._recv_t.dtype,
device=self._recv_t.device,
)
for s, t in zip(self._ptr, self._ptr[1:]):
ind = self._ind[s:t]
out[ind] += self._recv_t[s:t]
self._work = None
self._ptr = None
self._ind = None
self._len = None
self._recv_t = out
def wait(self) -> Tensor:
if self._work is None:
return self._recv_t
self._work.wait()
self._reduce()
return self._recv_t
class RouteWorkCache:
def __init__(self,
enable_fw: bool = True,
enable_bw: bool = True,
) -> None:
self.enable_fw = enable_fw
self.enable_bw = enable_bw
self._cached_works: Dict[str, RouteWork] = {}
def enable_fw_(self, enable: bool = True):
self.enable_fw = enable
return self
def enable_bw_(self, enable: bool = True):
self.enable_bw = enable
return self
def wait(self):
for work in self._cached_works.values():
work.wait()
def clear(self):
self._cached_works.clear()
def get_and_set(self,
key: str,
work: RouteWork,
bw: bool = False,
) -> Optional[RouteWork]:
if bw and self.enable_bw:
key = key + "_bw"
elif not bw and self.enable_fw:
key = key + "_fw"
else:
return work
t = self._cached_works.get(key, work)
self._cached_works[key] = work
return t
class RouteAlltoAll(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
route: Route,
cache: Optional[RouteWorkCache],
cache_key: Optional[str],
):
ctx.saved_route = route
ctx.saved_cache = cache
ctx.saved_cache_key = cache_key
if cache is None or cache_key is None:
return route.fw_tensor(x)
else:
work = route.fw_tensor(x, async_op=True)
work = cache.get_and_set(cache_key, work, bw=False)
return work.wait()
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
) -> Tensor:
route: Route = ctx.saved_route
cache: Optional[RouteWorkCache] = ctx.saved_cache
cache_key: Optional[str] = ctx.saved_cache_key
if cache is None or cache_key is None:
return route.bw_tensor(grad), None, None, None
else:
work = route.bw_tensor(grad, async_op=True)
work = cache.get_and_set(cache_key, work, bw=True)
return work.wait(), None, None, None
\ No newline at end of file
from typing import Any
import torch
import torch.distributed as dist
import torch.autograd as autograd
from torch import Tensor
from typing import *
from torch_sparse import SparseTensor
__all__ = [
"SparseBlocks",
]
class SparseBlocks:
@staticmethod
def from_raw_indices(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Optional[Tensor] = None,
edge_attr: Optional[Tensor] = None,
group: Any = None,
) -> 'SparseBlocks':
assert edge_index.dim() == 2 and edge_index.size(0) == 2
if src_ids is None:
src_ids = dst_ids
src_ids, src_ptr = SparseBlocks.__fetch_ids_sizes(src_ids, group=group)
adj_ts = SparseBlocks.__remap_adj_t(dst_ids, edge_index, src_ids, src_ptr, edge_attr)
return SparseBlocks(adj_ts, group=group)
def __init__(self, adj_ts: List[SparseTensor], group: Any) -> None:
self._adj_ts = adj_ts
self._group = group
def adj_t(self, i: int) -> SparseTensor:
return self._adj_ts[i]
@property
def group(self):
return self._group
@property
def part_id(self) -> int:
return dist.get_rank(self._group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self._group)
def apply(self, x: Tensor) -> Tensor:
return SparseBlockMM.apply(self, x)
@staticmethod
def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1
if group is None:
group = dist.GroupMember.WORLD
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device)
all_lens: List[int] = [None] * world_size
dist.all_gather_object(all_lens, local_ids.numel(), group=group)
# all reduce num_nodes
num_nodes = torch.zeros(1, **ikw)
if local_ids.numel() > 0:
num_nodes = local_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes: int = num_nodes.item() + 1
# async fetch remote ids
all_ids: List[Tensor] = [None] * world_size
all_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_ids[i] = local_ids
else:
all_ids[i] = torch.empty(all_lens[i], **ikw)
src = dist.get_global_rank(group, i)
all_get[i] = dist.broadcast(
all_ids[i], src=src, async_op=True, group=group
)
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
offset: int = 0
for i in range(world_size):
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_get[i].wait()
ids = all_ids[i]
assert (imp[ids] == (2**62-1)*2+1).all(), "all ids must be orthogonal."
imp[ids] = torch.arange(offset, offset + all_lens[i], **ikw)
offset += all_lens[i]
assert (imp != (2**62-1)*2+1).all(), "some points that do not exist."
ids = torch.cat(all_ids, dim=0)
ptr: List[int] = [0]
for s in all_lens:
ptr.append(ptr[-1] + s)
return ids, ptr
@staticmethod
def __remap_adj_t(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Tensor,
src_ptr: List[int],
edge_attr: Optional[Tensor],
) -> List[SparseTensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
imp: Tensor = torch.full((dst_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[dst_ids] = torch.arange(dst_ids.numel(), **ikw)
dst = imp[edge_index[1]]
assert (dst != (2**62-1)*2+1).all()
imp: Tensor = torch.full((src_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[src_ids] = torch.arange(src_ids.numel(), **ikw)
src = imp[edge_index[0]]
assert (src != (2**62-1)*2+1).all()
edge_index = torch.vstack([src, dst])
adj = SparseTensor.from_edge_index(
edge_index=edge_index,
edge_attr=edge_attr,
sparse_sizes=(src_ids.numel(), dst_ids.numel()),
)
adj_ts: List[SparseTensor] = []
for s, t in zip(src_ptr, src_ptr[1:]):
adj_ts.append(adj[s:t].t())
return adj_ts
class SparseBlockMM(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
sp: SparseBlocks,
x: Tensor,
):
part_id = sp.part_id
num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1)
if i == part_id:
h = x.clone()
else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
src = dist.get_global_rank(group, i)
return dist.broadcast(h, src=src, group=sp.group, async_op=True)
last_work = None
out = None
for i in range(num_parts):
if i == 0:
work = async_fetch(0)
else:
work = last_work
if i + 1 < sp.num_parts:
last_work = async_fetch(i + 1)
work.wait()
h, = work.result()
if out is None:
out = sp.adj_t(i) @ h
else:
out += sp.adj_t(i) @ h
ctx.saved_sp = sp
return out
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
sp: SparseBlocks = ctx.saved_sp
part_id = sp.part_id
num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_reduce(i: int, g: Tensor):
dst = dist.get_global_rank(group, i)
return dist.reduce(
g, dst=dst, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True,
)
out = None
last_work = None
for i in range(num_parts):
g = sp.adj_t(i).t() @ grad
if i > 0:
last_work.wait()
last_work = async_reduce(i, g)
if i == part_id:
out = g
if last_work is not None:
last_work.wait()
return None, out
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"SyncBatchNorm",
]
class SyncBatchNorm(nn.Module):
def __init__(self,
num_features: Union[int, torch.Size],
eps: float = 1e-5,
momentum: float = 0.1,
) -> None:
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.num_features = num_features
self.eps = eps
self.momentum = momentum
def forward(self, x: Tensor) -> Tensor:
return SyncBatchNormFunction.apply(
x,
self.running_mean, self.running_var,
self.weight, self.bias,
self.training, self.momentum, self.eps,
)
def reset_parameters(self):
self.running_mean.data.fill_(0.0)
self.running_var.data.fill_(1.0)
self.weight.data.fill_(1.0)
self.bias.data.fill_(0.0)
@classmethod
def convert_sync_batchnorm(cls, net: nn.Module) -> nn.Module:
if isinstance(net, nn.modules.batchnorm._BatchNorm):
new_net = SyncBatchNorm(
num_features=net.num_features,
eps=net.eps,
momentum=net.momentum,
).to(net.weight.device)
new_net.weight.data[:] = net.weight.data
new_net.bias.data[:] = net.bias.data
new_net.get_buffer("running_mean").data[:] = net.get_buffer("running_mean").data
new_net.get_buffer("running_var").data[:] = net.get_buffer("running_var").data
return new_net
else:
for name, child in list(net.named_children()):
net.add_module(name, cls.convert_sync_batchnorm(child))
return net
class SyncBatchNormFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
running_mean: Tensor,
running_var: Tensor,
weight: Tensor,
bias: Tensor,
training: bool,
momentum: float,
eps: float,
):
if not training:
mean = running_mean
var = running_var
else:
ws = torch.zeros(1, dtype=x.dtype, device=x.device) + x.size(0)
ws_req = dist.all_reduce(ws, op=dist.ReduceOp.SUM, async_op=True)
if x.size(0) > 0:
sum_x = x.sum(dim=0)
else:
sum_x = torch.zeros(
size=(1,) + x.shape[1:],
dtype=x.dtype,
device=x.device,
)
sum_x_req = dist.all_reduce(sum_x, op=dist.ReduceOp.SUM, async_op=True)
if x.size(0) > 0:
sum_x2 = (x**2).sum(dim=0)
else:
sum_x2 = torch.zeros(
size=(1,) + x.shape[1:],
dtype=x.dtype,
device=x.device,
)
sum_x2_req = dist.all_reduce(sum_x2, op=dist.ReduceOp.SUM, async_op=True)
ws_req.wait()
whole_size = ws.item()
sum_x_req.wait()
mean = sum_x / whole_size
sum_x2_req.wait()
# var = sum_x2 / whole_size - mean ** 2
var = (sum_x2 - mean * sum_x) / whole_size
running_mean.mul_(1 - momentum).add_(mean * momentum)
running_var.mul_(1 - momentum).add_(var * momentum)
std = torch.sqrt(var + eps)
x_hat = (x - mean) / std
if training:
ctx.save_for_backward(x_hat, weight, std)
ctx.whole_size = whole_size
return x_hat * weight + bias
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
x_hat, weight, std = ctx.saved_tensors
dbias = grad.sum(dim=0)
dbias_req = dist.all_reduce(dbias, op=dist.ReduceOp.SUM, async_op=True)
dweight = (grad * x_hat).sum(dim=0)
dweight_req = dist.all_reduce(dweight, op=dist.ReduceOp.SUM, async_op=True)
dbias_req.wait()
dweight_req.wait()
n = ctx.whole_size
dx = (weight / n) / std * (n * grad - dbias - x_hat * dweight)
return dx, None, None, dweight, dbias, None, None, None
from .pipe import SequencePipe
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .sync import VirtualMotions, VirtualForward, BatchSync
from .utils import vector_backward
from starrygl.parallel.utils import *
class SequencePipe(ABC):
def __init__(self) -> None:
super().__init__()
self._pos_begin = 0
self._pos_end = 0
@abstractmethod
def get_group(self) -> Any:
raise NotImplementedError
@abstractmethod
def get_init_states(self) -> Union[Tensor, Sequence[Tensor]]:
raise NotImplementedError
@abstractmethod
def forward(self,
inputs: Sequence[Tensor],
states: Sequence[Tensor],
) -> Tuple[
Sequence[Tensor],
Sequence[Tensor],
]:
raise NotImplementedError
def loss_fn(self,
inputs: Sequence[Tensor],
labels: Sequence[Tensor],
) -> Tensor:
raise NotImplementedError
def get_ranks(self) -> Sequence[int]:
world_size = dist.get_world_size(self.get_group())
return tuple(range(world_size))
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
def parameters(self):
params: List[nn.Parameter] = []
for name, m in self.get_model():
params.extend(m.parameters())
return params
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def all_reduce(self, async_op: bool = False):
works = []
for name, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def apply(self, bs: int, *inputs: Tensor) -> Sequence[Tensor]:
runtime = SequencePipeRuntime(bs, self)
return SequencePipeFunction.apply(runtime, *inputs)
def fast_backward(self,
bs: int,
inputs: Sequence[Tensor],
labels: Sequence[Tensor],
) -> Optional[Tensor]:
runtime = SequencePipeRuntime(bs, self, use_fast_backward=True)
inputs_grads = runtime.backward(inputs, labels)
vector_backward(inputs, inputs_grads)
return runtime.acc_loss
@property
def begin(self) -> int:
return self._pos_begin
@property
def end(self) -> int:
return self._pos_end
@property
def batch_size(self) -> int:
return self._pos_end - self._pos_begin
@contextmanager
def _switch_batch(self, begin: int, end: int):
saved_begin = self._pos_begin
saved_end = self._pos_end
self._pos_begin = begin
self._pos_end = end
try:
yield
finally:
self._pos_begin = saved_begin
self._pos_end = saved_end
class SequencePipeRuntime:
def __init__(self,
micro_batch_size: int,
program: SequencePipe,
use_fast_backward: bool = False,
) -> None:
self.micro_batch_size = micro_batch_size
self.program = program
self.use_fast_backward = use_fast_backward
self.acc_loss = None
self.group = program.get_group()
self.ranks = program.get_ranks()
self.index = self.ranks.index(dist.get_rank(self.group))
self._last_work = None
def detach_inputs(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
detach = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "The first dimension of all tensors must be the same."
detach.append(t.detach())
return detach
def get_begin_end(self, i: int, n: int) -> Tuple[int, int]:
begin = i * self.micro_batch_size
end = min(n, begin + self.micro_batch_size)
return begin, end
def get_batch_sync(self, tensors: Sequence[Tensor], device: Any) -> BatchSync:
return BatchSync(
*tensors,
seq_index=self.index,
seq_ranks=self.ranks,
group=self.group, device=device,
)
def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
detach = self.detach_inputs(inputs)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
motion = VirtualForward(self.index, len(self.ranks), S, batch_vsz=3)
outputs = None
ready_recv: Dict[int, BatchSync] = {}
ready_send: Dict[int, BatchSync] = {}
while not motion.finished:
for op, i in motion.step_comp():
begin, end = self.get_begin_end(i, N)
if op == "forward":
batch_inputs = self.get_batch_inputs(begin, end, detach)
if self.index > 0:
batch_states = ready_recv.pop(i).wait_state()
else:
batch_states = self.get_batch_states(begin, end)
with self.program._switch_batch(begin, end):
batch_outputs, batch_states = \
self.program.forward(batch_inputs, batch_states)
if self.index + 1 < len(self.ranks):
ready_send[i] = self.get_batch_sync(
batch_states,
device=detach[0].device,
)
del batch_inputs, batch_states
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for t, b in zip(outputs, batch_outputs):
t[begin:end] = b.detach()
del batch_outputs
for op, type, i in motion.step_sync():
assert type == "state"
begin, end = self.get_begin_end(i, N)
if op == "send":
ready_send.pop(i).send_state()
elif op == "recv":
ready_recv[i] = self.get_batch_sync(
self.get_batch_states(begin, end),
device=detach[0].device,
)
ready_recv[i].recv_state()
assert not ready_recv
assert not ready_send
return outputs
def backward(self,
inputs: Sequence[Tensor],
gradients: Sequence[Tensor],
) -> Sequence[Tensor]:
detach = self.detach_inputs(inputs)
detach_grads = self.detach_inputs(gradients)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
motions = VirtualMotions(self.index, len(self.ranks), S, batch_vsz=3)
ready_recv_s: Dict[int, BatchSync] = {}
ready_recv_g: Dict[int, BatchSync] = {}
ready_send_s: Dict[int, BatchSync] = {}
ready_send_g: Dict[int, BatchSync] = {}
ready_bw_cmp = {}
input_grads = [None] * len(detach)
while not motions.finished:
for op, i in motions.step_comp():
begin, end = self.get_begin_end(i, N)
if op == "forward":
batch_inputs = self.get_batch_inputs(begin, end, detach, inputs)
if self.index > 0:
ready_send_g[i] = ready_recv_s.pop(i)
batch_states = ready_send_g[i].wait_state()
else:
batch_states = self.get_batch_states(begin, end)
with self.program._switch_batch(begin, end):
batch_outputs, batch_states = self.program.forward(batch_inputs, batch_states)
if self.index + 1 < len(self.ranks):
ready_send_s[i] = self.get_batch_sync(
batch_states,
device=detach[0].device,
)
ready_recv_g[i] = ready_send_s[i]
ready_bw_cmp[i] = (batch_inputs, batch_outputs, batch_states)
del batch_inputs, batch_outputs, batch_states
elif op == "backward":
batch_output_grads = self.get_batch_inputs(begin, end, detach_grads)
batch_inputs, batch_outputs, batch_states = ready_bw_cmp.pop(i)
if self.use_fast_backward:
with self.program._switch_batch(begin, end):
vec_loss = [self.program.loss_fn(batch_outputs, batch_output_grads)]
vec_grad = [torch.ones_like(vec_loss[0])]
if self.acc_loss is None:
self.acc_loss = vec_loss[0].detach()
else:
self.acc_loss += vec_loss[0].detach()
else:
vec_loss = list(batch_outputs)
vec_grad = list(batch_output_grads)
del batch_outputs, batch_output_grads
if self.index + 1 < len(self.ranks):
batch_state_grads = ready_recv_g.pop(i).wait_grads()
vec_loss.extend(batch_states)
vec_grad.extend(batch_state_grads)
del batch_state_grads
del batch_states
vector_backward(vec_loss, vec_grad)
for i, t in enumerate(batch_inputs):
g, t.grad = t.grad, None
if g is None:
continue
if input_grads[i] is None:
input_grads[i] = torch.zeros(N, *g.shape[1:], dtype=g.dtype, device=g.device)
input_grads[i][begin:end] = g
del batch_inputs
for op, type, i in motions.step_sync():
begin, end = self.get_begin_end(i, N)
if op == "send":
if type == "state":
ready_send_s.pop(i).send_state()
elif type == "grads":
ready_send_g.pop(i).send_grads()
elif op == "recv":
if type == "state":
ready_recv_s[i] = self.get_batch_sync(
self.get_batch_states(begin, end),
device=detach[0].device,
)
ready_recv_s[i].recv_state()
elif type == "grads":
ready_recv_g[i].recv_grads()
assert not ready_recv_s
assert not ready_recv_g
assert not ready_send_s
assert not ready_send_g
assert not ready_bw_cmp
return input_grads
def get_batch_inputs(self,
begin: int, end: int,
detach: Sequence[Tensor],
inputs: Sequence[Tensor] = None,
) -> Sequence[Tensor]:
batch = []
for i, t in enumerate(detach):
assert not t.requires_grad
t = t[begin:end]
if inputs and inputs[i].requires_grad:
t.requires_grad_()
t.retain_grad()
batch.append(t)
return batch
def get_batch_states(self,
begin: int, end: int,
) -> Sequence[Tensor]:
states = []
for s in self.program.get_init_states():
s = s.unsqueeze(0).broadcast_to(
end - begin, *s.size(),
).contiguous()
states.append(s)
return states
class SequencePipeFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
runtime: SequencePipeRuntime,
*inputs: Tensor,
):
ctx.save_for_backward(*inputs)
ctx.saved_runtime = runtime
return runtime.forward(inputs)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
*grads: Tensor,
):
inputs: Sequence[Tensor] = ctx.saved_tensors
runtime: SequencePipeRuntime = ctx.saved_runtime
with torch.enable_grad():
input_grads = runtime.backward(inputs, grads)
return None, *input_grads
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
class BatchSync:
def __init__(self,
*state: Tensor,
seq_index: int,
seq_ranks: Optional[List[int]] = None,
group: Any = None,
device: Any = None,
) -> None:
self._state = state
self._grads = [None] * len(self._state)
self._rgrad = torch.tensor(
[t.requires_grad for t in self._state],
dtype=torch.bool, device=device,
)
self._seq_index = int(seq_index)
if group is None:
group = dist.GroupMember.WORLD
self._group = group
self._device = torch.device(device)
if seq_ranks is None:
group_size = dist.get_world_size(group)
seq_ranks = range(group_size)
self._seq_ranks: Tuple[int,...] = tuple(seq_ranks)
self._works = []
def zip_for_backward(self, *grads: Optional[Tensor]):
assert len(grads) == len(self._state)
self._grads = grads
vec_loss, vec_grad = [], []
for s, g in zip(self._state, self._grads):
if s.requires_grad:
vec_loss.append(s)
vec_grad.append(g)
return vec_loss, vec_grad
def wait_state(self) -> Sequence[Tensor]:
for w in self._works:
w.wait()
self._works.clear()
rgrad = self._rgrad.tolist()
for r, t in zip(rgrad, self._state):
assert t.is_leaf
t.requires_grad_(r)
return self._state
def wait_grads(self) -> Sequence[Tensor]:
for w in self._works:
w.wait()
self._works.clear()
assert self._grads is not None
return self._grads
def send_state(self):
if not self._state:
return
if self._seq_index + 1 >= len(self._seq_ranks):
return
dst = self._seq_ranks[self._seq_index + 1]
dst = dist.get_global_rank(self._group, dst)
dist.isend(self._rgrad, dst=dst, group=self._group)
for t in self._state:
dist.isend(t, dst=dst, group=self._group)
def send_grads(self):
if not self._state:
return
if self._seq_index <= 0:
return
rgrad = self._rgrad.tolist()
dst = self._seq_ranks[self._seq_index - 1]
dst = dist.get_global_rank(self._group, dst)
for r, t in zip(rgrad, self._state):
if not r:
continue
g, t.grad = t.grad, None
if g is None:
g = torch.zeros_like(t)
dist.isend(g, dst=dst, group=self._group)
def recv_state(self):
if not self._state:
return
if self._seq_index <= 0:
return
src = self._seq_ranks[self._seq_index - 1]
src = dist.get_global_rank(self._group, src)
self._works.append(
dist.irecv(self._rgrad, src=src, group=self._group)
)
for t in self._state:
self._works.append(
dist.irecv(t, src=src, group=self._group)
)
def recv_grads(self):
if not self._state:
return
if self._seq_index + 1 >= len(self._seq_ranks):
return
rgrad = self._rgrad.tolist()
src = self._seq_ranks[self._seq_index + 1]
src = dist.get_global_rank(self._group, src)
for i, (r, t) in enumerate(zip(rgrad, self._state)):
if not r:
self._grads[i] = None
continue
if self._grads[i] is None:
self._grads[i] = torch.empty_like(t)
self._works.append(
dist.irecv(self._grads[i], src=src, group=self._group)
)
class VirtualForward:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
batch_vsz: int = 2,
) -> None:
assert batch_vsz > 0
self._max_count = max_count
self._bs = batch_vsz
vmax_count = (max_count + batch_vsz - 1) // batch_vsz
self._motions: List[ForwardGenerator] = []
for _ in range(batch_vsz):
self._motions.append(
ForwardGenerator(index, num_ranks, vmax_count)
)
self._step_count = 0
@property
def finished(self):
return self._motions[self._step_count].finished
def step_comp(self):
for op, i in self._motions[self._step_count].step_comp():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, k
def step_sync(self):
for op, d, i in self._motions[self._step_count].step_sync():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, d, k
self._step_count += 1
self._step_count %= self._bs
class ForwardGenerator:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._dst_fp = ForwardFootprint(index+1, num_ranks, max_count)
self._fp = ForwardFootprint(index, num_ranks, max_count)
self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
self._finished = False
@property
def finished(self):
t = self._dst_fp.finished
k = self._fp.finished
op, _ = self._last_action
return t and k and not op
def step_comp(self):
if self.finished:
return
op, i = self._last_action
self._last_action = None, -1
if op == "forward":
yield "forward", i
def step_sync(self):
if self.finished:
return
_, dst_op, dst_i = self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
if dst_op == "forward":
yield "send", "state", dst_i
if op == "forward" and self._index > 0:
yield "recv", "state", i
class ForwardFootprint:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._max_count = max_count
self._fw_batch_id = 0
self._count = 0
if index < 0 or index >= num_ranks:
self._finished = True
else:
self._finished = False
@property
def finished(self):
return self._finished
def step(self) -> Tuple[int, Optional[str], int]:
if self._finished:
return (self._count, None, -1)
ret = (self._count, "nop", -1)
if self._count == self._index + self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._fw_batch_id >= self._max_count:
self._finished = True
self._count += 1
return ret
class VirtualMotions:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
batch_vsz: int = 2,
) -> None:
assert batch_vsz > 0
self._max_count = max_count
self._bs = batch_vsz
vmax_count = (max_count + batch_vsz - 1) // batch_vsz
self._motions: List[MotionGenerator] = []
for _ in range(batch_vsz):
self._motions.append(
MotionGenerator(index, num_ranks, vmax_count)
)
self._step_count = 0
@property
def finished(self):
return self._motions[self._step_count].finished
def step_comp(self):
for op, i in self._motions[self._step_count].step_comp():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, k
def step_sync(self):
for op, d, i in self._motions[self._step_count].step_sync():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, d, k
self._step_count += 1
self._step_count %= self._bs
class MotionGenerator:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._src_fp = F1B1Footprint(index-1, num_ranks, max_count)
self._dst_fp = F1B1Footprint(index+1, num_ranks, max_count)
self._fp = F1B1Footprint(index, num_ranks, max_count)
self._src_fp.step()
self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
self._finished = False
@property
def finished(self):
s = self._src_fp.finished
t = self._dst_fp.finished
k = self._fp.finished
op, _ = self._last_action
return s and t and k and not op
def step_comp(self):
if self.finished:
return
op, i = self._last_action
self._last_action = None, -1
if op == "forward":
yield "forward", i
elif op == "backward":
yield "backward", i
def step_sync(self):
if self.finished:
return
_, src_op, src_i = self._src_fp.step()
_, dst_op, dst_i = self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
if op == "backward" and \
self._index + 1 < self._num_ranks:
yield "recv", "grads", i
if src_op == "backward":
assert dst_op != "forward"
yield "send", "grads", src_i
elif dst_op == "forward":
assert src_op != "backward"
yield "send", "state", dst_i
if op == "forward" and self._index > 0:
yield "recv", "state", i
class F1B1Footprint:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._max_count = max_count
self._bw_offset = 2 * self._num_ranks - self._index - 1
self._fw_batch_id = 0
self._bw_batch_id = 0
self._count = 0
if index < 0 or index >= num_ranks:
self._finished = True
else:
self._finished = False
@property
def finished(self):
return self._finished
def step(self) -> Tuple[int, Optional[str], int]:
if self._finished:
return (self._count, None, -1)
ret = (self._count, "nop", -1)
if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1
elif self._fw_batch_id < self._max_count:
if self._count >= self._index + 2 * self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._bw_batch_id >= self._max_count:
self._finished = True
self._count += 1
return ret
import torch
from torch import Tensor
from typing import *
def vector_backward(
vec_loss: Sequence[Tensor],
vec_grad: Sequence[Tensor],
):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vec_loss, vec_grad):
if g is None:
continue
if not x.requires_grad:
continue
loss.append(x.flatten())
grad.append(g.flatten())
if loss:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from collections import defaultdict
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
# def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
# works = []
# for p in net.parameters():
# if p.grad is None:
# p.grad = torch.zeros_like(p.data)
# w = dist.all_reduce(p.grad, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
# def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
# works = []
# for b in net.buffers():
# w = dist.all_reduce(b.data, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.parameters():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_tensors[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.parameters():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
if p.grad is not None:
typed_tensors[p.dtype][s:t] = p.grad.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
g = torch.empty(0, dtype=p.dtype, device=device)
p.grad = g.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.buffers():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_numel[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.buffers():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
typed_tensors[p.dtype][s:t] = p.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
p.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def default_stride(*size: int) -> Tuple[int,...]:
dims = len(size)
stride = [1] * dims
for i in range(1, dims):
k = dims - i
stride[k - 1] = stride[k] * size[k]
return tuple(stride)
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for i,mfg in enumerate(mfgs):
for b in mfg:
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
#print(b.edata['ID'],b.edata['dt'],b.srcdata['ID'])
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
#print(len(sample_out),eid,eid_len)
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
#print(eid_tensor)
dist_eid = eid_mapper[eid_tensor>>1].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
#print(src_node,graph.sample_graph['edge_index'][1,eid_tensor])
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
#print(src_ts,root_node)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
#sprint(nid_tensor)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = (eid_tensor>>1)
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
#print(row,col[sample_out[r].src_index()])
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
#print(idx,e_idx)
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
#print(b.all_edges)
#print(dist_nid[b.srcdata['ID']],dist_nid[b.srcdata['ID'][col[sample_out[r].src_index().to(device)]]])
#print(b.edata['dt'],b.srcdata['ts'])
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class LRUCache(Cache):
"""
Least-recently-used (LRU) cache
"""
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(LRUCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'lru'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_feature: Sequence[torch.Tensor]):
if len(uncached_index) > self.capacity:
num_to_cache = self.capacity
else:
num_to_cache = len(uncached_index)
node_id_to_cache = uncached_index[:num_to_cache].to(torch.int32)
self.now_cache_count -= 1
self.cache_count[cached_index] = 0
# get the k node id with the least water level
removing_cache_index = torch.topk(
self.cache_count, k=num_to_cache, largest=False).indices.to(torch.int32)
removing_node_id = self.cache_index_to_id[removing_cache_index]
# update cache attributes
for buffer,data in zip(self.buffers,uncached_feature):
buffer[removing_cache_index] = data[:num_to_cache].reshape(-1,*buffer.shape[1:])
self.cache_count[removing_cache_index] = 0
self.cache_validate[removing_node_id] = False
self.cache_validate[node_id_to_cache] = True
self.cache_map[removing_node_id] = -1
self.cache_map[node_id_to_cache] = removing_cache_index
self.cache_index_to_id[removing_cache_index] = node_id_to_cache
from typing import Callable, List, Optional, Sequence, Union
import numpy as np
import torch
from starrygl.distributed.utils import DistributedTensor
class Cache:
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
print(len(cache_data),cache_data)
assert torch.cuda.is_available() == True
self.use_local = use_local
self.is_update_cache = is_update_cache
self.device = torch.device('cuda')
self.use_remote = torch.distributed.get_world_size()>1
assert not (self.use_local is False and self.use_remote is False),\
"the data is on the cuda and no need remote cache"
self.cache_ratio = cache_ratio
self.num_cache = num_cache
self.capacity = int(self.num_cache * cache_ratio)
self.update_stream = torch.cuda.Stream()
self.buffers = []
self.pinned_buffers = []
for data in cache_data:
self.buffers.append(
torch.zeros(self.capacity,*data.shape[1:],
dtype = data.dtype,device = torch.device('cuda'))
)
self.cache_validate = torch.zeros(
num_cache, dtype=torch.bool, device=self.device)
# maps node id -> index
self.cache_map = torch.zeros(
num_cache, dtype=torch.int32, device=self.device) - 1
# maps index -> node id
self.cache_index_to_id = torch.zeros(
num_cache,dtype=torch.int32, device=self.device) -1
self.hit_sum = 0
self.hit_ = 0
def init_cache(self,ind:torch.Tensor,data:Sequence[torch.Tensor]):
pos = torch.arange(ind.shape[0],device = 'cuda',dtype = ind.dtype)
self.cache_map[ind] = pos.to(torch.int32).to('cuda')
self.cache_index_to_id[pos] = ind.to(torch.int32).to('cuda')
for data,buffer in zip(data,self.buffers):
buffer[:ind.shape[0],] = data
self.cache_validate[ind] = True
def update_cache(self, cached_index: torch.Tensor,
uncached_index: torch.Tensor,
uncached_data: Sequence[torch.Tensor]):
raise NotImplementedError
def fetch_data(self,ind:Optional[torch.Tensor] = None,
uncached_source_fn: Callable = None, source_index:torch.Tensor = None):
self.hit_sum += ind.shape[0]
assert isinstance(ind, torch.Tensor)
cache_mask = self.cache_validate[ind]
uncached_mask = ~cache_mask
self.hit_ += torch.sum(cache_mask)
cached_data = []
cached_index = self.cache_map[ind[cache_mask]]
if uncached_mask.sum() > 0:
uncached_id = ind[uncached_mask]
source_index = source_index[uncached_mask]
uncached_feature = uncached_source_fn(source_index)
if isinstance(uncached_feature,torch.Tensor):
uncached_feature = [uncached_feature]
else:
uncached_id = None
uncached_feature = [None for _ in range(len(self.buffers))]
for data,uncached_data in zip(self.buffers,uncached_feature):
nfeature = torch.zeros(
len(ind), *data.shape[1:], dtype=data.dtype,device=self.device)
nfeature[cache_mask,:] = data[cached_index]
if uncached_id is not None:
nfeature[uncached_mask] = uncached_data.reshape(-1,*data.shape[1:])
cached_data.append(nfeature)
if self.is_update_cache and uncached_mask.sum() > 0:
self.update_cache(cached_index=cached_index,
uncached_index=uncached_id,
uncached_feature=uncached_feature)
return nfeature
def invalidate(self,ind):
self.cache_validate[ind] = False
from typing import Optional
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.cache import LRU_cache
from starrygl.sample.cache.cache import Cache
from starrygl.sample.cache.static_cache import StaticCache
from starrygl.sample.cache.utils import pre_sample
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.memory.shared_mailbox import SharedMailBox
import torch
_FetchCache = None
class FetchFeatureCache:
@staticmethod
def create_fetch_cache(num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'):
"""
method to create a fetch cache instance.
Args:
num_nodes: Total number of nodes in the graph.
num_edges: Total number of edges in the graph.
edge_cache_ratio: The hit rate of cache edges.
node_cache_ratio: The hit rate of cache nodes.
graph: Distributed graph store.
mailbox: used for storing information.
policy: Caching policy, either 'lru' or 'static'.
"""
global _FetchCache
_FetchCache = FetchFeatureCache(num_nodes, num_edges,
edge_cache_ratio, node_cache_ratio,
graph,mailbox,policy)
@staticmethod
def getFetchCache():
"""
method to get the existing fetch cache instance.
Returns:
FetchFeatureCache: The existing fetch cache instance.
"""
global _FetchCache
return _FetchCache
def __init__(self, num_nodes: int, num_edges: int,
edge_cache_ratio: int, node_cache_ratio: int,
graph: DistributedGraphStore,
mailbox:SharedMailBox = None,
policy = 'lru'
):
"""
Initializes the FetchFeatureCache instance.
Args:
num_nodes: Total number of nodes in the graph.
num_edges: Total number of edges in the graph.
edge_cache_ratio: The hit rate of cache edges.
node_cache_ratio: The hit rate of cache nodes.
graph: Distributed graph store.
mailbox: used for storing information.
policy: Caching policy, either 'lru' or 'static'.
"""
if policy == 'lru':
init_fn = LRU_cache.LRUCache
elif policy == 'static':
init_fn = StaticCache
self.ctx = DistributedContext.get_default_context()
if graph.x is not None:
self.node_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[graph.x],use_local=graph.uvm_node)
else:
self.node_cache = None
if graph.edge_attr is not None:
self.edge_cache:Cache = init_fn(edge_cache_ratio,num_edges,
[graph.edge_attr],use_local = graph.uvm_edge)
else:
self.edge_cache = None
if mailbox is not None:
self.mailbox_cache:Cache = init_fn(node_cache_ratio,num_nodes,
[mailbox.node_memory,
mailbox.node_memory_ts.accessor.data.reshape(-1,1),
mailbox.mailbox,
mailbox.mailbox_ts],
use_local = mailbox.uvm)
else:
self.mailbox_cache = None
self.graph = graph
self.mailbox = mailbox
global FetchCache
FetchCache = self
def fetch_feature(self, nid: Optional[torch.Tensor] = None, dist_nid = None,
eid: Optional[torch.Tensor] = None, dist_eid = None
):
"""
Fetches node and edge features along with mailbox memory.
Args:
nid: Node indices to fetch features for.
dist_nid: The remote communication corresponding to nid.
eid: Edge indices to fetch features for.
dist_eid: The remote communication corresponding to eid.
"""
nfeat = None
mem = None
efeat = None
if self.node_cache is not None and nid is not None:
nfeat = torch.zeros(nid.shape[0],
self.node_cache.buffers[0].shape[1],
dtype = self.node_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.node_cache.use_local is False:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
nfeat[local_mask] = self.graph.x.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
nfeat[remote_mask] = self.node_cache.fetch_data(remote_id,\
self.graph._get_node_attr,source_id)[0]
else:
nfeat = self.node_cache.fetch_data(nid,
self.graph._get_node_attr,dist_nid)[0]
if self.mailbox_cache is not None and nid is not None:
memory = torch.zeros(nid.shape[0],
self.mailbox_cache.buffers[0].shape[1],
dtype = self.mailbox_cache.buffers[0].dtype,
device = torch.device('cuda')
)
memory_ts = torch.zeros(nid.shape[0],
dtype = self.mailbox_cache.buffers[1].dtype,
device = torch.device('cuda')
)
mailbox = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[2].shape[1:],
dtype = self.mailbox_cache.buffers[2].dtype,
device = torch.device('cuda')
)
mailbox_ts = torch.zeros(nid.shape[0],
*self.mailbox_cache.buffers[3].shape[1:],
dtype = self.mailbox_cache.buffers[3].dtype,
device = torch.device('cuda')
)
if self.mailbox_cache.use_local is False:
if self.node_cache is None:
local_mask = (DistIndex(dist_nid).part == torch.distributed.get_rank())
local_id = dist_nid[local_mask]
remote_mask = ~local_mask
remote_id = nid[remote_mask]
source_id = dist_nid[remote_mask]
mem = self.mailbox.gather_memory(local_id)
memory[local_mask],memory_ts[local_mask],mailbox[local_mask],mailbox_ts[local_mask]= mem
if remote_mask.sum() > 0:
mem = self.mailbox_cache.fetch_data(remote_id,\
self.mailbox.gather_memory,source_id)
memory[remote_mask] = mem[0]
memory_ts[remote_mask] = mem[1].reshape(-1)
mailbox[remote_mask] = mem[2]
mailbox_ts[remote_mask] = mem[3]
mem = memory,memory_ts,mailbox,mailbox_ts
else:
mem = self.mailbox_cache.fetch_data(nid,mailbox.gather_memory,dist_nid)
if self.edge_cache is not None and eid is not None:
efeat = torch.zeros(eid.shape[0],
self.edge_cache.buffers[0].shape[1],
dtype = self.edge_cache.buffers[0].dtype,
device = torch.device('cuda')
)
if self.edge_cache.use_local is False:
local_mask = (DistIndex(dist_eid).part == torch.distributed.get_rank())
local_id = dist_eid[local_mask]
efeat[local_mask] = self.graph.edge_attr.accessor.data[DistIndex(local_id).loc]
remote_mask = ~local_mask
if remote_mask.sum() > 0:
remote_id = eid[remote_mask]
source_id = dist_eid[remote_mask]
efeat[remote_mask] = self.edge_cache.fetch_data(remote_id,\
self.graph._get_edge_attr,source_id)[0]
else:
efeat = self.node_cache.fetch_data(eid,
self.graph._get_edge_attr,dist_eid)[0]
return nfeat,efeat,mem
def init_cache_with_presample(self,dataloader, num_epoch:int = 10):
"""
Initializes the cache with pre-sampled data from the provided dataloader.
Args:
dataloader: The data loader we implement, containing the graph data.
num_epoch: Number of epochs to pre-sample the data.
"""
node_size = self.node_cache.capacity if self.node_cache is not None else 0
edge_size = self.edge_cache.capacity if self.edge_cache is not None else 0
node_counts,edge_counts = pre_sample(dataloader=dataloader,
num_epoch=num_epoch,
node_size = node_size,
edge_size = edge_size)
if node_size != 0:
if self.node_cache.use_local is False:
dist_mask = DistIndex(self.graph.nids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
node_counts = node_counts[dist_mask]
_,nid = node_counts.topk(node_size)
if self.node_cache.use_local is False:
nid = dist_mask.nonzero()[nid]
dist_nid = self.graph.nids_mapper[nid].unique()
node_feature = self.graph._get_node_attr(dist_nid.to(self.graph.x.device))
_nid = nid.reshape(-1)
self.node_cache.init_cache(_nid,node_feature)
print('finish node init')
if edge_size != 0:
if self.edge_cache.use_local is False:
dist_mask = DistIndex(self.graph.eids_mapper).part == torch.distributed.get_rank()
dist_mask = ~dist_mask
edge_counts = edge_counts[dist_mask]
_,eid = edge_counts.topk(edge_size)
if self.edge_cache.use_local is False:
eid_ = dist_mask.nonzero()[eid]
else:
eid_ = eid
dist_eid = self.graph.eids_mapper[eid_].unique()
edge_feature = self.graph._get_edge_attr(dist_eid.to(self.graph.edge_attr.device))
eid_ = eid_.reshape(-1)
self.edge_cache.init_cache(eid_,edge_feature)
print('finish edge init')
\ No newline at end of file
from typing import Optional, Sequence, Union
import torch
from starrygl.distributed.utils import DistributedTensor
from starrygl.sample.cache.cache import Cache
class StaticCache(Cache):
def __init__(self, cache_ratio: int,
num_cache:int,
cache_data: Sequence[DistributedTensor],
use_local:bool = False,
pinned_buffers_shape: Sequence[torch.Size] = None,
is_update_cache = False
):
super(StaticCache, self).__init__(cache_ratio,num_cache,
cache_data,use_local,
pinned_buffers_shape,
is_update_cache)
self.name = 'static'
self.now_cache_count = 0
self.cache_count = torch.zeros(
self.capacity, dtype=torch.int32, device=torch.device('cuda'))
self.is_update_cache = False
import torch
#dataloader不要加import
def pre_sample(dataloader, num_epoch:int,node_size:int,edge_size:int):
nodes_counts = torch.zeros(dataloader.graph.num_nodes,dtype = torch.long)
edges_counts = torch.zeros(dataloader.graph.num_edges,dtype = torch.long)
print(nodes_counts.shape,edges_counts.shape)
sampler = dataloader.sampler
neg_sampling = dataloader.neg_sampler
sample_fn = dataloader.sampler_fn
graph = dataloader.graph
for _ in range(num_epoch):
dataloader.__iter__()
while dataloader.recv_idxs < dataloader.expected_idx:
dataloader.recv_idxs += 1
data = dataloader._next_data()
out = sample_fn(sampler,data,neg_sampling)
if(len(out)>0):
sample_out,metadata = out
else:
sample_out = out
eid = [ret.eid() for ret in sample_out]
eid_tensor = torch.cat(eid,dim = 0)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
dst_node = graph.sample_graph['edge_index'][1,eid_tensor*2].to(graph.nids_mapper.device)
eid_tensor = torch.unique(eid_tensor)
nid_tensor = torch.unique(torch.cat((src_node,dst_node)))
edges_counts[eid_tensor] += 1
nodes_counts[nid_tensor] += 1
return nodes_counts,edges_counts
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
torch.cuda.synchronize()
return batch_data
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
from starrygl.utils.uvm import *
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
ctx = DistributedContext.get_default_context()
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
edge_attr = uvm_share(edge_attr,device = torch.device('cpu'))
edge_attr.copy_(pdata.edge_attr)
edge_attr = uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
from starrygl.distributed.context import DistributedContext
from typing import *
import torch.distributed as dist
import torch
from starrygl.distributed.utils import DistIndex
def build_mapper(nids):
rank = dist.get_rank()
world_size = dist.get_world_size()
dst_len = nids.size(0)
ikw = dict(dtype=torch.long, device=nids.device)
num_nodes = torch.zeros(1, **ikw)
num_nodes[0] = dst_len
dist.all_reduce(num_nodes, op=dist.ReduceOp.SUM)
all_ids: List[torch.Tensor] = [None] * world_size
dist.all_gather_object(all_ids,nids)
part_mp = torch.empty(num_nodes,**ikw)
ind_mp = torch.empty(num_nodes,**ikw)
for i in range(world_size):
iid = all_ids[i]
part_mp[iid] = i
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp)
def get_validate_graph(self,graph):
pass
\ No newline at end of file
import starrygl
from typing import Union
from typing import List
from typing import Optional
import torch
from torch.distributed import rpc
import torch_scatter
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist
#from starrygl.utils.uvm import cudaMemoryAdvise
class SharedMailBox():
'''
We will first define our mailbox, including our definitions of mialbox and memory:
.. code-block:: python
from starrygl.sample.memory.shared_mailbox import SharedMailBox
mailbox = SharedMailBox(num_nodes=num_nodes, memory_param=memory_param, dim_edge_feat=dim_edge_feat)
Args:
num_nodes (int): number of nodes
memory_param (dict): the memory parameters in the yaml file,refer to TGL
dim_edge_feat (int): the dim of edge feature
device (torch.device): the device used to store MailBox
uvm (bool): 1-use uvm, 0-don't use uvm
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.memory.shared_mailbox import SharedMailBox
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
We then need to hand over the mailbox to the data loader as in the above example, so that the relevant memory/mailbox can be directly loaded during training.
During the training, we will call `get_update_memory`/`get_update_mail` function constantly updates
the relevant storage,which is the idea related to TGN.
'''
def __init__(self,
num_nodes,
memory_param,
dim_edge_feat,
device = torch.device('cuda'),
uvm = False):
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
if memory_param['type'] != 'node':
raise NotImplementedError
self.memory_param = memory_param
self.memory_size = memory_param['dim_out']
assert not (device.type =='cpu' and uvm is True),\
'set uvm must set device on cuda'
memory_device = device
if device.type == 'cuda' and uvm is True:
memory_device = torch.device('cpu')
node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =memory_device)
node_memory_ts = torch.zeros(self.num_nodes,
dtype=torch.float32,
device = self.device)
mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat,
device = memory_device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']),
dtype=torch.float32,device = self.device)
self.uvm = uvm
if uvm is True:
ctx = DistributedContext.get_default_context()
node_memory = starrygl.utils.uvm.uvm_empty(*node_memory.shape,
dtype=node_memory.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(node_memory,device = ctx.device)
starrygl.utils.uvm.uvm_advise(node_memory,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(node_memory)
mailbox = starrygl.utils.uvm.uvm_empty(*mailbox.shape,
dtype=mailbox.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(mailbox,device = ctx.device)
starrygl.utils.uvm.vm_advise(mailbox,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(mailbox)
self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox)
self.mailbox_ts = DistributedTensor(mailbox_ts)
self.next_mail_pos = torch.zeros((self.num_nodes),
dtype=torch.long,
device = self.device)
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc == True:
self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
def reset(self):
self.node_memory.accessor.data.zero_()
self.node_memory_ts.accessor.data.zero_()
self.mailbox.accessor.data.zero_()
self.mailbox_ts.accessor.data.zero_()
self.next_mail_pos.zero_()
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.node_memory.accessor.data[index] = source
self.node_memory_ts.accessor.data[index] = source_ts
def set_mailbox_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[index] = torch.remainder(
self.next_mail_pos[index] + 1,
self.memory_param['mailbox_size'])
def set_memory_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_memory_local(index,source,source_ts)
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_memory_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def add_to_mailbox_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_mailbox_local(index,source,source_ts)
else:
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_mailbox_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
dst = dst.to(self.device)
index = torch.cat([src, dst]).reshape(-1)
index = dist_indx_mapper[index]
mem_src = memory[src]
mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)#.reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
mail_ts = max_ts
mail = mail[idx]
index = unq_index
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts
memory = memory[idx]
index = unq_index
return index,memory,ts
def get_memory(self,index):
if self.num_parts == 1:
return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index]
elif self.node_memory.rrefs is None:
return self.gather_memory(dist_index = index)
else:
memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index)
mail = self.mailbox.index_select(index)
mail_ts = self.mailbox_ts.index_select(index)
def callback(fs):
memory,memory_ts,mail,mail_ts = fs.value()
memory = memory.value()
memory_ts = memory_ts.value()
mail = mail.value()
mail_ts = mail_ts.value()
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group),\
self.node_memory_ts.all_to_all_get(**ids,group = group),\
self.mailbox.all_to_all_get(**ids,group = group),\
self.mailbox_ts.all_to_all_get(**ids,group = group)
import parser
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import starrygl
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
else:
for i, pdata in enumerate(partition_data_for_gnn(data, num_parts,
algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_data_for_gnn(data: Data, num_parts: int,
algo: str, verbose: bool = False):
if algo == "metis":
part_fn = metis_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose:
print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose:
print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
eids = torch.zeros(num_edges, dtype=torch.long)
len = 0
edgeptr = torch.zeros(num_parts+1, dtype=eids.dtype)
for i in range(num_parts):
epart_i = torch.where(edge_parts == i)[0]
eids[epart_i] = torch.arange(epart_i.shape[0]) + len
len += epart_i.shape[0]
edgeptr[i+1] = len
data.eids = eids
data.sample_graph.sample_eids = eids[data.sample_graph.sample_eid]
nids = torch.zeros(num_nodes, dtype=torch.long)
len = 0
partptr = torch.zeros(num_parts+1, dtype=nids.dtype)
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
nids[npart_i] = torch.arange(npart_i.shape[0]) + len
len += npart_i.shape[0]
partptr[i+1] = len
data.edge_index = nids[data.edge_index]
data.sample_graph.edge_index = nids[data.sample_graph.edge_index]
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:, epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
"sample_graph": data.sample_graph,
"partptr": partptr,
"edgeptr": edgeptr
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def _nopart(edge_index: torch.LongTensor, num_nodes: int):
node_parts = torch.zeros(num_nodes, dtype=torch.long)
if isinstance(edge_index, torch.Tensor):
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
return node_parts
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
edges = []
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edge = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
edge_weight_dict[key]).unsqueeze(0)), dim=0)
edges.append(edge)
# w = edges.T
edges = torch.cat(edges,dim = 1)
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts
"""
weight: 各种工作负载边划分权重
按照点均衡划分
"""
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index_dict = data.edge_index_dict
tgnn_norm = compute_temporal_norm(data.edge_index, data.edge_ts, num_nodes)
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
data.eids = eids
data.sample_graph['eids'] = eids[data.sample_graph['eids']]
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
pdata = {
"ids": npart_i,
"tgnn_norm": tgnn_norm,
"edge_index": data.edge_index[:, epart_i],
"sample_graph": data.sample_graph
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "edge_index_dict" \
or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def metis_partition(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
nodes = torch.tensor(list(G.adj.keys()))
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def metis_partition_bydegree(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
value, counts = torch.unique(edge_index[1, :].view(-1), return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.graph['node_weight_attr'] = 'weight'
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = starrygl.sampler_ops.get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class EvaluateNegativeSampling(NegativeSampling):
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
src_node_ids: torch.Tensor,
dst_node_ids: torch.Tensor,
interact_times: torch.Tensor = None,
last_observed_time: float = None,
negative_sample_strategy: str = 'random',
seed: int = None
):
super(EvaluateNegativeSampling,self).__init__(mode)
self.seed = seed
self.negative_sample_strategy = negative_sample_strategy
self.src_node_ids = src_node_ids
self.dst_node_ids = dst_node_ids
self.interact_times = interact_times
self.unique_src_nodes_id = src_node_ids.unique()
self.unique_dst_nodes_id = dst_node_ids.unique()
self.src_id_mapper = torch.zeros(self.unique_src_nodes_id[-1])
self.dst_id_mapper = torch.zeros(self.unique_dst_nodes_id[-1])
self.src_id_mapper[self.unique_src_nodes_id] = torch.arange(self.unique_src_nodes_id.shape[0])
self.dst_id_mapper[self.unique_dst_nodes_id] = torch.arange(self.unique_dst_nodes_id.shape[0])
self.unique_interact_times = self.interact_times.unique()
self.earliest_time = self.unique_interact_times.min().item()
self.last_observed_time = last_observed_time
if self.negative_sample_strategy == 'inductive':
# set of observed edges
self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time)
if self.seed is not None:
self.random_state = torch.Generator()
self.random_state.manual_seed(seed)
else:
self.random_state = torch.Generator()
def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float):
selected_mask = ((self.interact_times >= start_time) and (self.interact_times <= end_time))
# return the unique select source and destination nodes in the selected time interval
return torch.cat((self.src_node_ids[selected_mask],self.dst_node_ids[selected_mask]),dim = 1)
def sample(self, num_samples: int, num_nodes: Optional[int] = None, batch_src_node_ids: Optional[torch.Tensor] = None,
batch_dst_node_ids: Optional[torch.Tensor] = None, current_batch_start_time: Optional[torch.Tensor] = None,
current_batch_end_time: Optional[torch.Tensor] = None) -> Tensor:
if self.negative_sample_strategy == 'random':
negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=num_samples)
elif self.negative_sample_strategy == 'historical':
negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
elif self.negative_sample_strategy == 'inductive':
negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
else:
raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!')
return negative_src_node_ids, negative_dst_node_ids
def random_sample(self, size: int):
if self.seed is None:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size)
else:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size, generate = self.random_state)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size, generate = self.random_state)
return self.unique_src_nodes_id[random_sample_edge_src_node_indices], self.unique_dst_nodes_id[random_sample_edge_dst_node_indices]
def random_sample_with_collision_check(self, size: int, batch_src_nodes_id:torch.Tensor, batch_dst_nodes_id:torch.Tensor):
batch_edge = torch.stack((batch_src_nodes_id,batch_dst_nodes_id))
batch_src_index = self.src_id_mapper[batch_src_nodes_id]
batch_dst_index = self.dst_id_mapper[batch_dst_nodes_id]
return_edge = torch.tensor([[],[]])
while(True):
src_ = torch.randint(0, len(self.unique_src_nodes_id), size*2)
dst_ = torch.randint(0, len(self.unique_dst_nodes_id), size*2)
edge = torch.stack((src_,dst_))
sample_id = src_*self.unique_dst_nodes_id.shape[0] + dst_
batch_id = batch_src_index * self.unique_dst_nodes_id.shape[0] + batch_dst_index
mask = torch.isin(sample_id,batch_id,invert = True)
edge = edge[:,mask]
if(edge.shape[1] >= size):
return_edge = torch.cat((return_edge,edge[:,:size]),1)
break
else:
return_edge = torch.cat((return_edge,edge),1)
size = size - edge.shape[1]
return return_edge
def historical_sample(self, size: int, batch_src_nodes_id: torch.Tensor, batch_dst_nodes_id: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]]] = True
mask = (~mask)
unique_historical_edges = uni[:,mask]
if size > unique_historical_edges.shape[1]:
num_random_sample_edges = size - len(unique_historical_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,batch_src_node_ids=batch_src_nodes_id,
batch_dst_node_ids=batch_dst_nodes_id)
sample_edges = torch.cat((unique_historical_edges,random_sample_edge),dim = 1)
else:
historical_sample_edge_node_indices = torch.randperm(unique_historical_edges.shape[1],generator=self.random_state)
sample_edges = unique_historical_edges[:,historical_sample_edge_node_indices[:size]]
return sample_edges
def inductive_sample(self, size: int, batch_src_node_ids: torch.Tensor, batch_dst_node_ids: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((self.observed_edges,current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]+historical_edges.shape[1]]] = True
mask = (~mask)
unique_inductive_edges = uni[:,mask]
if size > len(unique_inductive_edges):
num_random_sample_edges = size - len(unique_inductive_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids)
sample_edges = torch.cat((unique_inductive_edges,random_sample_edge),dim = 1)
else:
inductive_sample_edge_node_indices = torch.randperm(unique_inductive_edges.shape[1],generator=self.random_state)
sample_edges = unique_inductive_edges[:, inductive_sample_edge_node_indices[:size]]
return sample_edges
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class PreNegativeSampling(NegativeSampling):
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
neg_sample_list: torch.Tensor
):
super(PreNegativeSampling,self).__init__(mode)
self.neg_sample_list = neg_sample_list
self.next_pos = 0
def set_next_pos(self,pos):
self.next_pos = pos
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
neg_sample_out = self.neg_sample_list[
self.next_pos:self.next_pos+num_samples,:].reshape(-1)
self.next_pos = self.next_pos + num_samples
return neg_sample_out
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
import numpy as np
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
unique: bool
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
unique: bool = False
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
self.unique = unique
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class SampleOutput:
node: Optional[torch.Tensor] = None
edge_index_list: Optional[List[torch.Tensor]] = None
eid_list: Optional[List[torch.Tensor]] = None
delta_ts_list: Optional[List[torch.Tensor]] = None
metadata: Optional[Any] = None
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
with_outer_sample: SampleType,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2]
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_ts = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6]).double()
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
edge_weight1 = None
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=edge_ts, data=None, partptr=torch.tensor([0, num_nodes1]))
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
# from neighbor_sampler import get_neighbors, update_edge_weight
# row, col = edge_index1
# tnb = get_neighbors(row.contiguous(), col.contiguous(), num_nodes1, edge_weight1)
# print("tnb.neighbors:", tnb.neighbors)
# print("tnb.deg:", tnb.deg)
# print("tnb.weight:", tnb.edge_weight)
sampler = NeighborSampler(num_nodes1,
num_layers=1,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
is_distinct = 0,
policy="recent")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.ts:", sampler.tnb.timestamp)
print("tnb.weight:", sampler.tnb.edge_weight)
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.DoubleTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', tnb.edge_weight)
# print('begin update')
# pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([6,7]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([9, 9]))
end = time.time()
# print('node:', out.node)
# print('edge_index_list:', out.edge_index_list)
# print('eid_list:', out.eid_list)
# print('eid_ts_list:', out.eid_ts_list)
print("sample time:", end-pre)
print("tot_time", out[0].tot_time)
print("sam_time", out[0].sample_time)
print("sam_edge", out[0].sample_edge_num)
print('eid_list:', out[0].eid)
print('delta_ts_list:', out[0].delta_ts)
print((out[0].sample_nodes<10000).sum())
print('node:', out[0].sample_nodes)
print('node_ts:', out[0].sample_nodes_ts)
\ No newline at end of file
import torch
import time
from .Utils import GraphData
def test():
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2,2] # index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
src,dst = edge_index1
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index1 = torch.stack([row, col])
g_data = GraphData(id=0, edge_index=edge_index1, data=None, partptr=torch.tensor([0, num_nodes1]))
edge_weight1 = None
# g_data.eid=None
from .neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
sampler = NeighborSampler(num_nodes1,
num_layers=2,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.eid:", sampler.tnb.eid)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.weight:", sampler.tnb.edge_weight)
# row,col = edge_index1
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('begin update')
# pre = time.time()
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
end = time.time()
print('node1:\t', out[0].sample_nodes().tolist())
print('eid1:\t', out[0].eid().tolist())
print('edge1:\t', edge_index1[:, out[0].eid()].tolist())
print('node2:\t', out[1].sample_nodes().tolist())
print('eid2:\t', out[1].eid().tolist())
print('edge2:\t', edge_index1[:, out[1].eid()].tolist())
print("sample time:", end-pre)
if __name__ == "__main__":
test()
\ No newline at end of file
import starrygl
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import math
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
# from sample_cores import ParallelSampler, get_neighbors, heads_unique
from torch.distributed.rpc import rpc_async
class NeighborSampler(BaseSampler):
r'''
Parallel sampling is crucial for expanding model training to a large amount of data.Due to the large scale and complexity of graph data, traditional serial sampling may lead to significant waste of computing and storage resources. The significance of parallel sampling lies in improving the efficiency and overall computational speed of sampling by simultaneously sampling from multiple nodes or neighbors.
This helps to accelerate the training and inference process of the model, making it more scalable and practical when dealing with large-scale graph data.
Our parallel sampling adopts a hybrid approach of CPU and GPU, where the entire graph structure is stored on the CPU and then uploaded to the GPU after sampling the graph structure on the CPU. Each trainer has a separate sampler for parallel training.
We have encapsulated the functions for parallel sampling, and you can easily use them in the following ways:
.. code-block:: python
# First,you need to import Python packages
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
# Then,you can use ours parallel sampler
sampler = NeighborSampler(num_nodes=num_nodes, num_layers=num_layers, fanout=fanout, graph_data=graph_data,
workers=workers, is_distinct = is_distinct, policy = policy, edge_weight= edge_weight, graph_name = graph_name)
Args:
num_nodes (int): the num of all nodes in the graph
num_layers (int): the num of layers to be sampled
fanout (list): the list of max neighbors' number chosen for each layer
graph_data (:class: starrygl.sample.sample_core.neighbor_sampler): the graph data you want to sample
workers (int): the number of threads, default value is 1
is_distinct (bool): 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy (str): "uniform" or "recent" or "weighted"
edge_weight (torch.Tensor,Optional): the initial weights of edges
graph_name (str): the name of graph should provide edge_index or (neighbors, deg)
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],
graph_data=sample_graph, workers=15, policy = 'recent', graph_name = "wiki_train")
If you want to directly call parallel sampling functions, use the following methods:
.. code-block:: python
# the parameter meaning is the same as the `Args` above
from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors
# get neighbor infomation table,row and col come from graph_data.edge_index=(row, col)
tnb = get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, graph_data. eid, edge_weight, timestamp)
# call parallel sampler
p_sampler = ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers, fanout, num_layers, policy)
For complete usage and more details, please refer to `~starrygl.sample.sample_core.neighbor_sampler`
'''
def __init__(
self,
num_nodes: int,
num_layers: int,
fanout: list,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy: "uniform" or "recent" or "weighted"
edge_weight: the initial weights of edges
graph_name: the name of graph
should provide edge_index or (neighbors, deg)
"""
super().__init__()
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = workers # min(workers, torch.get_num_threads())
self.fanout = fanout
self.num_nodes = num_nodes
self.graph_data=graph_data
self.policy = policy
self.is_distinct = is_distinct
assert graph_name is not None
self.graph_name = graph_name
if(tnb is None):
if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort()
timestamp = timestamp.float().contiguous()
eid = graph_data.eid[ind].contiguous()
row, col = graph_data.edge_index[:,ind]
else:
eid = graph_data.eid
timestamp = None
row, col = graph_data.edge_index
if(edge_weight is not None):
edge_weight = edge_weight.float().contiguous()
self.tnb = starrygl.sampler_ops.get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, eid, edge_weight, timestamp)
else:
assert tnb is not None
self.tnb = tnb
self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy)
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
ts: Optional[torch.Tensor] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index,
ts: the timestamp of nodes, optional,
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
fanout_index: optional. Specify the index to fanout
Returns:
sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled
"""
if(ts is None):
self.part_unique = True
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None, self.part_unique)
ret = self.p_sampler.get_ret()
return ret
else:
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.float().contiguous(), None)
ret = self.p_sampler.get_ret()
return ret
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
ets: the timestamp of edges, optional
neg_sampling: The negative sampling configuration
Returns:
sampled_edge_index_list: the edges sampled
sampled_eid_list: the edges' id sampled
sampled_delta_ts_list:the edges' delta time sampled
metadata: other infomation
"""
src, dst = edges
num_pos = src.numel()
num_neg = 0
with_timestap = ets is not None
seed_ts = None
if neg_sampling is not None:
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg, dst_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets, ets], dim=0)
if neg_sampling.is_triplet():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets], dim=0)
#if neg_sampling.is_evaluate():
#src,dst = neg_sampling.sample(num_samples=)
else:
seed = torch.cat([src, dst], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets], dim=0)
# 去重负采样
if neg_sampling is not None and neg_sampling.unique:
if with_timestap: # ts操作
pair, inverse_seed= torch.unique(torch.stack([seed, seed_ts],0), return_inverse=True, dim=1)
seed, seed_ts = pair
seed = seed.long()
else:
seed, inverse_seed = seed.unique(return_inverse=True)
out = self.sample_from_nodes(seed, seed_ts, with_outer_sample)
if neg_sampling is None or (not neg_sampling.unique):
if with_timestap:
return out, {'seed':seed,'seed_ts':seed_ts,
'src_pos_index':torch.arange(0,num_pos),
'dst_pos_index':torch.arange(num_pos,2*num_pos),
'src_neg_index':torch.arange(2*num_pos,3*num_pos)}
else:
return out, {'seed':seed,
'src_pos_index':slice(0,num_pos),
'dst_pos_index':slice(num_pos,2*num_pos),
'src_neg_index':slice(2*num_pos,3*num_pos)}
metadata = {}
if neg_sampling.is_binary():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:3 * num_pos]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
dst_neg_index = inverse_seed[3 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
metadata = {'seed':seed, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index, 'dst_neg_index':dst_neg_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
elif neg_sampling.is_triplet():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
# src_index是seed里src点的索引
# dst_pos_index是seed里dst_pos点的索引
# dst_neg_index是seed里dst_neg点的索引
metadata = {'seed':seed, 'src_pos_index':src_pos_index, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
# sampled_nodes最前方是原始序列的采样起点也就是去重后的seed
return out, metadata
if __name__=="__main__":
# edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
# [1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
edge_weight1 = None
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
sampler = NeighborSampler(num_nodes=num_nodes1,
num_layers=3,
fanout=[2, 1, 1],
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
ts=torch.tensor([1, 2]),
with_outer_sample=SampleType.Whole)
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('delta_ts_list:', out.delta_ts_list)
print('metadata: ', out.metadata)
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
print('begin load')
pre = time.time()
timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
tnb = torch.load("tnb_before.my")
end = time.time()
print("load time:", end-pre)
row, col = g.edge_index
edge_weight=None
g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
# print('begin tnb')
# pre = time.time()
# tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes, 0, g_data.eid, edge_weight, timestamp)
# end = time.time()
# print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_before.my")
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb=tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
policy="uniform",
is_root_ts=0,
graph_name='a')
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
ts = torch.tensor([i%5+1 for i in range(0, 600000)])
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)),
ts=ts,
with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print("sample time", end-pre)
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from .Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
def test():
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
g_data = GraphData(id=1, edge_index=g.edge_index, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
row, col = g.edge_index
# edge_weight = torch.ones(g.num_edges).float()
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight[indices] = 2.0
# g_data.eid = None
edge_weight = None
timestamp = None
from .neighbor_sampler import NeighborSampler, SampleType
from .neighbor_sampler import get_neighbors
update_edge_row = row
update_edge_col = col
update_edge_w = torch.DoubleTensor([i for i in range(g.num_edges)])
# print('begin update')
# pre = time.time()
# # update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
print('begin tnb')
pre = time.time()
tnb = get_neighbors("a",
row.contiguous(),
col.contiguous(),
g.num_nodes, 0,
g_data.eid,
edge_weight,
timestamp)
end = time.time()
print("init tnb time:", end-pre)
# torch.save(tnb, "/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# print('begin load')
# pre = time.time()
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# end = time.time()
# print("load time:", end-pre)
print('begin init')
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb = tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node1:\t', out[0].sample_nodes())
print('eid1:\t', out[0].eid())
print('edge1:\t', g.edge_index[:, out[0].eid()])
print('node2:\t', out[1].sample_nodes())
print('eid2:\t', out[1].eid())
print('edge2:\t', g.edge_index[:, out[1].eid()])
print("sample time", end-pre)
if __name__ == "__main__":
test()
\ No newline at end of file
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling, SampleOutput
from neighbor_sampler import NeighborSampler, SampleType
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct, 0-don't need distinct
policy: "uniform" or "recent" or "weighted"
is_root_ts: 1-base on root's ts, 0-base on parent node's ts
edge_weight: the initial weights of edges
graph_name: the name of graph
"""
super().__init__()
self.sampler = NeighborSampler(
num_nodes=num_nodes,
tnb=tnb,
num_layers=num_layers,
fanout=[1 for i in range(num_layers)],
graph_data=graph_data,
edge_weight = edge_weight,
workers=workers,
policy=policy,
graph_name=graph_name,
is_distinct = is_distinct
)
self.num_layers = num_layers
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
ts: Optional[torch.Tensor] = None
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes, ts, with_outer_sample)
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, ets, neg_sampling, with_outer_sample)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
edge_weight1 = None
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
# Run the random walk sampling
sampler=RandomWalkSampler(num_nodes=num_nodes1,
num_layers=3,
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_root_ts=0,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([1, 2]))
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print('metadata: ', out.metadata)
import argparse
import random
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm
from .Utils import GraphData
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
class NegLinkInductiveSampler:
def __init__(self, nodes):
self.nodes = list(nodes)
def sample(self, n):
return np.random.choice(self.nodes, size=n)
def load_reddit_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
def load_gdelt_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("GDELT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
def test():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
dataset = "gdelt"#"reddit"#"gdelt"
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data, df = load_gdelt_dataset()
print(g_data)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
# print('begin load')
# pre = time.time()
# # timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
# tnb = torch.load("tnb_reddit_before.my")
# end = time.time()
# print("load time:", end-pre)
# row, col = g.edge_index
edge_weight=None
# g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from .neighbor_sampler import NeighborSampler, SampleType, get_neighbors
print('begin tnb')
row, col = g_data.edge_index
row = torch.cat([row, col])
col = torch.cat([col, row])
eid = torch.cat([g_data.eid, g_data.eid])
timestamp = torch.cat([g_data.edge_ts, g_data.edge_ts])
timestamp,ind = timestamp.sort()
timestamp = timestamp.float().contiguous()
eid = eid[ind].contiguous()
row = row[ind]
col = col[ind]
print(row, col)
g2 = GraphData(0, torch.stack([row, col]), timestamp=timestamp, data=None, partptr=torch.tensor([0, max(int(df['src'].max()), int(df['dst'].max())) + 1]))
print(g2)
pre = time.time()
tnb = get_neighbors(dataset, row.contiguous(), col.contiguous(), g_data.num_nodes, 0, eid, edge_weight, timestamp)
end = time.time()
print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_{}_before.my".format(dataset), pickle_protocol=4)
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
tnb=tnb,
num_layers=1,
fanout=[10],
graph_data=g_data,
workers=32,
policy="recent",
graph_name='a')
end = time.time()
print("init time:", end-pre)
# neg_link_sampler = NegLinkSampler(g_data.num_nodes)
from .base import NegativeSampling, NegativeSamplingMode
neg_link_sampler = NegativeSampling(NegativeSamplingMode.triplet)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
pre = time.time()
min_than_ten = 0
min_than_ten_sum = 0
seed_node_sum = 0
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
# root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long()
# ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32))
# outi = sampler.sample_from_nodes(root_nodes, ts=ts)
edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long()
outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler)
# min_than_ten += (torch.tensor(tnb.deg)[meta['seed']]<10).sum()
# min_than_ten_sum += ((torch.tensor(tnb.deg)[meta['seed']])[torch.tensor(tnb.deg)[meta['seed']]<10]).sum()
# seed_node_sum += meta['seed'].size(0)
tot_time += outi[0].tot_time
sam_time += outi[0].sample_time
# print(outi[0].sample_edge_num)
sam_edge += outi[0].sample_edge_num
# out.append(outi)
end = time.time()
# print("row", out[23][0].row())
print("sample time", end-pre)
print("tot_time", tot_time)
print("sam_time", sam_time)
print("sam_edge", sam_edge)
# print('eid_list:', out[23][0].eid())
# print('delta_ts_list:', out[10][0].delta_ts)
# print('node:', out[23][0].sample_nodes())
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('eid_list:', out[23][1].eid)
# print('node:', out[23][1].sample_nodes)
# print('node_ts:', out[23][1].sample_nodes_ts)
# print('edge_index_list:', out[0][0].edge_index)
# print("min_than_ten", min_than_ten)
# print("min_than_ten_sum", min_than_ten_sum)
# print("seed_node_sum", seed_node_sum)
# print("predict edge_num", (seed_node_sum-min_than_ten)*9+min_than_ten_sum)
print('吞吐量 : {:.4f}'.format(sam_edge/(end-pre)))
if __name__ == "__main__":
test()
\ No newline at end of file
import argparse
import random
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm
from .Utils import GraphData
def load_reddit_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=None, data=None, partptr=torch.tensor([0, num_nodes]))
return g
def load_gdelt_dataset():
df = pd.read_csv('/mnt/data/hzq/DATA/{}/edges.csv'.format("GDELT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=None, data=None, partptr=torch.tensor([0, num_nodes]))
return g
def load_ogb_dataset():
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name='ogbn-products', root="/home/zlj/hzq/code/gnn/dataset/")
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
src, dst = g.edge_index
row = torch.cat([src, dst])
col = torch.cat([dst, src])
edge_index = torch.stack([row, col])
g = GraphData(id=0, edge_index=edge_index, data=g, partptr=torch.tensor([0, g.num_nodes]))
return g # , node_data
def test():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
# parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
# parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data = load_ogb_dataset()
print(g_data)
from .neighbor_sampler import NeighborSampler, get_neighbors
# print('begin tnb')
# row, col = g_data.edge_index
# pre = time.time()
# tnb = get_neighbors("a",
# row.contiguous(),
# col.contiguous(),
# g_data.num_nodes, 0,
# g_data.eid,
# None,
# None)
# end = time.time()
# print("init tnb time:", end-pre)
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
num_layers=2,
fanout=[10,10],
graph_data=g_data,
workers=32,
policy="uniform",
graph_name='a',
is_distinct=0)
end = time.time()
print("init time:", end-pre)
print(sampler.tnb.deg[0:100])
n_list = []
for i in range (g_data.num_nodes.item() // args.batch_size):
if i+args.batch_size< g_data.num_nodes.item():
n_list.append(range(i*args.batch_size, i*args.batch_size+args.batch_size))
else:
n_list.append(range(i*args.batch_size, g_data.num_nodes.item()))
# print(n_list)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
sam_node = 0
pre = time.time()
for i, nodes in tqdm(enumerate(n_list), total=g_data.num_nodes.item() // args.batch_size):
# for nodes in n_list:
root_nodes = torch.tensor(nodes).long()
outi = sampler.sample_from_nodes(root_nodes)
sam_node += outi[0].sample_nodes().size(0)
sam_node += outi[1].sample_nodes().size(0)
sam_edge += outi[0].sample_edge_num
end = time.time()
print("sample time", end-pre)
print("sam_edge", sam_edge)
print("sam_node", sam_node)
print('边吞吐量 : {:.4f}'.format(sam_edge/(end-pre)))
print('点吞吐量 : {:.4f}'.format(sam_node/(end-pre)))
if __name__ == "__main__":
test()
\ No newline at end of file
from concurrent.futures import ThreadPoolExecutor, thread
from multiprocessing import Process
from multiprocessing.pool import ThreadPool
from typing import Deque
import torch
import torch.distributed as dist
stage = ['train_stream','write_memory','write_mail','lookup']
class PipelineManager:
def __init__(self,num_tasks = 10):
self.stream_set = {}
self.dist_set = {}
self.args_queue = {}
self.thread_pool = ThreadPoolExecutor(num_tasks)
for k in stage:
self.stream_set[k] = torch.cuda.Stream()
self.dist_set[k] = dist.new_group()
self.args_queue[k] = Deque()
def submit(self,state,func,kwargs):
future = self.thread_pool.submit(self.run, state,func,kwargs)
return future
def run(self,state,func,kwargs):
with torch.cuda.stream(self.stream_set[state]):
return func(**kwargs,group = self.dist_set[state])
manger = None
def getPipelineManger():
global manger
if manger == None:
manger = PipelineManager()
return manger
\ No newline at end of file
......@@ -149,6 +149,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
......
......@@ -58,7 +58,6 @@ class GeneralModel(torch.nn.Module):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
print(rst,mfgs[l][h])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
......
......@@ -31,7 +31,6 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
#print(b.edata['ID'],b.edata['dt'],b.srcdata['ID'])
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
......@@ -51,21 +50,19 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
sample_out,metadata = sample_out
else:
metadata = None
# print(sample_out)
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
#print(len(sample_out),eid,eid_len)
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
#print(eid_tensor)
dist_eid = eid_mapper[eid_tensor>>1].to(device)
dist_eid = graph.sample_graph['dist_eid'][eid_tensor].to(device)#eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
#print(src_node,graph.sample_graph['edge_index'][1,eid_tensor])
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = root_node.shape[0]
root_node = data.nodes.to(graph.nidst_eid_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor].to(device)])
......@@ -77,44 +74,30 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
#print(src_ts,root_node)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
#sprint(nid_tensor)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = None
mem = mailbox.get_memory(dist_nid)
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = (eid_tensor>>1)
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
......@@ -123,23 +106,18 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
#print(row,col[sample_out[r].src_index()])
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
#print(idx,e_idx)
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + elen]
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
#print(b.all_edges)
#print(dist_nid[b.srcdata['ID']],dist_nid[b.srcdata['ID'][col[sample_out[r].src_index().to(device)]]])
#print(b.edata['dt'],b.srcdata['ts'])
col = row
col_len += eid_len[r]
row_len += eid_len[r]
......@@ -152,7 +130,6 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
......
......@@ -37,6 +37,7 @@ class DistributedGraphStore:
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
self.sample_graph['dist_eid'] = self.eids_mapper[pdata.sample_graph['eids']]
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
......@@ -255,7 +256,7 @@ class TemporalNeighborSampleGraph(DistributedGraphStore):
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
self.eid = sample_graph['eids']#torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
......
......@@ -56,11 +56,11 @@ import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
#torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9437'
def seed_everything(seed=42):
......@@ -71,7 +71,7 @@ def seed_everything(seed=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
seed_everything(34)
def main():
print('main')
use_cuda = True
......@@ -80,9 +80,9 @@ def main():
torch.set_num_threads(int(40/torch.distributed.get_world_size()))
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/here/{}".format(args.dataname), algo="metis_for_tgnn")
pdata = partition_load("/mnt/data/part_data/v2/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
print(graph.num_nodes)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
......@@ -100,18 +100,15 @@ def main():
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=int(40/torch.distributed.get_world_size()),policy = policy, graph_name = "wiki_train")
print(sample_graph.eid.shape,sample_graph.edge_index.shape,sample_graph.edge_ts.shape,graph.edge_attr)
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_ts.max(),val_ts.max(),test_ts.max())
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
......@@ -148,7 +145,6 @@ def main():
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
......@@ -299,10 +295,6 @@ def main():
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
......@@ -315,8 +307,12 @@ def main():
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}\n'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s\n'.format(fetch_time,write_back_time))
#torch.save(model.state_dict(), get_checkpoint_path(e))
torch.save(model.state_dict(), get_checkpoint_path(e))
if not early_stop:
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
model.eval()
if mailbox is not None:
mailbox.reset()
......
from .partition import *
from .uvm import *
\ No newline at end of file
import torch
import starrygl
from torch import Tensor
from torch_sparse import SparseTensor
from typing import *
__all__ = [
"metis_partition",
"mt_metis_partition",
"random_partition",
"pyg_metis_partition"
]
def _nopart(edge_index: Tensor, num_nodes: int):
return torch.zeros(num_nodes).type_as(edge_index)
def metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
node_sizes: Optional[Tensor] = None,
recursive: bool = False,
min_edge_cut: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.metis_partition(
rowptr, col, value, node_weight, node_sizes, num_parts, recursive, min_edge_cut)
return node_parts
def pyg_metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, _ = adj_t.coalesce().to_symmetric().csr()
node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
return node_parts
def mt_metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
num_workers: int = 8,
recursive: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.mt_metis_partition(
rowptr, col, value, node_weight, num_parts, num_workers, recursive)
return node_parts
def random_partition(edge_index: Tensor, num_nodes: int, num_parts: int) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
return torch.randint(num_parts, size=(num_nodes,)).type_as(edge_index)
import torch
import starrygl
from torch import Tensor
from enum import Enum
from typing import *
__all__ = [
"uvm_empty",
"uvm_share",
"uvm_advise",
"uvm_prefetch",
"cudaMemoryAdvise",
]
def uvm_empty(*sizes: int, dtype: torch.dtype, device: Any):
sizes = torch.Size(sizes)
device = torch.device(device)
assert device.type == "cuda" \
and device.index is not None, "device must be cuda:x"
size_bytes = torch.Size(sizes).numel() * dtype.itemsize
# default strides
dims = len(sizes)
strides = [1] * dims
for i in range(1, dims):
strides[dims-i-1] = strides[dims-i] * sizes[dims-i]
strides = torch.Size(strides)
storage = starrygl.ops.uvm_storage_new(size_bytes, device.index)
return torch.empty(0, dtype=dtype, device=device).set_(storage, 0, sizes, strides)
def uvm_share(x: Tensor, device: Any):
device = torch.device(device)
if device.type == "cpu":
storage = starrygl.ops.uvm_storage_to_cpu(x.untyped_storage())
else:
assert device.type == "cuda" \
and device.index is not None, "device must be cuda:x or cpu"
storage = starrygl.ops.uvm_storage_to_cuda(x.untyped_storage(), device.index)
return torch.empty(0, dtype=x.dtype, device=device) \
.set_(storage, x.storage_offset(), x.size(), x.stride())
class cudaMemoryAdvise(Enum):
cudaMemAdviseSetAccessedBy = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy
cudaMemAdviseUnsetAccessedBy = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseUnsetAccessedBy
cudaMemAdviseSetPreferredLocation = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseSetPreferredLocation
cudaMemAdviseUnsetPreferredLocation = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseUnsetPreferredLocation
cudaMemAdviseSetReadMostly = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseSetReadMostly
cudaMemAdviseUnsetReadMostly = starrygl.ops.cudaMemoryAdvise.cudaMemAdviseUnsetReadMostly
def uvm_advise(x: Tensor, advise: cudaMemoryAdvise):
assert isinstance(advise, cudaMemoryAdvise)
advise = starrygl.ops.cudaMemoryAdvise(advise.value)
starrygl.ops.uvm_storage_advise(x.untyped_storage(), advise)
def uvm_prefetch(x: Tensor):
starrygl.ops.uvm_storage_prefetch(x.untyped_storage())
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment