Commit f4af3231 by Wenjie Huang

rpc based

parent 3f94b191
...@@ -160,3 +160,6 @@ cython_debug/ ...@@ -160,3 +160,6 @@ cython_debug/
#.idea/ #.idea/
cora/ cora/
/test_*
/*.ipynb
/s.py
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffers(nn.Module):
def __init__(self, src_size: int, channels: Tuple[Union[int, Tuple[int]]]) -> None:
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__() super().__init__()
self.src_size = src_size self.channels = channels
self.num_layers = len(channels) self.num_nodes = num_nodes
for i, s in enumerate(channels): self.route = route
s = (s,) if isinstance(s, int) else s
self.register_buffer(f"data_{i}", torch.zeros(src_size, *s), persistent=False) self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
def get(self, i: int, index: Optional[Tensor]) -> Tensor: @property
i = self._idx(i) def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None: if index is None:
return self.get_buffer(f"data_{i}") return self.data
else: else:
return self.get_buffer(f"data_{i}")[index] return self.data[index]
def set(self, i: int, index: Optional[Tensor], value: Tensor): def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
i = self._idx(i) if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None: if index is None:
self.get_buffer(f"data_{i}")[:,...] = value self.data.copy_(value)
else: else:
self.get_buffer(f"data_{i}")[index] = value # value = value.to(self.device)
self.data[index] = value
def add(self, i: int, index: Optional[Tensor], value: Tensor): def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
i = self._idx(i) if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None: if index is None:
self.get_buffer(f"data_{i}")[:,...] += value self.data.add_(value)
else: else:
self.get_buffer(f"data_{i}")[index] += value # value = value.to(self.device)
self.data[index] += value
def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None:
self.data.zero_()
else:
self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
def zero_grad(self): @staticmethod
for name, grad in self.named_buffers("data_", recurse=False): def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
grad.zero_() args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
def _idx(self, i: int) -> int: @staticmethod
assert -self.num_layers < i and i < self.num_layers def _method_call(method, rref: rpc.RRef, *args, **kwargs):
return (self.num_layers + i) % self.num_layers self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
...@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int: ...@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids] x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x) return max(*x)
def collect_feat0(src_ids: Tensor, dst_ids: Tensor, feat0: Tensor):
device = get_compute_device()
route = Route(dst_ids.to(device), src_ids.to(device))
return route.forward_a2a(feat0.to(device))[0].to(feat0.device)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor: def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size] edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0] return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.distributed as dist # import torch.distributed as dist
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from starrygl.loader import BatchHandle # from starrygl.loader import BatchHandle
class BaseLayer(nn.Module): # class BaseLayer(nn.Module):
def __init__(self) -> None: # def __init__(self) -> None:
super().__init__() # super().__init__()
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor: # def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return x # return x
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None): # def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat() # x = handle.fetch_feat()
with torch.no_grad(): # with torch.no_grad():
x = self.forward(x, edge_index, edge_attr) # x = self.forward(x, edge_index, edge_attr)
handle.update_feat(x) # handle.update_feat(x)
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None): # def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat().requires_grad_() # x = handle.fetch_feat().requires_grad_()
g = handle.fetch_grad() # g = handle.fetch_grad()
self.forward(x, edge_index, edge_attr).backward(g) # self.forward(x, edge_index, edge_attr).backward(g)
handle.accumulate_grad(x.grad) # handle.accumulate_grad(x.grad)
x.grad = None # x.grad = None
def all_reduce_grad(self): # def all_reduce_grad(self):
for p in self.parameters(): # for p in self.parameters():
if p.grad is not None: # if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) # dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class BaseModel(nn.Module): # class BaseModel(nn.Module):
def __init__(self, # def __init__(self,
num_features: int, # num_features: int,
layers: List[int], # layers: List[int],
prev_layer: bool = False, # prev_layer: bool = False,
post_layer: bool = False, # post_layer: bool = False,
) -> None: # ) -> None:
super().__init__() # super().__init__()
def init_prev_layer(self) -> Tensor: # def init_prev_layer(self) -> Tensor:
pass # pass
def init_post_layer(self) -> Tensor: # def init_post_layer(self) -> Tensor:
pass # pass
def init_conv_layer(self) -> Tensor: # def init_conv_layer(self) -> Tensor:
pass # pass
\ No newline at end of file \ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.rpc as rpc
import os import os
from typing import * from typing import *
...@@ -15,29 +16,36 @@ def convert_parallel_model( ...@@ -15,29 +16,36 @@ def convert_parallel_model(
net = SyncBatchNorm.convert_sync_batchnorm(net) net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net, net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters, find_unused_parameters=find_unused_parameters,
# broadcast_buffers=False,
) )
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return net return net
def init_process_group(backend: str = "gloo") -> torch.device: def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK") local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None: if local_rank is not None:
local_rank = int(local_rank) local_rank = int(local_rank)
dist.init_process_group(backend)
if backend == "nccl" or backend == "mpi": if backend == "nccl" or backend == "mpi":
if local_rank is None: device = torch.device(f"cuda:{local_rank or rank}")
device = torch.device(f"cuda:{dist.get_rank()}")
else:
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
device = torch.device("cpu") device = torch.device("cpu")
...@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device: ...@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
_COMPUTE_DEVICE = device _COMPUTE_DEVICE = device
return device return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu") _COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device: def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE global _COMPUTE_DEVICE
return _COMPUTE_DEVICE return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch # import torch
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from torch_scatter import scatter_sum # from torch_scatter import scatter_sum
from starrygl.core.route import Route # from starrygl.core.route import Route
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor: # def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size = route.src_size # dst_size = route.src_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device) # x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size) # in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg, _ = route.forward_a2a(in_deg) # in_deg, _ = route.forward_a2a(in_deg)
return in_deg # return in_deg
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor: # def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size = route.dst_size # src_size = route.dst_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device) # x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size) # out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg, _ = route.backward_a2a(out_deg) # out_deg, _ = route.backward_a2a(out_deg)
out_deg, _ = route.forward_a2a(out_deg) # out_deg, _ = route.forward_a2a(out_deg)
return out_deg # return out_deg
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor: # def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg = compute_in_degree(edge_index, route) # in_deg = compute_in_degree(edge_index, route)
out_deg = compute_out_degree(edge_index, route) # out_deg = compute_out_degree(edge_index, route)
a = in_deg[edge_index[0]].pow(-0.5) # a = in_deg[edge_index[0]].pow(-0.5)
b = out_deg[edge_index[0]].pow(-0.5) # b = out_deg[edge_index[0]].pow(-0.5)
x = a * b # x = a * b
x[x.isinf()] = 0.0 # x[x.isinf()] = 0.0
x[x.isnan()] = 0.0 # x[x.isnan()] = 0.0
return x # return x
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import os from torch_sparse import SparseTensor
import time
import psutil
from starrygl.nn import *
from starrygl.graph import DistGraph
from starrygl.loader import NodeLoader, NodeHandle, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group, convert_parallel_model from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl from starrygl.utils import partition_load, main_print, sync_print
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print class SimpleGNNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x: Tensor, adj_t: SparseTensor) -> Tensor:
x = self.linear(x)
x = adj_t @ x
return self.norm(x)
class SimpleGNN(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.conv1 = SimpleGNNConv(in_channels, hidden_channels)
self.conv2 = SimpleGNNConv(hidden_channels, hidden_channels)
self.fc_out = nn.Linear(hidden_channels, out_channels)
def forward(self, handle: NodeHandle, buffers: List[TensorBuffer]) -> Tensor:
futs = [
handle.get_src_feats(buffers[0]),
handle.get_ext_feats(buffers[1]),
]
with RouteContext() as ctx:
x = futs[0].wait()
x = self.conv1(x, handle.adj_t)
x, f = handle.push_and_pull(x, futs[1], buffers[1])
x = self.conv2(x, handle.adj_t)
ctx.add_futures(f) # 等当前batch推理完成后需要等待所有futures完成
x = self.fc_out(x)
return x
if __name__ == "__main__": if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备 # 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl") device = init_process_group(backend="nccl")
# 加载数据集 # 加载数据集
pdata = partition_load("./cora", algo="metis").to(device) pdata = partition_load("./cora", algo="metis")
g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index) loader = NodeLoader(pdata.ids, pdata.edge_index, device)
# g.args["async_op"] = True # 创建历史缓存
g.args["sample_k"] = 20 hidden_size = 64
buffers: List[TensorBuffer] = [
g.edata["gcn_norm"] = compute_gcn_norm(g) TensorBuffer(pdata.num_features, loader.src_size, loader.route),
g.ndata["x"] = pdata.x TensorBuffer(hidden_size, loader.src_size, loader.route),
g.ndata["y"] = pdata.y ]
# 定义GAT图神经网络模型 # 设置节点初始特征,并预同步到其它分区
net = GCN( buffers[0].data[:loader.dst_size] = pdata.x
g=g, buffers[0].broadcast()
layer_options=BasicLayerOptions(
in_channels=pdata.num_features, # 创建模型
hidden_channels=64, net = SimpleGNN(pdata.num_features, hidden_size, pdata.num_classes).to(device)
num_layers=2,
out_channels=pdata.num_classes,
norm="batchnorm",
),
input_options=BasicInputOptions(
straight_enabled=True,
),
jk_options=BasicJKOptions(
jk_mode=None,
),
straight_options=BasicStraightOptions(
enabled=True,
),
).to(device)
# 转换成分布式并行版本
net = convert_parallel_model(net) net = convert_parallel_model(net)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# 训练阶段
for ep in range(1, 100+1):
epoch_loss = 0.0
net.train()
for handle in loader.iter(128):
fut_m = handle.get_dst_feats(pdata.train_mask)
fut_y = handle.get_dst_feats(pdata.y)
# 定义优化器 h = net(handle, buffers)
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
train_mask = fut_m.wait()
avg_mem = 0.0 logits = h[train_mask]
avg_dur = 0.0 if logits.size(0) > 0:
avg_num = 0 y = fut_y.wait()[train_mask]
loss = nn.CrossEntropyLoss()(logits, y)
# 开始训练
best_val_acc = best_test_acc = 0 opt.zero_grad()
for ep in range(1, 10+1): loss.backward()
time_start = time.time() opt.step()
train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask) epoch_loss += loss.item()
val_loss, val_acc = eval_epoch(net, g, pdata.val_mask) main_print(ep, epoch_loss)
test_loss, test_acc = eval_epoch(net, g, pdata.test_mask) rpc.shutdown()
# val_loss, val_acc = train_loss, train_acc
# test_loss, test_acc = train_loss, train_acc
if val_acc > best_val_acc: # import torch
best_val_acc = val_acc # import torch.nn as nn
best_test_acc = test_acc
# from torch import Tensor
duration = time.time() - time_start # from typing import *
if with_nccl():
cur_mem = torch.cuda.memory_reserved() # import os
else: # import time
cur_mem = psutil.Process(os.getpid()).memory_info().rss # import psutil
cur_mem_mb = round(cur_mem / 1024**2)
# from starrygl.nn import *
if ep > 1: # from starrygl.graph import DistGraph
avg_mem += cur_mem
avg_dur += duration # from starrygl.parallel import init_process_group, convert_parallel_model
avg_num += 1 # from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
main_print( # from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f"best_accuracy: {best_test_acc:.4f}") # if __name__ == "__main__":
avg_mem = round(avg_mem / avg_num / 1024**2) # # 启动分布式进程组,并分配计算设备
avg_dur = avg_dur / avg_num # device = init_process_group(backend="nccl")
main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
# # 加载数据集
# pdata = partition_load("./cora", algo="metis").to(device)
# g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# # g.args["async_op"] = True
# # g.args["num_samples"] = 20
# g.edata["gcn_norm"] = compute_gcn_norm(g)
# g.ndata["x"] = pdata.x
# g.ndata["y"] = pdata.y
# # 定义GAT图神经网络模型
# net = ShrinkGCN(
# g=g,
# layer_options=BasicLayerOptions(
# in_channels=pdata.num_features,
# hidden_channels=64,
# num_layers=3,
# out_channels=pdata.num_classes,
# norm="batchnorm",
# ),
# input_options=BasicInputOptions(
# straight_enabled=True,
# straight_num_samples = 200,
# ),
# straight_options=BasicStraightOptions(
# enabled=True,
# num_samples = 20,
# # beta=1.1,
# ),
# ).to(device)
# # 转换成分布式并行版本
# net = convert_parallel_model(net)
# # 定义优化器
# opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
# avg_mem = 0.0
# avg_dur = 0.0
# avg_num = 0
# # 开始训练
# best_val_acc = best_test_acc = 0
# for ep in range(1, 50+1):
# time_start = time.time()
# train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
# val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
# test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# # val_loss, val_acc = train_loss, train_acc
# # test_loss, test_acc = train_loss, train_acc
# if val_acc > best_val_acc:
# best_val_acc = val_acc
# best_test_acc = test_acc
# duration = time.time() - time_start
# if with_nccl():
# cur_mem = torch.cuda.memory_reserved()
# else:
# cur_mem = psutil.Process(os.getpid()).memory_info().rss
# cur_mem_mb = round(cur_mem / 1024**2)
# if ep > 1:
# avg_mem += cur_mem
# avg_dur += duration
# avg_num += 1
# main_print(
# f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
# f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
# f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
# f"best_accuracy: {best_test_acc:.4f}")
# avg_mem = round(avg_mem / avg_num / 1024**2)
# avg_dur = avg_dur / avg_num
# main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment