Commit f4af3231 by Wenjie Huang

rpc based

parent 3f94b191
...@@ -159,4 +159,7 @@ cython_debug/ ...@@ -159,4 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
cora/ cora/
\ No newline at end of file /test_*
/*.ipynb
/s.py
\ No newline at end of file
...@@ -4,201 +4,351 @@ import torch.nn as nn ...@@ -4,201 +4,351 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from .buffer import TensorBuffers from torch_sparse import SparseTensor
from torch_scatter import scatter_min
from multiprocessing.pool import ThreadPool
from .buffer import TensorBuffer
from .route import Route
from .context import RouteContext
# class BatchHandle:
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# feat_buffer: TensorBuffers,
# grad_buffer: TensorBuffers,
# with_feat0: bool = False,
# layer_id: Optional[int] = None,
# target_device: Any = None,
# ) -> None:
# self.src_ids = src_ids
# self.dst_size = dst_size
# self.feat_buffer = feat_buffer
# self.grad_buffer = grad_buffer
# self.with_feat0 = with_feat0
# self.layer_id = layer_id
# if target_device is None:
# self.target_device = src_ids.device
# else:
# self.target_device = torch.device(target_device)
# if with_feat0:
# _ = self.feat0
# @property
# def dst_ids(self) -> Tensor:
# return self.src_ids[:self.dst_size]
# @property
# def src_size(self) -> int:
# return self.src_ids.size(0)
# @property
# def device(self) -> torch.device:
# return self.src_ids.device
# @property
# def feat0(self) -> Tensor:
# if not hasattr(self, "_feat0"):
# self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
# return self._feat0
# def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device)
# def update_feat(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
# return torch.cat([x, o], dim=0)
# def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device)
# def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.src_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class BatchHandle: class NodeHandle:
def __init__(self, def __init__(self,
src_ids: Tensor, src_ids: Tensor,
dst_size: int, dst_size: int,
feat_buffer: TensorBuffers, edge_ids: Tensor,
grad_buffer: TensorBuffers, adj_t: SparseTensor,
with_feat0: bool = False, executor: ThreadPool,
layer_id: Optional[int] = None, device: Any,
target_device: Any = None, stream: Optional[torch.cuda.Stream] = None,
edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
) -> None: ) -> None:
self.src_ids = src_ids self._src_ids = src_ids
self.dst_size = dst_size self._dst_size = dst_size
self.feat_buffer = feat_buffer self._edge_ids = edge_ids
self.grad_buffer = grad_buffer
self.with_feat0 = with_feat0 self._device = torch.device(device)
self.layer_id = layer_id if self._device != torch.device("cpu") and stream is None:
self._stream = torch.cuda.Stream(self._device)
if target_device is None:
self.target_device = src_ids.device
else: else:
self.target_device = torch.device(target_device) self._stream = stream
self._executor = executor
if with_feat0: self._adj_t = adj_t.to(self.device)
_ = self.feat0 self._node_time = node_time
self._edge_time = edge_time
def get_src_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.src_ids)
def get_dst_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.dst_ids)
def get_ext_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.ext_ids)
def get_edge_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.edge_ids)
def get_src_time(self):
return self.async_select(self._node_time, self.src_ids)
def get_dst_time(self):
return self.async_select(self._node_time, self.dst_ids)
def get_ext_time(self):
return self.async_select(self._node_time, self.ext_ids)
def get_edge_time(self) -> Tensor:
return self.async_select(self._edge_time, self.edge_ids)
def push_and_pull(self, value: Tensor, ext_fut: torch.futures.Future[Tensor], data: Union[Tensor, TensorBuffer]):
fut = self.async_update(data, value, self.dst_ids)
ext_value = ext_fut.wait()
x = torch.cat([value, ext_value], dim=0)
return x, fut
def async_select(self, src: Union[Tensor, TensorBuffer], idx: Tensor) -> torch.futures.Future[Tensor]:
fut = torch.futures.Future(devices=[self._device])
def run():
try:
with torch.no_grad():
with torch.cuda.stream(self._stream):
index = idx.to(src.device)
if isinstance(src, TensorBuffer):
val = src.local_get(index, lock=True)
else:
val = src.index_select(0, index)
val = val.to(self._device)
fut.set_result(val)
except Exception as e:
fut.set_exception(e)
self._executor.apply_async(run)
return fut
def async_update(self, src: Union[Tensor, TensorBuffer], val: Tensor, idx: Tensor, ops: str = "mov") -> torch.futures.Future:
assert ops in ["mov", "add"]
fut = torch.futures.Future(devices=[self._device])
def run():
try:
with torch.no_grad():
with torch.cuda.stream(self._stream):
value = val.to(src.device)
index = idx.to(src.device)
if isinstance(src, TensorBuffer):
if ops == "mov":
src.all_remote_set(value, index, lock=True)
elif ops == "add":
src.all_remote_add(value, index, lock=True)
else:
if ops == "mov":
src[index] = value
elif ops == "add":
src[index] += value
fut.set_result(None)
except Exception as e:
fut.set_exception(e)
self._executor.apply_async(run)
return fut
@property @property
def dst_ids(self) -> Tensor: def adj_t(self) -> SparseTensor:
return self.src_ids[:self.dst_size] return self._adj_t
@property
def src_ids(self) -> Tensor:
return self._src_ids
@property @property
def src_size(self) -> int: def src_size(self) -> int:
return self.src_ids.size(0) return self._src_ids.size(0)
@property @property
def device(self) -> torch.device: def dst_ids(self) -> Tensor:
return self.src_ids.device return self._src_ids[:self._dst_size]
@property @property
def feat0(self) -> Tensor: def dst_size(self) -> int:
if not hasattr(self, "_feat0"): return self._dst_size
self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
return self._feat0 @property
def ext_ids(self) -> Tensor:
def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor: return self._src_ids[self._dst_size:]
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device) @property
def ext_size(self) -> int:
def update_feat(self, x: Tensor, layer_id: Optional[int] = None): return self.src_size - self.dst_size
assert x.size(0) == self.dst_size
layer_id = int(self.layer_id if layer_id is None else layer_id) @property
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device)) def edge_ids(self) -> Tensor:
return self._edge_ids
def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
assert x.size(0) == self.dst_size @property
layer_id = int(self.layer_id if layer_id is None else layer_id) def edge_size(self) -> Tensor:
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device)) return self._edge_ids.size(0)
o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
return torch.cat([x, o], dim=0) @property
def device(self) -> torch.device:
def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor: return self._device
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device) class NodeLoader:
def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
assert x.size(0) == self.src_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class DataLoader:
def __init__(self, def __init__(self,
node_parts: Tensor, global_ids: Tensor,
feat_buffer: TensorBuffers, global_edges: Tensor,
grad_buffer: TensorBuffers, device: Any,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None, edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None, node_time: Optional[Tensor] = None,
num_threads: int = 1,
) -> None: ) -> None:
self.src_size = feat_buffer.src_size self.route, edge_index = Route.from_edge_index(global_ids, global_edges)
self.dst_size = node_parts.size(0)
assert node_parts.size(0) <= self.src_size
assert grad_buffer.src_size == self.src_size
num_parts = node_parts.max().item() + 1 if node_time is None and edge_time is not None:
cluster, node_perm = node_parts.sort(dim=0) node_time = scatter_min(edge_time[edge_index[0]], edge_index[1], dim=0, dim_size=self.dst_size)
node_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
node_ids = torch.arange(self.dst_size).type_as(global_ids)
if node_time is not None:
perm = node_time.argsort()
node_ids = node_ids[perm]
edge_parts = node_parts[edge_index[1]] self.node_ids = node_ids
cluster, edge_perm = edge_parts.sort() self.node_time = node_time
edge_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
self.num_parts = num_parts edge_ids = torch.arange(edge_index.size(1)).type_as(global_ids)
self.node_ptr = node_ptr perm = edge_index[1].argsort()
self.edge_ptr = edge_ptr
self.node_perm = node_perm
self.edge_perm = edge_perm
self.edge_index = edge_index[:,edge_perm]
if edge_attr is not None: edge_ids = edge_ids[perm]
self.edge_attr = edge_attr[edge_perm] edge_index = edge_index[:,perm]
if node_time is not None: self.edge_ids = edge_ids
self.node_time = node_time[node_perm] self.edge_index = edge_index
self.edge_time = edge_time
# self.rowptr = torch.ops.torch_sparse.ind2ptr(self.edge_index[1], self.dst_size)
if edge_time is not None: self.device = torch.device(device)
self.edge_time = edge_time[edge_perm] if self.device != torch.device("cpu"):
self.stream = torch.cuda.Stream(self.device)
else:
self.stream = None
self.executor = ThreadPool(num_threads)
@property
def src_size(self) -> int:
return self.route.src_size
@property
def dst_size(self) -> int:
return self.route.dst_size
# def _select_nodes(self, start_time: int, end_time: int) -> Tuple[Optional[Tensor], Optional[Tensor]]:
# if self.node_time is None:
# if self.edge_time is None:
# raise RuntimeError("neither node_time nor edge_time exists!")
# else:
# edge_mask = (start_time <= self.edge_time) & (self.edge_time < end_time)
# idx = self.edge_index[1, edge_mask]
# node_mask = torch.zeros(self.dst_size, dtype=torch.bool, device=idx.device).index_fill_(0, idx, 1)
# return torch.where(node_mask)[0], None
# else:
# time_range = torch.tensor([start_time, end_time]).type_as(self.node_time)
# s, t = torch.searchsorted(self.node_time, time_range).tolist()
# return self.node_ids[s:t], self.node_time[s:t]
def _get_node_handle(self, node_ids: Tensor, edge_ids: Tensor, edge_index: Tensor) -> NodeHandle:
num_node: int = torch.max(node_ids.max(), edge_index.max()).item() + 1
imp = torch.zeros(num_node, dtype=torch.bool, device=node_ids.device)
imp.index_fill_(0, edge_index[0], 1)
imp.index_fill_(0, node_ids, 0)
src_ids = torch.cat([node_ids, torch.where(imp)[0]], dim=0)
src_size = src_ids.size(0)
dst_size = node_ids.size(0)
self.feat_buffer = feat_buffer imp = torch.zeros(num_node, dtype=torch.long, device=node_ids.device).fill_((2**62-1)*2+1)
self.grad_buffer = grad_buffer imp[src_ids] = torch.arange(src_size, dtype=torch.long, device=node_ids.device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
perm = edge_index[1].argsort()
edge_ids = edge_ids[perm]
edge_index = edge_index[:,perm]
rowptr = torch.ops.torch_sparse.ind2ptr(edge_index[1], dst_size)
adj_t = SparseTensor(rowptr=rowptr, col=edge_index[0], sparse_sizes=(dst_size, src_size))
return NodeHandle(
src_ids=src_ids,
dst_size=dst_size,
edge_ids=edge_ids,
adj_t=adj_t,
executor=self.executor,
device=self.device,
stream=self.stream,
edge_time=self.edge_time,
node_time=self.node_time,
)
def iter(self, batch_size: int = 1, layer_id: Optional[int] = None, seed: int = 0, filter: Callable[[Tensor], Tensor] = None, device = None): # very very very slow !!!
rnd = torch.Generator() def iter(self, batch_size: int, wind: bool = False, start_time: int = 0, end_time: int = 9223372036854775807, seed: int = 0):
rand_gen = torch.Generator()
if seed != 0: if seed != 0:
rnd.manual_seed(seed) rand_gen.manual_seed(seed)
sampled = torch.randperm(self.num_parts, generator=rnd, dtype=torch.long)
s = 0 if not wind:
imp = torch.empty(self.src_size, dtype=torch.long) sampled = torch.randperm(self.dst_size, generator=rand_gen, dtype=torch.long)
while s < sampled.size(0): for s in range(0, sampled.size(0), batch_size):
t = min(s + batch_size, sampled.size(0)) t = min(s + batch_size, sampled.size(0))
node_ids = sampled[s:t]
dst_ids = [] imp = torch.zeros(self.dst_size, dtype=torch.bool, device=node_ids.device)
edge_index = [] imp.index_fill_(0, node_ids, 1)
edge_attr = [] edge_mask = imp[self.edge_index[1]]
edge_ids = self.edge_ids[edge_mask]
imp.zero_() edge_index = self.edge_index[:,edge_mask]
for i in sampled[s:t].tolist(): yield self._get_node_handle(node_ids, edge_ids, edge_index)
s += batch_size else:
assert self.node_time is not None
a, b = self.node_ptr[i:i+2].tolist() node_mask = (start_time <= self.node_time) & (self.node_time < end_time)
nidx = self.node_perm[a:b] all_node_ids = torch.where(node_mask)[0]
if hasattr(self, "node_time") and filter is not None:
node_mask = filter(self.node_time[a:b]) sampled = torch.randperm(node_ids.size(0), generator=rand_gen, dtype=torch.long)
if not node_mask.any(): all_node_ids = all_node_ids[sampled]
continue
nidx = nidx[node_mask] for s in range(0, sampled.size(0), batch_size):
else: t = min(s + batch_size, sampled.size(0))
node_mask = None node_ids = all_node_ids[s:t]
a, b = self.edge_ptr[i:i+2].tolist()
eidx = self.edge_index[:, a:b]
if hasattr(self, "edge_time") and filter is not None:
if node_mask is None:
edge_mask = filter(self.edge_time[a:b])
else:
imp[nidx] = i+1
edge_mask = (imp[eidx[1]] == i+1)
edge_mask &= filter(self.edge_time[a:b])
if not edge_mask.any():
continue
eidx = eidx[:, edge_mask]
else:
edge_mask = None
dst_ids.append(nidx) imp = torch.zeros(self.dst_size, dtype=torch.bool, device=node_ids.device)
edge_index.append(eidx) imp.index_fill_(0, node_ids, 1)
edge_mask = imp[self.edge_index[1]]
if hasattr(self, "edge_attr"): if self.edge_time is not None:
attr = self.edge_attr[a:b] edge_time = self.edge_time[self.edge_ids]
if edge_mask is None: edge_mask &= (start_time <= edge_time) & (edge_time < end_time)
edge_attr.append(attr) edge_ids = self.edge_ids[edge_mask]
else: edge_index = self.edge_index[:,edge_mask]
edge_attr.append(attr[edge_mask]) yield self._get_node_handle(node_ids, edge_ids, edge_index)
if len(dst_ids) == 0 or len(edge_index) == 0:
continue
dst_ids = torch.cat(dst_ids, dim=-1)
edge_index = torch.cat(edge_index, dim=-1)
if hasattr(self, "edge_attr"):
edge_attr = torch.cat(edge_attr, dim=0)
else:
edge_attr = None
imp.zero_()
imp.index_fill_(0, edge_index[0], 1).index_fill_(0, dst_ids, 0)
src_ids = torch.cat([dst_ids, torch.where(imp > 0)[0]], dim=-1)
assert (src_ids[:dst_ids.size(0)] == dst_ids).all()
imp.fill_((2**62-1)*2+1)
imp[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
handle = BatchHandle(
src_ids, dst_ids.size(0),
self.feat_buffer, self.grad_buffer,
with_feat0 = (layer_id is None),
layer_id = layer_id,
target_device = device,
)
edge_index = edge_index.to(device)
edge_attr = edge_attr.to(device)
yield handle, edge_index, edge_attr
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffers(nn.Module):
def __init__(self, src_size: int, channels: Tuple[Union[int, Tuple[int]]]) -> None:
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__() super().__init__()
self.src_size = src_size self.channels = channels
self.num_layers = len(channels) self.num_nodes = num_nodes
for i, s in enumerate(channels): self.route = route
s = (s,) if isinstance(s, int) else s
self.register_buffer(f"data_{i}", torch.zeros(src_size, *s), persistent=False) self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
def get(self, i: int, index: Optional[Tensor]) -> Tensor: @property
i = self._idx(i) def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None: if index is None:
return self.get_buffer(f"data_{i}") return self.data
else: else:
return self.get_buffer(f"data_{i}")[index] return self.data[index]
def set(self, i: int, index: Optional[Tensor], value: Tensor): def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
i = self._idx(i) if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None: if index is None:
self.get_buffer(f"data_{i}")[:,...] = value self.data.copy_(value)
else: else:
self.get_buffer(f"data_{i}")[index] = value # value = value.to(self.device)
self.data[index] = value
def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None:
self.data.add_(value)
else:
# value = value.to(self.device)
self.data[index] += value
def add(self, i: int, index: Optional[Tensor], value: Tensor): def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
i = self._idx(i) if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None: if index is None:
self.get_buffer(f"data_{i}")[:,...] += value self.data.zero_()
else: else:
self.get_buffer(f"data_{i}")[index] += value self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
def zero_grad(self): @staticmethod
for name, grad in self.named_buffers("data_", recurse=False): def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
grad.zero_() args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
def _idx(self, i: int) -> int: @staticmethod
assert -self.num_layers < i and i < self.num_layers def _method_call(method, rref: rpc.RRef, *args, **kwargs):
return (self.num_layers + i) % self.num_layers self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
...@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int: ...@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids] x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x) return max(*x)
def collect_feat0(src_ids: Tensor, dst_ids: Tensor, feat0: Tensor):
device = get_compute_device()
route = Route(dst_ids.to(device), src_ids.to(device))
return route.forward_a2a(feat0.to(device))[0].to(feat0.device)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor: def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size] edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0] return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.distributed as dist # import torch.distributed as dist
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from starrygl.loader import BatchHandle # from starrygl.loader import BatchHandle
class BaseLayer(nn.Module): # class BaseLayer(nn.Module):
def __init__(self) -> None: # def __init__(self) -> None:
super().__init__() # super().__init__()
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor: # def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return x # return x
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None): # def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat() # x = handle.fetch_feat()
with torch.no_grad(): # with torch.no_grad():
x = self.forward(x, edge_index, edge_attr) # x = self.forward(x, edge_index, edge_attr)
handle.update_feat(x) # handle.update_feat(x)
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None): # def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat().requires_grad_() # x = handle.fetch_feat().requires_grad_()
g = handle.fetch_grad() # g = handle.fetch_grad()
self.forward(x, edge_index, edge_attr).backward(g) # self.forward(x, edge_index, edge_attr).backward(g)
handle.accumulate_grad(x.grad) # handle.accumulate_grad(x.grad)
x.grad = None # x.grad = None
def all_reduce_grad(self): # def all_reduce_grad(self):
for p in self.parameters(): # for p in self.parameters():
if p.grad is not None: # if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) # dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class BaseModel(nn.Module): # class BaseModel(nn.Module):
def __init__(self, # def __init__(self,
num_features: int, # num_features: int,
layers: List[int], # layers: List[int],
prev_layer: bool = False, # prev_layer: bool = False,
post_layer: bool = False, # post_layer: bool = False,
) -> None: # ) -> None:
super().__init__() # super().__init__()
def init_prev_layer(self) -> Tensor: # def init_prev_layer(self) -> Tensor:
pass # pass
def init_post_layer(self) -> Tensor: # def init_post_layer(self) -> Tensor:
pass # pass
def init_conv_layer(self) -> Tensor: # def init_conv_layer(self) -> Tensor:
pass # pass
\ No newline at end of file \ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.rpc as rpc
import os import os
from typing import * from typing import *
...@@ -15,29 +16,36 @@ def convert_parallel_model( ...@@ -15,29 +16,36 @@ def convert_parallel_model(
net = SyncBatchNorm.convert_sync_batchnorm(net) net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net, net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters, find_unused_parameters=find_unused_parameters,
# broadcast_buffers=False,
) )
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return net return net
def init_process_group(backend: str = "gloo") -> torch.device: def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None: if local_rank is not None:
local_rank = int(local_rank) local_rank = int(local_rank)
dist.init_process_group(backend)
if backend == "nccl" or backend == "mpi": if backend == "nccl" or backend == "mpi":
if local_rank is None: device = torch.device(f"cuda:{local_rank or rank}")
device = torch.device(f"cuda:{dist.get_rank()}")
else:
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
device = torch.device("cpu") device = torch.device("cpu")
...@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device: ...@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
_COMPUTE_DEVICE = device _COMPUTE_DEVICE = device
return device return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu") _COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device: def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE global _COMPUTE_DEVICE
return _COMPUTE_DEVICE return _COMPUTE_DEVICE
\ No newline at end of file
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch # import torch
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from torch_scatter import scatter_sum # from torch_scatter import scatter_sum
from starrygl.core.route import Route # from starrygl.core.route import Route
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor: # def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size = route.src_size # dst_size = route.src_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device) # x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size) # in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg, _ = route.forward_a2a(in_deg) # in_deg, _ = route.forward_a2a(in_deg)
return in_deg # return in_deg
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor: # def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size = route.dst_size # src_size = route.dst_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device) # x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size) # out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg, _ = route.backward_a2a(out_deg) # out_deg, _ = route.backward_a2a(out_deg)
out_deg, _ = route.forward_a2a(out_deg) # out_deg, _ = route.forward_a2a(out_deg)
return out_deg # return out_deg
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor: # def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg = compute_in_degree(edge_index, route) # in_deg = compute_in_degree(edge_index, route)
out_deg = compute_out_degree(edge_index, route) # out_deg = compute_out_degree(edge_index, route)
a = in_deg[edge_index[0]].pow(-0.5) # a = in_deg[edge_index[0]].pow(-0.5)
b = out_deg[edge_index[0]].pow(-0.5) # b = out_deg[edge_index[0]].pow(-0.5)
x = a * b # x = a * b
x[x.isinf()] = 0.0 # x[x.isinf()] = 0.0
x[x.isnan()] = 0.0 # x[x.isnan()] = 0.0
return x # return x
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import os from torch_sparse import SparseTensor
import time
import psutil
from starrygl.nn import *
from starrygl.graph import DistGraph
from starrygl.loader import NodeLoader, NodeHandle, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group, convert_parallel_model from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl from starrygl.utils import partition_load, main_print, sync_print
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print class SimpleGNNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x: Tensor, adj_t: SparseTensor) -> Tensor:
x = self.linear(x)
x = adj_t @ x
return self.norm(x)
class SimpleGNN(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.conv1 = SimpleGNNConv(in_channels, hidden_channels)
self.conv2 = SimpleGNNConv(hidden_channels, hidden_channels)
self.fc_out = nn.Linear(hidden_channels, out_channels)
def forward(self, handle: NodeHandle, buffers: List[TensorBuffer]) -> Tensor:
futs = [
handle.get_src_feats(buffers[0]),
handle.get_ext_feats(buffers[1]),
]
with RouteContext() as ctx:
x = futs[0].wait()
x = self.conv1(x, handle.adj_t)
x, f = handle.push_and_pull(x, futs[1], buffers[1])
x = self.conv2(x, handle.adj_t)
ctx.add_futures(f) # 等当前batch推理完成后需要等待所有futures完成
x = self.fc_out(x)
return x
if __name__ == "__main__": if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备 # 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl") device = init_process_group(backend="nccl")
# 加载数据集 # 加载数据集
pdata = partition_load("./cora", algo="metis").to(device) pdata = partition_load("./cora", algo="metis")
g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index) loader = NodeLoader(pdata.ids, pdata.edge_index, device)
# 创建历史缓存
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
# 设置节点初始特征,并预同步到其它分区
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
# g.args["async_op"] = True # 创建模型
g.args["sample_k"] = 20 net = SimpleGNN(pdata.num_features, hidden_size, pdata.num_classes).to(device)
net = convert_parallel_model(net)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# 训练阶段
for ep in range(1, 100+1):
epoch_loss = 0.0
net.train()
for handle in loader.iter(128):
fut_m = handle.get_dst_feats(pdata.train_mask)
fut_y = handle.get_dst_feats(pdata.y)
h = net(handle, buffers)
train_mask = fut_m.wait()
logits = h[train_mask]
if logits.size(0) > 0:
y = fut_y.wait()[train_mask]
loss = nn.CrossEntropyLoss()(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
epoch_loss += loss.item()
main_print(ep, epoch_loss)
rpc.shutdown()
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# import os
# import time
# import psutil
# from starrygl.nn import *
# from starrygl.graph import DistGraph
# from starrygl.parallel import init_process_group, convert_parallel_model
# from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
# from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
# if __name__ == "__main__":
# # 启动分布式进程组,并分配计算设备
# device = init_process_group(backend="nccl")
g.edata["gcn_norm"] = compute_gcn_norm(g) # # 加载数据集
g.ndata["x"] = pdata.x # pdata = partition_load("./cora", algo="metis").to(device)
g.ndata["y"] = pdata.y # g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# 定义GAT图神经网络模型 # # g.args["async_op"] = True
net = GCN( # # g.args["num_samples"] = 20
g=g,
layer_options=BasicLayerOptions(
in_channels=pdata.num_features,
hidden_channels=64,
num_layers=2,
out_channels=pdata.num_classes,
norm="batchnorm",
),
input_options=BasicInputOptions(
straight_enabled=True,
),
jk_options=BasicJKOptions(
jk_mode=None,
),
straight_options=BasicStraightOptions(
enabled=True,
),
).to(device)
# 转换成分布式并行版本 # g.edata["gcn_norm"] = compute_gcn_norm(g)
net = convert_parallel_model(net) # g.ndata["x"] = pdata.x
# g.ndata["y"] = pdata.y
# # 定义GAT图神经网络模型
# net = ShrinkGCN(
# g=g,
# layer_options=BasicLayerOptions(
# in_channels=pdata.num_features,
# hidden_channels=64,
# num_layers=3,
# out_channels=pdata.num_classes,
# norm="batchnorm",
# ),
# input_options=BasicInputOptions(
# straight_enabled=True,
# straight_num_samples = 200,
# ),
# straight_options=BasicStraightOptions(
# enabled=True,
# num_samples = 20,
# # beta=1.1,
# ),
# ).to(device)
# # 转换成分布式并行版本
# net = convert_parallel_model(net)
# 定义优化器 # # 定义优化器
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4) # opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
avg_mem = 0.0 # avg_mem = 0.0
avg_dur = 0.0 # avg_dur = 0.0
avg_num = 0 # avg_num = 0
# 开始训练 # # 开始训练
best_val_acc = best_test_acc = 0 # best_val_acc = best_test_acc = 0
for ep in range(1, 10+1): # for ep in range(1, 50+1):
time_start = time.time() # time_start = time.time()
train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask) # train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
val_loss, val_acc = eval_epoch(net, g, pdata.val_mask) # val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
test_loss, test_acc = eval_epoch(net, g, pdata.test_mask) # test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# val_loss, val_acc = train_loss, train_acc # # val_loss, val_acc = train_loss, train_acc
# test_loss, test_acc = train_loss, train_acc # # test_loss, test_acc = train_loss, train_acc
if val_acc > best_val_acc: # if val_acc > best_val_acc:
best_val_acc = val_acc # best_val_acc = val_acc
best_test_acc = test_acc # best_test_acc = test_acc
duration = time.time() - time_start # duration = time.time() - time_start
if with_nccl(): # if with_nccl():
cur_mem = torch.cuda.memory_reserved() # cur_mem = torch.cuda.memory_reserved()
else: # else:
cur_mem = psutil.Process(os.getpid()).memory_info().rss # cur_mem = psutil.Process(os.getpid()).memory_info().rss
cur_mem_mb = round(cur_mem / 1024**2) # cur_mem_mb = round(cur_mem / 1024**2)
if ep > 1: # if ep > 1:
avg_mem += cur_mem # avg_mem += cur_mem
avg_dur += duration # avg_dur += duration
avg_num += 1 # avg_num += 1
main_print( # main_print(
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, " # f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], " # f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], " # f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f"best_accuracy: {best_test_acc:.4f}") # f"best_accuracy: {best_test_acc:.4f}")
avg_mem = round(avg_mem / avg_num / 1024**2) # avg_mem = round(avg_mem / avg_num / 1024**2)
avg_dur = avg_dur / avg_num # avg_dur = avg_dur / avg_num
main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s") # main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment