Commit 3f94b191 by Wenjie Huang

add 1-order dataloader

parent a4c4cadb
from .a2a import all_to_all, with_gloo, with_nccl
# from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
# from .cache import MessageCache, NodeProbe
\ No newline at end of file
......@@ -4,58 +4,95 @@ import torch.distributed as dist
from torch import Tensor
from typing import *
def with_nccl() -> bool:
return dist.get_backend() == "nccl"
def with_gloo() -> bool:
return dist.get_backend() == "gloo"
class Works:
def __init__(self, *works) -> None:
self._works = list(works)
self._waited = False
def wait(self) -> None:
assert not self._waited
for w in self._works:
if w is None:
continue
w.wait()
self._waited = True
def push(self, *works) -> None:
assert not self._waited
self._works.extend(works)
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works:
):
assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl():
work = dist.all_to_all(
backend = dist.get_backend()
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
elif backend == "mpi":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=True,
)
return Works(work)
elif with_gloo():
else:
assert backend == "gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
works = Works()
p2p_op_list: List[dist.P2POp] = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w)
p2p_op_list.extend([
dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
])
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
return works
else:
backend = dist.get_backend()
raise RuntimeError(f"unsupported backend: {backend}")
# def with_nccl() -> bool:
# return dist.get_backend() == "nccl"
# def with_gloo() -> bool:
# return dist.get_backend() == "gloo"
# class Works:
# def __init__(self, *works) -> None:
# self._works = list(works)
# self._waited = False
# def wait(self) -> None:
# assert not self._waited
# for w in self._works:
# if w is None:
# continue
# w.wait()
# self._waited = True
# def push(self, *works) -> None:
# assert not self._waited
# self._works.extend(works)
# def all_to_all(
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Optional[Any] = None,
# ) -> Works:
# assert len(output_tensor_list) == len(input_tensor_list)
# if with_nccl():
# work = dist.all_to_all(
# output_tensor_list=output_tensor_list,
# input_tensor_list=input_tensor_list,
# group=group,
# async_op=True,
# )
# return Works(work)
# elif with_gloo():
# rank = dist.get_rank()
# world_size = dist.get_world_size()
# works = Works()
# for i in range(1, world_size):
# send_i = (rank + i) % world_size
# recv_i = (rank - i + world_size) % world_size
# send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
# recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
# works.push(recv_w, send_w)
# output_tensor_list[rank][:] = input_tensor_list[rank]
# return works
# else:
# backend = dist.get_backend()
# raise RuntimeError(f"unsupported backend: {backend}")
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
# from multiprocessing.pool import ThreadPool
# from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
# def wait(self):
# raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
# def has_events(self) -> bool:
# return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
# def set_events(self, start, end):
# self._events = (start, end)
# return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
def wait(self):
pass
# def wait(self):
# pass
def get(self):
pass
# def get(self):
# pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
def wait(self):
pass
# def wait(self):
# pass
def get(self) -> Tensor:
return self._val
# def get(self) -> Tensor:
# return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# class Lache:
# def __init__(self,
# cache_size: int,
# data: Tensor,
# ) -> None:
# assert not data.is_cuda, "data must be in CPU Memory"
# cache_size = (cache_size,) + data.shape[1:]
# self.fdata = data
# self.cache = torch.zeros(cache_size, dtype=data.dtype)
# self.no_idx: int = (2**62-1)*2+1
# self.cached_idx = torch.empty(cache_size, dtype=torch.long).fill_(self.no_idx)
# self.read_count = torch.zeros(data.size(0), dtype=torch.long)
# def to(self, device):
# self.cache = self.cache.to(device)
# self.cached_idx = self.cached_idx.to(device)
# self.read_count = self.read_count.to(device)
# return self
# def _push_impl(self, value: Tensor, index: Tensor):
# # self.read_count[index] += 1
# imp = torch.zeros_like(self.read_count)
# def _pull_impl(self, index: Tensor):
# assert index.device == self.cache.device
# self.read_count[index] += 1
# s = index.shape[:1] + self.cache.shape[1:]
# x = torch.empty(s, dtype=self.cache.dtype, device=self.cache.device)
# cache_mask = torch.zeros_like(self.read_count, dtype=torch.bool).index_fill_(0, self.cached_idx, 1)[index]
# cache_index = index[cache_mask]
# x[cache_mask] =
# no_cache_index = index[~cache_mask]
# rt_index = index[~lc_mask].to(self.fdata.device)
# rt_data = self.fdata[rt_index].to(index.device)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
# class ShrinkData:
# no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
# def __init__(self,
# src_size: int,
# dst_size: int,
# dst_idx: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> None:
# device = dst_idx.device
# tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
# tmp.fill_(0)
# tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
# edge_idx = torch.where(tmp[edge_index[1]])[0]
# edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
# if bipartite:
# tmp.fill_(0)
# tmp.index_fill_(0, edge_index[0], 1)
# src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
# dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# src = imp[edge_index[0]]
# edge_index = torch.vstack([src, dst])
# else:
# tmp.index_fill_(0, edge_index[0], 1)
# tmp.index_fill_(0, dst_idx, 0)
# src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edge_index = edge_index
# self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
# def to(self, device):
# self.src_idx = self.src_idx.to(device)
# self.dst_idx = self.dst_idx.to(device)
# self.edge_idx = self.edge_idx.to(device)
# self.edge_index = self.edge_index.to(device)
# self._src_imp = self._src_imp.to(device)
# return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
# def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
# idx = self._src_imp[idx]
# m = (idx != self.no_idx)
# return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
# @property
# def edge_size(self) -> int:
# return self.edge_idx.size(0)
from .distgraph import DistGraph
\ No newline at end of file
# from .distgraph import DistGraph
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class EData:
def __init__(self,
edge_size: Optional[int] = None,
p = None,
) -> None:
if p is None:
assert edge_size is not None
self.edge_size = edge_size
else:
assert edge_size is None
self.edge_size = p.edge_size
# class EData:
# def __init__(self,
# edge_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert edge_size is not None
# self.edge_size = edge_size
# else:
# assert edge_size is None
# self.edge_size = p.edge_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
raise ValueError(f"not found '{name}' in data")
return t
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError(f"the second parameter's type must be Tensor")
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError(f"the second parameter's type must be Tensor")
if tensor.size(0) == self.edge_size:
self.data[name] = tensor
else:
raise ValueError(f"tensor's shape must match the edge_size")
# if tensor.size(0) == self.edge_size:
# self.data[name] = tensor
# else:
# raise ValueError(f"tensor's shape must match the edge_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
# def __delitem__(self, name: str) -> None:
# self.pop(name)
def get(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = p.prev_data
return t
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
def pop(self, name: str) -> Tensor:
return self.data.pop(name)
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
def permute_(self, perm: Tensor):
p = self
while p is not None:
for key in list(p.data.keys()):
val = p.data.get(key)
p.data[key] = val[perm]
p = p.prev_data
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class NData:
def __init__(self,
src_size: Optional[int] = None,
dst_size: Optional[int] = None,
p = None,
) -> None:
if p is None:
assert src_size is not None
assert dst_size is not None
self.src_size = src_size
self.dst_size = dst_size
else:
assert src_size is None
assert dst_size is None
self.src_size = p.src_size
self.dst_size = p.dst_size
# class NData:
# def __init__(self,
# src_size: Optional[int] = None,
# dst_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert src_size is not None
# assert dst_size is not None
# self.src_size = src_size
# self.dst_size = dst_size
# else:
# assert src_size is None
# assert dst_size is None
# self.src_size = p.src_size
# self.dst_size = p.dst_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
raise ValueError(f"not found '{name}' in data")
return t
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError("the second parameter's type must be Tensor")
if tensor.size(0) == self.src_size:
self.data[name] = tensor
elif tensor.size(0) == self.dst_size:
self.data[name] = tensor
else:
raise ValueError("tensor's shape must match the src_size or dst_size")
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError("the second parameter's type must be Tensor")
# if tensor.size(0) == self.src_size:
# self.data[name] = tensor
# elif tensor.size(0) == self.dst_size:
# self.data[name] = tensor
# else:
# raise ValueError("tensor's shape must match the src_size or dst_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
# def __delitem__(self, name: str) -> None:
# self.pop(name)
def get(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = p.prev_data
return t
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
def pop(self, name: str) -> Tensor:
return self.data.pop(name)
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
def permute_(self, perm: Tensor):
p = self
while p is not None:
for key in list(p.data.keys()):
val = p.data.get(key)
p.data[key] = val[perm]
p = p.prev_data
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
def get_type(self, key: Union[str, Tensor]):
if isinstance(key, Tensor):
if key.size(0) == self.src_size:
return "src"
elif key.size(0) == self.dst_size:
return "dst"
else:
raise RuntimeError
t = self.__getitem__(key)
if t.size(0) == self.src_size:
return "src"
elif t.size(0) == self.dst_size:
return "dst"
else:
raise RuntimeError
\ No newline at end of file
# def get_type(self, key: Union[str, Tensor]):
# if isinstance(key, Tensor):
# if key.size(0) == self.src_size:
# return "src"
# elif key.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
# t = self.__getitem__(key)
# if t.size(0) == self.src_size:
# return "src"
# elif t.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# def init_local_edge_index(
# dst_ids: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# max_ids = calc_max_ids(dst_ids, edge_index)
# ikw = dict(dtype=torch.long, device=dst_ids.device)
# xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
# # 判断是不是点分割且所有边被划分到目标点所在分区
# xmp[edge_index[1].unique()] += 0b01
# xmp[dst_ids.unique()] += 0b10
# if not (xmp != 0x01).all():
# raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# if bipartite:
# src_ids = edge_index[0].unique()
# else:
# # 假设是同构图
# # src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
# xmp.fill_(0)
# xmp[edge_index[0]] = 1
# xmp[dst_ids] = 0
# src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
# # 计算局部索引
# xmp.fill_((2**62-1)*2+1)
# xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
# src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
# xmp.fill_((2**62-1)*2+1)
# xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
# dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
# local_edge_index = torch.vstack([src, dst])
# return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
\ No newline at end of file
# def calc_max_ids(*ids: Tensor) -> int:
# x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
# return max(*x)
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from .buffer import TensorBuffers
class BatchHandle:
def __init__(self,
src_ids: Tensor,
dst_size: int,
feat_buffer: TensorBuffers,
grad_buffer: TensorBuffers,
with_feat0: bool = False,
layer_id: Optional[int] = None,
target_device: Any = None,
) -> None:
self.src_ids = src_ids
self.dst_size = dst_size
self.feat_buffer = feat_buffer
self.grad_buffer = grad_buffer
self.with_feat0 = with_feat0
self.layer_id = layer_id
if target_device is None:
self.target_device = src_ids.device
else:
self.target_device = torch.device(target_device)
if with_feat0:
_ = self.feat0
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def device(self) -> torch.device:
return self.src_ids.device
@property
def feat0(self) -> Tensor:
if not hasattr(self, "_feat0"):
self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
return self._feat0
def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor:
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device)
def update_feat(self, x: Tensor, layer_id: Optional[int] = None):
assert x.size(0) == self.dst_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
assert x.size(0) == self.dst_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
return torch.cat([x, o], dim=0)
def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor:
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device)
def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
assert x.size(0) == self.src_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class DataLoader:
def __init__(self,
node_parts: Tensor,
feat_buffer: TensorBuffers,
grad_buffer: TensorBuffers,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
) -> None:
self.src_size = feat_buffer.src_size
self.dst_size = node_parts.size(0)
assert node_parts.size(0) <= self.src_size
assert grad_buffer.src_size == self.src_size
num_parts = node_parts.max().item() + 1
cluster, node_perm = node_parts.sort(dim=0)
node_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
edge_parts = node_parts[edge_index[1]]
cluster, edge_perm = edge_parts.sort()
edge_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
self.num_parts = num_parts
self.node_ptr = node_ptr
self.edge_ptr = edge_ptr
self.node_perm = node_perm
self.edge_perm = edge_perm
self.edge_index = edge_index[:,edge_perm]
if edge_attr is not None:
self.edge_attr = edge_attr[edge_perm]
if node_time is not None:
self.node_time = node_time[node_perm]
if edge_time is not None:
self.edge_time = edge_time[edge_perm]
self.feat_buffer = feat_buffer
self.grad_buffer = grad_buffer
def iter(self, batch_size: int = 1, layer_id: Optional[int] = None, seed: int = 0, filter: Callable[[Tensor], Tensor] = None, device = None):
rnd = torch.Generator()
if seed != 0:
rnd.manual_seed(seed)
sampled = torch.randperm(self.num_parts, generator=rnd, dtype=torch.long)
s = 0
imp = torch.empty(self.src_size, dtype=torch.long)
while s < sampled.size(0):
t = min(s + batch_size, sampled.size(0))
dst_ids = []
edge_index = []
edge_attr = []
imp.zero_()
for i in sampled[s:t].tolist():
s += batch_size
a, b = self.node_ptr[i:i+2].tolist()
nidx = self.node_perm[a:b]
if hasattr(self, "node_time") and filter is not None:
node_mask = filter(self.node_time[a:b])
if not node_mask.any():
continue
nidx = nidx[node_mask]
else:
node_mask = None
a, b = self.edge_ptr[i:i+2].tolist()
eidx = self.edge_index[:, a:b]
if hasattr(self, "edge_time") and filter is not None:
if node_mask is None:
edge_mask = filter(self.edge_time[a:b])
else:
imp[nidx] = i+1
edge_mask = (imp[eidx[1]] == i+1)
edge_mask &= filter(self.edge_time[a:b])
if not edge_mask.any():
continue
eidx = eidx[:, edge_mask]
else:
edge_mask = None
dst_ids.append(nidx)
edge_index.append(eidx)
if hasattr(self, "edge_attr"):
attr = self.edge_attr[a:b]
if edge_mask is None:
edge_attr.append(attr)
else:
edge_attr.append(attr[edge_mask])
if len(dst_ids) == 0 or len(edge_index) == 0:
continue
dst_ids = torch.cat(dst_ids, dim=-1)
edge_index = torch.cat(edge_index, dim=-1)
if hasattr(self, "edge_attr"):
edge_attr = torch.cat(edge_attr, dim=0)
else:
edge_attr = None
imp.zero_()
imp.index_fill_(0, edge_index[0], 1).index_fill_(0, dst_ids, 0)
src_ids = torch.cat([dst_ids, torch.where(imp > 0)[0]], dim=-1)
assert (src_ids[:dst_ids.size(0)] == dst_ids).all()
imp.fill_((2**62-1)*2+1)
imp[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
handle = BatchHandle(
src_ids, dst_ids.size(0),
self.feat_buffer, self.grad_buffer,
with_feat0 = (layer_id is None),
layer_id = layer_id,
target_device = device,
)
edge_index = edge_index.to(device)
edge_attr = edge_attr.to(device)
yield handle, edge_index, edge_attr
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
class TensorBuffers(nn.Module):
def __init__(self, src_size: int, channels: Tuple[Union[int, Tuple[int]]]) -> None:
super().__init__()
self.src_size = src_size
self.num_layers = len(channels)
for i, s in enumerate(channels):
s = (s,) if isinstance(s, int) else s
self.register_buffer(f"data_{i}", torch.zeros(src_size, *s), persistent=False)
def get(self, i: int, index: Optional[Tensor]) -> Tensor:
i = self._idx(i)
if index is None:
return self.get_buffer(f"data_{i}")
else:
return self.get_buffer(f"data_{i}")[index]
def set(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
if index is None:
self.get_buffer(f"data_{i}")[:,...] = value
else:
self.get_buffer(f"data_{i}")[index] = value
def add(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
if index is None:
self.get_buffer(f"data_{i}")[:,...] += value
else:
self.get_buffer(f"data_{i}")[index] += value
def zero_grad(self):
for name, grad in self.named_buffers("data_", recurse=False):
grad.zero_()
def _idx(self, i: int) -> int:
assert -self.num_layers < i and i < self.num_layers
return (self.num_layers + i) % self.num_layers
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from starrygl.parallel import get_compute_device
from starrygl.core.route import Route
from starrygl.utils.partition import metis_partition
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def collect_feat0(src_ids: Tensor, dst_ids: Tensor, feat0: Tensor):
device = get_compute_device()
route = Route(dst_ids.to(device), src_ids.to(device))
return route.forward_a2a(feat0.to(device))[0].to(feat0.device)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.loader import BatchHandle
class BaseLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return x
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat()
with torch.no_grad():
x = self.forward(x, edge_index, edge_attr)
handle.update_feat(x)
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat().requires_grad_()
g = handle.fetch_grad()
self.forward(x, edge_index, edge_attr).backward(g)
handle.accumulate_grad(x.grad)
x.grad = None
def all_reduce_grad(self):
for p in self.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class BaseModel(nn.Module):
def __init__(self,
num_features: int,
layers: List[int],
prev_layer: bool = False,
post_layer: bool = False,
) -> None:
super().__init__()
def init_prev_layer(self) -> Tensor:
pass
def init_post_layer(self) -> Tensor:
pass
def init_conv_layer(self) -> Tensor:
pass
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
# from torch_geometric.utils import softmax
# from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
from .gat_conv import GATConv
from .utils import ShrinkHelper
# from starrygl.graph import DistGraph
# from .gat_conv import GATConv
# from .utils import ShrinkHelper
class ShrinkGATConv(GATConv):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
heads=heads,
concat=concat,
negative_slope=negative_slope,
dropout=dropout,
edge_dim=edge_dim,
bias=bias,
**kwargs
)
# class ShrinkGATConv(GATConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# heads: int = 1,
# concat: bool = False,
# negative_slope: float = 0.2,
# dropout: float = 0.0,
# edge_dim: Optional[int] = None,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# heads=heads,
# concat=concat,
# negative_slope=negative_slope,
# dropout=dropout,
# edge_dim=edge_dim,
# bias=bias,
# **kwargs
# )
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels
x = g.ndata["x"]
# H, C = self.heads, self.out_channels
# x = g.ndata["x"]
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
# src_x = (src_x @ self.weight).view(-1, H, C)
# dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
# alpha_i = (src_x * self.att_src).sum(dim=-1)
# alpha_j = alpha_j[edge_index[0]]
alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
# alpha_i = (dst_x * self.att_dst).sum(dim=-1)
# alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
# if self.edge_dim is not None:
# edge_attr = g.edata["edge_attr"]
# edge_attr = edge_attr[sh.edge_idx]
# if edge_attr.dim() == 1:
# edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=sh.dst_size,
)
# e = (edge_attr @ self.lin_edge).view(-1, H, C)
# alpha_e = (e * self.att_edge).sum(dim=-1)
# alpha = alpha_i + alpha_j + alpha_e
# else:
# alpha = alpha_i + alpha_j
# alpha = F.leaky_relu(alpha, self.negative_slope)
# alpha = softmax(
# src=alpha,
# index=edge_index[1],
# num_nodes=sh.dst_size,
# )
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# x = x[edge_index[0]] * alpha.view(-1, H, 1)
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
# if self.concat:
# x = x.view(-1, H * C)
# else:
# x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
# if self.bias is not None:
# x += self.bias
# return x
import torch
import torch.nn as nn
# import torch
# import torch.nn as nn
from torch_scatter import scatter_sum
# from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
from .gcn_conv import GCNConv
from .utils import ShrinkHelper
# from starrygl.graph import DistGraph
# from .gcn_conv import GCNConv
# from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
**kwargs
)
# class ShrinkGCNConv(GCNConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# bias=bias,
# **kwargs
# )
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
# x = g.ndata["x"]
# gcn_norm = g.edata["gcn_norm"].view(-1, 1)
x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
edge_index = sh.edges
# x = x[sh.src_idx]
# gcn_norm = gcn_norm[sh.edge_idx]
# edge_index = sh.edges
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None:
x += self.bias
return x
# x = x @ self.weight
# x = x[edge_index[0]] * gcn_norm
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.bias is not None:
# x += self.bias
# return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from .gin_conv import GINConv
from .utils import ShrinkHelper
class ShrinkGINConv(GINConv):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
mlp_channels=mlp_channels,
eps=eps,
train_eps=train_eps,
**kwargs
)
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gin_conv import GINConv
# from .utils import ShrinkHelper
# class ShrinkGINConv(GINConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# mlp_channels: Optional[int] = None,
# eps: float = 0,
# train_eps: bool = False,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# mlp_channels=mlp_channels,
# eps=eps,
# train_eps=train_eps,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
# x = g.ndata["x"]
sh = ShrinkHelper(g, dst_idx)
# sh = ShrinkHelper(g, dst_idx)
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
x = z + (1 + self.eps) * dst_x
return self.nn(x)
\ No newline at end of file
# z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
# x = z + (1 + self.eps) * dst_x
# return self.nn(x)
\ No newline at end of file
......@@ -3,11 +3,10 @@ import torch.nn as nn
import torch.distributed as dist
import os
from typing import *
from .degree import compute_degree, compute_gcn_norm
from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model(
net: nn.Module,
......@@ -16,28 +15,38 @@ def convert_parallel_model(
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
broadcast_buffers=False,
# broadcast_buffers=False,
)
for name, buffer in net.named_buffers():
if name.endswith("last_embd"):
continue
if name.endswith("last_w"):
continue
dist.broadcast(buffer, src=0)
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
dist.init_process_group(backend)
if backend == "nccl":
local_rank = os.getenv("LOCAL_RANK")
if backend == "nccl" or backend == "mpi":
if local_rank is None:
device = torch.device(f"cuda:{dist.get_rank()}")
else:
local_rank = int(local_rank)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
get_executor() # initialize route executor
return device
\ No newline at end of file
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
\ No newline at end of file
......@@ -4,29 +4,30 @@ from torch import Tensor
from typing import *
from torch_scatter import scatter_sum
from starrygl.core.route import Route
from ..graph.distgraph import DistGraph
# from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
edge_index = g.edge_index
x = torch.ones(
edge_index.size(1),
dtype=torch.long,
device=edge_index.device,
)
in_degree = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size) # (g.dst_size,)
out_degree = scatter_sum(x, edge_index[0], dim=0, dim_size=g.src_size)
out_degree = g.gather(out_degree, direct="src_to_dst") # (g.dst_size,)
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size = route.src_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_degree = g.gather(in_degree, direct="dst_to_src") # (g.src_size,)
out_degree = g.gather(out_degree, direct="dst_to_src") # (g.src_size,)
return in_degree, out_degree
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg, _ = route.forward_a2a(in_deg)
return in_deg
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size = route.dst_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg, _ = route.backward_a2a(out_deg)
out_deg, _ = route.forward_a2a(out_deg)
return out_deg
def compute_gcn_norm(g: DistGraph) -> Tensor:
edge_index = g.edge_index
in_deg, out_deg = compute_degree(g)
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg = compute_in_degree(edge_index, route)
out_deg = compute_out_degree(edge_index, route)
a = in_deg[edge_index[0]].pow(-0.5)
b = out_deg[edge_index[0]].pow(-0.5)
......
from .functional import train_epoch, eval_epoch
# from .functional import train_epoch, eval_epoch
from .partition import partition_load, partition_save, partition_data
from .printer import sync_print, main_print
from .metrics import all_reduce_loss, accuracy
\ No newline at end of file
import torch
import torch.nn as nn
# import torch
# import torch.nn as nn
from torch import Tensor
from ..graph import DistGraph
# from torch import Tensor
# from ..graph import DistGraph
from typing import *
from .metrics import *
# from typing import *
# from .metrics import *
def train_epoch(
model: nn.Module,
opt: torch.optim.Optimizer,
g: DistGraph,
mask: Optional[Tensor] = None,
) -> float:
model.train()
criterion = nn.CrossEntropyLoss()
# def train_epoch(
# model: nn.Module,
# opt: torch.optim.Optimizer,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> float:
# model.train()
# criterion = nn.CrossEntropyLoss()
pred: Tensor = model(g)
targ: Tensor = g.ndata["y"]
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
if mask is not None:
pred = pred[mask]
targ = targ[mask]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
loss: Tensor = criterion(pred, targ)
# loss: Tensor = criterion(pred, targ)
opt.zero_grad()
loss.backward()
opt.step()
# opt.zero_grad()
# loss.backward()
# opt.step()
with torch.no_grad():
train_loss = all_reduce_loss(loss, targ.size(0))
train_acc = accuracy(pred.argmax(dim=-1), targ)
return train_loss, train_acc
# with torch.no_grad():
# train_loss = all_reduce_loss(loss, targ.size(0))
# train_acc = accuracy(pred.argmax(dim=-1), targ)
# return train_loss, train_acc
@torch.no_grad()
def eval_epoch(
model: nn.Module,
g: DistGraph,
mask: Optional[Tensor] = None,
) -> Tuple[float, float]:
model.eval()
criterion = nn.CrossEntropyLoss()
# @torch.no_grad()
# def eval_epoch(
# model: nn.Module,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> Tuple[float, float]:
# model.eval()
# criterion = nn.CrossEntropyLoss()
pred: Tensor = model(g)
targ: Tensor = g.ndata["y"]
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
if mask is not None:
pred = pred[mask]
targ = targ[mask]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
loss = criterion(pred, targ)
# loss = criterion(pred, targ)
eval_loss = all_reduce_loss(loss, targ.size(0))
eval_acc = accuracy(pred.argmax(dim=-1), targ)
return eval_loss, eval_acc
\ No newline at end of file
# eval_loss = all_reduce_loss(loss, targ.size(0))
# eval_acc = accuracy(pred.argmax(dim=-1), targ)
# return eval_loss, eval_acc
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment