Commit f4af3231 by Wenjie Huang

rpc based

parent 3f94b191
......@@ -159,4 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
cora/
\ No newline at end of file
cora/
/test_*
/*.ipynb
/s.py
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffers(nn.Module):
def __init__(self, src_size: int, channels: Tuple[Union[int, Tuple[int]]]) -> None:
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__()
self.src_size = src_size
self.num_layers = len(channels)
for i, s in enumerate(channels):
s = (s,) if isinstance(s, int) else s
self.register_buffer(f"data_{i}", torch.zeros(src_size, *s), persistent=False)
self.channels = channels
self.num_nodes = num_nodes
self.route = route
self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
def get(self, i: int, index: Optional[Tensor]) -> Tensor:
i = self._idx(i)
@property
def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None:
return self.get_buffer(f"data_{i}")
return self.data
else:
return self.get_buffer(f"data_{i}")[index]
def set(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
return self.data[index]
def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None:
self.get_buffer(f"data_{i}")[:,...] = value
self.data.copy_(value)
else:
self.get_buffer(f"data_{i}")[index] = value
# value = value.to(self.device)
self.data[index] = value
def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None:
self.data.add_(value)
else:
# value = value.to(self.device)
self.data[index] += value
def add(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None:
self.get_buffer(f"data_{i}")[:,...] += value
self.data.zero_()
else:
self.get_buffer(f"data_{i}")[index] += value
self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
def zero_grad(self):
for name, grad in self.named_buffers("data_", recurse=False):
grad.zero_()
@staticmethod
def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
def _idx(self, i: int) -> int:
assert -self.num_layers < i and i < self.num_layers
return (self.num_layers + i) % self.num_layers
@staticmethod
def _method_call(method, rref: rpc.RRef, *args, **kwargs):
self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
......@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def collect_feat0(src_ids: Tensor, dst_ids: Tensor, feat0: Tensor):
device = get_compute_device()
route = Route(dst_ids.to(device), src_ids.to(device))
return route.forward_a2a(feat0.to(device))[0].to(feat0.device)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
# import torch
# import torch.nn as nn
# import torch.distributed as dist
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.loader import BatchHandle
# from starrygl.loader import BatchHandle
class BaseLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
# class BaseLayer(nn.Module):
# def __init__(self) -> None:
# super().__init__()
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return x
# def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
# return x
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat()
with torch.no_grad():
x = self.forward(x, edge_index, edge_attr)
handle.update_feat(x)
# def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat()
# with torch.no_grad():
# x = self.forward(x, edge_index, edge_attr)
# handle.update_feat(x)
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat().requires_grad_()
g = handle.fetch_grad()
self.forward(x, edge_index, edge_attr).backward(g)
handle.accumulate_grad(x.grad)
x.grad = None
# def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat().requires_grad_()
# g = handle.fetch_grad()
# self.forward(x, edge_index, edge_attr).backward(g)
# handle.accumulate_grad(x.grad)
# x.grad = None
def all_reduce_grad(self):
for p in self.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# def all_reduce_grad(self):
# for p in self.parameters():
# if p.grad is not None:
# dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class BaseModel(nn.Module):
def __init__(self,
num_features: int,
layers: List[int],
prev_layer: bool = False,
post_layer: bool = False,
) -> None:
super().__init__()
def init_prev_layer(self) -> Tensor:
pass
def init_post_layer(self) -> Tensor:
pass
def init_conv_layer(self) -> Tensor:
pass
\ No newline at end of file
# class BaseModel(nn.Module):
# def __init__(self,
# num_features: int,
# layers: List[int],
# prev_layer: bool = False,
# post_layer: bool = False,
# ) -> None:
# super().__init__()
# def init_prev_layer(self) -> Tensor:
# pass
# def init_post_layer(self) -> Tensor:
# pass
# def init_conv_layer(self) -> Tensor:
# pass
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
......@@ -15,29 +16,36 @@ def convert_parallel_model(
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
# broadcast_buffers=False,
)
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
dist.init_process_group(backend)
if backend == "nccl" or backend == "mpi":
if local_rank is None:
device = torch.device(f"cuda:{dist.get_rank()}")
else:
device = torch.device(f"cuda:{local_rank}")
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
......@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
\ No newline at end of file
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from torch_scatter import scatter_sum
from starrygl.core.route import Route
# from torch_scatter import scatter_sum
# from starrygl.core.route import Route
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size = route.src_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
# dst_size = route.src_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg, _ = route.forward_a2a(in_deg)
return in_deg
# in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
# in_deg, _ = route.forward_a2a(in_deg)
# return in_deg
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size = route.dst_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
# src_size = route.dst_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg, _ = route.backward_a2a(out_deg)
out_deg, _ = route.forward_a2a(out_deg)
# out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
# out_deg, _ = route.backward_a2a(out_deg)
# out_deg, _ = route.forward_a2a(out_deg)
return out_deg
# return out_deg
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg = compute_in_degree(edge_index, route)
out_deg = compute_out_degree(edge_index, route)
# def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
# in_deg = compute_in_degree(edge_index, route)
# out_deg = compute_out_degree(edge_index, route)
a = in_deg[edge_index[0]].pow(-0.5)
b = out_deg[edge_index[0]].pow(-0.5)
x = a * b
x[x.isinf()] = 0.0
x[x.isnan()] = 0.0
return x
# a = in_deg[edge_index[0]].pow(-0.5)
# b = out_deg[edge_index[0]].pow(-0.5)
# x = a * b
# x[x.isinf()] = 0.0
# x[x.isnan()] = 0.0
# return x
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.nn import *
from starrygl.graph import DistGraph
from torch_sparse import SparseTensor
from starrygl.loader import NodeLoader, NodeHandle, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
from starrygl.utils import partition_load, main_print, sync_print
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print
class SimpleGNNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x: Tensor, adj_t: SparseTensor) -> Tensor:
x = self.linear(x)
x = adj_t @ x
return self.norm(x)
class SimpleGNN(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.conv1 = SimpleGNNConv(in_channels, hidden_channels)
self.conv2 = SimpleGNNConv(hidden_channels, hidden_channels)
self.fc_out = nn.Linear(hidden_channels, out_channels)
def forward(self, handle: NodeHandle, buffers: List[TensorBuffer]) -> Tensor:
futs = [
handle.get_src_feats(buffers[0]),
handle.get_ext_feats(buffers[1]),
]
with RouteContext() as ctx:
x = futs[0].wait()
x = self.conv1(x, handle.adj_t)
x, f = handle.push_and_pull(x, futs[1], buffers[1])
x = self.conv2(x, handle.adj_t)
ctx.add_futures(f) # 等当前batch推理完成后需要等待所有futures完成
x = self.fc_out(x)
return x
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
pdata = partition_load("./cora", algo="metis")
loader = NodeLoader(pdata.ids, pdata.edge_index, device)
# 创建历史缓存
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
# 设置节点初始特征,并预同步到其它分区
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
# g.args["async_op"] = True
g.args["sample_k"] = 20
# 创建模型
net = SimpleGNN(pdata.num_features, hidden_size, pdata.num_classes).to(device)
net = convert_parallel_model(net)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# 训练阶段
for ep in range(1, 100+1):
epoch_loss = 0.0
net.train()
for handle in loader.iter(128):
fut_m = handle.get_dst_feats(pdata.train_mask)
fut_y = handle.get_dst_feats(pdata.y)
h = net(handle, buffers)
train_mask = fut_m.wait()
logits = h[train_mask]
if logits.size(0) > 0:
y = fut_y.wait()[train_mask]
loss = nn.CrossEntropyLoss()(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
epoch_loss += loss.item()
main_print(ep, epoch_loss)
rpc.shutdown()
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# import os
# import time
# import psutil
# from starrygl.nn import *
# from starrygl.graph import DistGraph
# from starrygl.parallel import init_process_group, convert_parallel_model
# from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
# from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
# if __name__ == "__main__":
# # 启动分布式进程组,并分配计算设备
# device = init_process_group(backend="nccl")
g.edata["gcn_norm"] = compute_gcn_norm(g)
g.ndata["x"] = pdata.x
g.ndata["y"] = pdata.y
# # 加载数据集
# pdata = partition_load("./cora", algo="metis").to(device)
# g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# 定义GAT图神经网络模型
net = GCN(
g=g,
layer_options=BasicLayerOptions(
in_channels=pdata.num_features,
hidden_channels=64,
num_layers=2,
out_channels=pdata.num_classes,
norm="batchnorm",
),
input_options=BasicInputOptions(
straight_enabled=True,
),
jk_options=BasicJKOptions(
jk_mode=None,
),
straight_options=BasicStraightOptions(
enabled=True,
),
).to(device)
# # g.args["async_op"] = True
# # g.args["num_samples"] = 20
# 转换成分布式并行版本
net = convert_parallel_model(net)
# g.edata["gcn_norm"] = compute_gcn_norm(g)
# g.ndata["x"] = pdata.x
# g.ndata["y"] = pdata.y
# # 定义GAT图神经网络模型
# net = ShrinkGCN(
# g=g,
# layer_options=BasicLayerOptions(
# in_channels=pdata.num_features,
# hidden_channels=64,
# num_layers=3,
# out_channels=pdata.num_classes,
# norm="batchnorm",
# ),
# input_options=BasicInputOptions(
# straight_enabled=True,
# straight_num_samples = 200,
# ),
# straight_options=BasicStraightOptions(
# enabled=True,
# num_samples = 20,
# # beta=1.1,
# ),
# ).to(device)
# # 转换成分布式并行版本
# net = convert_parallel_model(net)
# 定义优化器
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
# # 定义优化器
# opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
avg_mem = 0.0
avg_dur = 0.0
avg_num = 0
# 开始训练
best_val_acc = best_test_acc = 0
for ep in range(1, 10+1):
time_start = time.time()
train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# val_loss, val_acc = train_loss, train_acc
# test_loss, test_acc = train_loss, train_acc
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
# avg_mem = 0.0
# avg_dur = 0.0
# avg_num = 0
# # 开始训练
# best_val_acc = best_test_acc = 0
# for ep in range(1, 50+1):
# time_start = time.time()
# train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
# val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
# test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# # val_loss, val_acc = train_loss, train_acc
# # test_loss, test_acc = train_loss, train_acc
# if val_acc > best_val_acc:
# best_val_acc = val_acc
# best_test_acc = test_acc
duration = time.time() - time_start
if with_nccl():
cur_mem = torch.cuda.memory_reserved()
else:
cur_mem = psutil.Process(os.getpid()).memory_info().rss
cur_mem_mb = round(cur_mem / 1024**2)
if ep > 1:
avg_mem += cur_mem
avg_dur += duration
avg_num += 1
# duration = time.time() - time_start
# if with_nccl():
# cur_mem = torch.cuda.memory_reserved()
# else:
# cur_mem = psutil.Process(os.getpid()).memory_info().rss
# cur_mem_mb = round(cur_mem / 1024**2)
# if ep > 1:
# avg_mem += cur_mem
# avg_dur += duration
# avg_num += 1
main_print(
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f"best_accuracy: {best_test_acc:.4f}")
avg_mem = round(avg_mem / avg_num / 1024**2)
avg_dur = avg_dur / avg_num
main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
# main_print(
# f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
# f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
# f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
# f"best_accuracy: {best_test_acc:.4f}")
# avg_mem = round(avg_mem / avg_num / 1024**2)
# avg_dur = avg_dur / avg_num
# main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment