Commit 3f94b191 by Wenjie Huang

add 1-order dataloader

parent a4c4cadb
from .a2a import all_to_all, with_gloo, with_nccl
# from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
# from .cache import MessageCache, NodeProbe
\ No newline at end of file
......@@ -4,58 +4,95 @@ import torch.distributed as dist
from torch import Tensor
from typing import *
def with_nccl() -> bool:
return dist.get_backend() == "nccl"
def with_gloo() -> bool:
return dist.get_backend() == "gloo"
class Works:
def __init__(self, *works) -> None:
self._works = list(works)
self._waited = False
def wait(self) -> None:
assert not self._waited
for w in self._works:
if w is None:
continue
w.wait()
self._waited = True
def push(self, *works) -> None:
assert not self._waited
self._works.extend(works)
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works:
):
assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl():
work = dist.all_to_all(
backend = dist.get_backend()
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
elif backend == "mpi":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=True,
)
return Works(work)
elif with_gloo():
else:
assert backend == "gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
works = Works()
p2p_op_list: List[dist.P2POp] = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w)
p2p_op_list.extend([
dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
])
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
return works
else:
backend = dist.get_backend()
raise RuntimeError(f"unsupported backend: {backend}")
# def with_nccl() -> bool:
# return dist.get_backend() == "nccl"
# def with_gloo() -> bool:
# return dist.get_backend() == "gloo"
# class Works:
# def __init__(self, *works) -> None:
# self._works = list(works)
# self._waited = False
# def wait(self) -> None:
# assert not self._waited
# for w in self._works:
# if w is None:
# continue
# w.wait()
# self._waited = True
# def push(self, *works) -> None:
# assert not self._waited
# self._works.extend(works)
# def all_to_all(
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Optional[Any] = None,
# ) -> Works:
# assert len(output_tensor_list) == len(input_tensor_list)
# if with_nccl():
# work = dist.all_to_all(
# output_tensor_list=output_tensor_list,
# input_tensor_list=input_tensor_list,
# group=group,
# async_op=True,
# )
# return Works(work)
# elif with_gloo():
# rank = dist.get_rank()
# world_size = dist.get_world_size()
# works = Works()
# for i in range(1, world_size):
# send_i = (rank + i) % world_size
# recv_i = (rank - i + world_size) % world_size
# send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
# recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
# works.push(recv_w, send_w)
# output_tensor_list[rank][:] = input_tensor_list[rank]
# return works
# else:
# backend = dist.get_backend()
# raise RuntimeError(f"unsupported backend: {backend}")
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
# from multiprocessing.pool import ThreadPool
# from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
# def wait(self):
# raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
# def has_events(self) -> bool:
# return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
# def set_events(self, start, end):
# self._events = (start, end)
# return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
def wait(self):
pass
# def wait(self):
# pass
def get(self):
pass
# def get(self):
# pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
def wait(self):
pass
# def wait(self):
# pass
def get(self) -> Tensor:
return self._val
# def get(self) -> Tensor:
# return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
import torch
import torch.nn as nn
# import torch
# import torch.nn as nn
from torch.utils import hooks
# from torch.utils import hooks
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from .route import Route, RouteWorkBase
from .shrink import ShrinkData
from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
from .route import get_executor as get_route_executor
# from .route import Route, RouteWorkBase
# from .shrink import ShrinkData
# from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
# from .route import get_route_executor
class ProbeLayer:
def __init__(self, layer_id, probe_obj) -> None:
self.layer_id: int = layer_id
self.probe_obj: NodeProbe = probe_obj
self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
self._val_hook_handle: Optional[hooks.RemovableHandle] = None
# class ProbeLayer:
# def __init__(self, layer_id, probe_obj) -> None:
# self.layer_id: int = layer_id
# self.probe_obj: NodeProbe = probe_obj
# self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
# self._val_hook_handle: Optional[hooks.RemovableHandle] = None
def to(self, device):
self.last_w = self.last_w.to(device)
return self
# def to(self, device):
# self.last_w = self.last_w.to(device)
# return self
def remove_val_hook(self):
if self._val_hook_handle is not None:
self._val_hook_handle.remove()
self._val_hook_handle = None
# def remove_val_hook(self):
# if self._val_hook_handle is not None:
# self._val_hook_handle.remove()
# self._val_hook_handle = None
def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
# def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
# assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
def hook(grad: Tensor):
from starrygl.utils.printer import main_print
import time
self.probe_obj.update_sample_w(self.last_w, grad, idx)
self.remove_val_hook()
self._backward_sample(False)
main_print(self.layer_id, grad.size(0), idx.size(0))
# def hook(grad: Tensor):
# from starrygl.utils.printer import main_print
# import time
# self.probe_obj.update_sample_w(self.last_w, grad, idx)
# self.remove_val_hook()
# self._backward_sample(False)
# main_print(self.layer_id, grad.size(0), idx.size(0))
self.val_hook_handle = val.register_hook(hook)
# self.val_hook_handle = val.register_hook(hook)
def warmup_sample(self):
assert self._is_last_layer()
for probe_layer in self.probe_obj.layers[::-1]:
probe_layer._backward_sample()
# def warmup_sample(self):
# assert self._is_last_layer()
# for probe_layer in self.probe_obj.layers[::-1]:
# probe_layer._backward_sample()
# max_norm = max(probe_layer.last_w.max().item(), 1.0)
# probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
# # max_norm = max(probe_layer.last_w.max().item(), 1.0)
# # probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
def _backward_sample(self, x: bool = True):
dst_idx = self._collect_prev_dst_idx()
if dst_idx is None:
return
# def _backward_sample(self, x: bool = True):
# dst_idx = self._collect_prev_dst_idx()
# if dst_idx is None:
# return
for cache in self.probe_obj.hook_caches:
next_cache_layer = self._next_cache_layer(cache)
shrink = next_cache_layer.set_shrink_data(dst_idx)
# for cache in self.probe_obj.hook_caches:
# next_cache_layer = self._next_cache_layer(cache)
# shrink = next_cache_layer.set_shrink_data(dst_idx)
src_idx = shrink.src_idx
# src_idx = shrink.src_idx
work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
next_cache_layer.sync_pull_work(work)
# work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
# next_cache_layer.sync_pull_work(work)
if not self._is_first_layer() and x:
next_cache_layer.sync_backward_sample_work(src_idx)
# if not self._is_first_layer() and x:
# next_cache_layer.sync_backward_sample_work(src_idx)
def _collect_prev_dst_idx(self) -> Optional[Tensor]:
if len(self.probe_obj.hook_caches) == 0:
return None
# def _collect_prev_dst_idx(self) -> Optional[Tensor]:
# if len(self.probe_obj.hook_caches) == 0:
# return None
if self._is_last_layer():
return self._wrapped_sample(self.probe_obj.num_samples, None)
# if self._is_last_layer():
# return self._wrapped_sample(self.probe_obj.num_samples, None)
if len(self.probe_obj.hook_caches) == 1:
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
dst_idx = prev_cache_layer.sync_backward_sample_work()
else:
tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
prev_idx = prev_cache_layer.sync_backward_sample_work()
if prev_idx is not None:
tmp.index_fill_(0, prev_idx, 1)
dst_idx = torch.where(tmp)[0]
if dst_idx.size(0) == 0:
dst_idx = None
# if len(self.probe_obj.hook_caches) == 1:
# for cache in self.probe_obj.hook_caches:
# prev_cache_layer = self._prev_cache_layer(cache)
# dst_idx = prev_cache_layer.sync_backward_sample_work()
# else:
# tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
# for cache in self.probe_obj.hook_caches:
# prev_cache_layer = self._prev_cache_layer(cache)
# prev_idx = prev_cache_layer.sync_backward_sample_work()
# if prev_idx is not None:
# tmp.index_fill_(0, prev_idx, 1)
# dst_idx = torch.where(tmp)[0]
# if dst_idx.size(0) == 0:
# dst_idx = None
if dst_idx is None:
return None
return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
# if dst_idx is None:
# return None
# return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
def _prev_cache_layer(self, cache):
cache: MessageCache = cache
i = self.layer_id + 1
if i < self.probe_obj.num_layers:
return cache.layers[i]
return None
# def _prev_cache_layer(self, cache):
# cache: MessageCache = cache
# i = self.layer_id + 1
# if i < self.probe_obj.num_layers:
# return cache.layers[i]
# return None
def _next_cache_layer(self, cache):
cache: MessageCache = cache
return cache.layers[self.layer_id]
# def _next_cache_layer(self, cache):
# cache: MessageCache = cache
# return cache.layers[self.layer_id]
def _is_last_layer(self):
return self.layer_id + 1 >= self.probe_obj.num_layers
# def _is_last_layer(self):
# return self.layer_id + 1 >= self.probe_obj.num_layers
def _is_first_layer(self):
return self.layer_id == 0
# def _is_first_layer(self):
# return self.layer_id == 0
def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
w = self.last_w
if idx is None:
return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
if 0 < k and k < idx.size(0):
t = self.probe_obj.sample(w[idx], k)
return idx[t]
else:
return idx
# def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
# w = self.last_w
# if idx is None:
# return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
# if 0 < k and k < idx.size(0):
# t = self.probe_obj.sample(w[idx], k)
# return idx[t]
# else:
# return idx
class NodeProbe:
def __init__(self,
num_nodes: int,
num_layers: int,
num_samples: int = 0,
p: str = "fro",
dim: int = -1,
beta: float = 1.0,
) -> None:
super().__init__()
self.num_nodes = num_nodes
self.num_layers = num_layers
self.num_samples = num_samples
self.p = p
self.dim = dim
self.beta = beta
# class NodeProbe:
# def __init__(self,
# num_nodes: int,
# num_layers: int,
# num_samples: int = 0,
# p: str = "fro",
# dim: int = -1,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.num_nodes = num_nodes
# self.num_layers = num_layers
# self.num_samples = num_samples
# self.p = p
# self.dim = dim
# self.beta = beta
self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
self.hook_caches: Set[MessageCache] = set()
# self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
# self.hook_caches: Set[MessageCache] = set()
def to(self, device):
for i in range(self.num_layers):
self.layers[i].to(device)
return self
# def to(self, device):
# for i in range(self.num_layers):
# self.layers[i].to(device)
# return self
def assign_message_cache(self, cache):
cache: MessageCache = cache
assert self.num_layers == cache.num_layers
self.hook_caches.add(cache)
# def assign_message_cache(self, cache):
# cache: MessageCache = cache
# assert self.num_layers == cache.num_layers
# self.hook_caches.add(cache)
def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
val_norm = grad.norm(p=self.p, dim=self.dim)
if idx is None:
last_w[:] = val_norm
else:
if self.beta != 1.0:
last_w.mul_(self.beta)
last_w[idx] = val_norm
# def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
# val_norm = grad.norm(p=self.p, dim=self.dim)
# if idx is None:
# last_w[:] = val_norm
# else:
# if self.beta != 1.0:
# last_w.mul_(self.beta)
# last_w[idx] = val_norm
def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
w = w / w.sum()
return torch.multinomial(w, num_samples=k, replacement=False)
# def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=k, replacement=False)
def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
self.layers[i].register_val_hook(val, idx)
return val
# def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
# self.layers[i].register_val_hook(val, idx)
# return val
class CacheLayer:
def __init__(self, layer_id, cache_obj, num_features) -> None:
self.layer_id: int = layer_id
self.cache_obj: MessageCache = cache_obj
self.data: Tensor = torch.randn(
size=(self.cache_obj.src_size, num_features),
dtype=torch.float32,
device=self.cache_obj.cache_device,
)
self.state: Dict[str, Any] = {}
self.shrink: Optional[ShrinkData] = None
self._async_push_work: Optional[AsyncCopyWorkBase] = None
self._async_pull_work: Optional[AsyncCopyWorkBase] = None
self._backward_sample_work: Optional[RouteWorkBase] = None
# class CacheLayer:
# def __init__(self, layer_id, cache_obj, num_features) -> None:
# self.layer_id: int = layer_id
# self.cache_obj: MessageCache = cache_obj
# self.data: Tensor = torch.randn(
# size=(self.cache_obj.src_size, num_features),
# dtype=torch.float32,
# device=self.cache_obj.cache_device,
# )
# self.state: Dict[str, Any] = {}
# self.shrink: Optional[ShrinkData] = None
# self._async_push_work: Optional[AsyncCopyWorkBase] = None
# self._async_pull_work: Optional[AsyncCopyWorkBase] = None
# self._backward_sample_work: Optional[RouteWorkBase] = None
def to(self, device):
if self.shrink is not None:
self.shrink = self.shrink.to(device)
return self
# def to(self, device):
# if self.shrink is not None:
# self.shrink = self.shrink.to(device)
# return self
def cache_data_to(self):
self.data = self.data.to(self.cache_obj.cache_device)
return self
# def cache_data_to(self):
# self.data = self.data.to(self.cache_obj.cache_device)
# return self
def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
self.shrink = self.cache_obj.new_shrink_data(dst_idx)
return self.shrink
# def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
# self.shrink = self.cache_obj.new_shrink_data(dst_idx)
# return self.shrink
def get_shrink_data(self) -> Optional[ShrinkData]:
return self.shrink
# def get_shrink_data(self) -> Optional[ShrinkData]:
# return self.shrink
# def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# if src_idx is None:
# return
# assert self._backward_sample_work is None
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# # def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# # if src_idx is None:
# # return
# # assert self._backward_sample_work is None
# # self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# work = self._backward_sample_work
# self._backward_sample_work = None
# return work
# # def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# # work = self._backward_sample_work
# # self._backward_sample_work = None
# # return work
def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
dst_idx = None
if self._backward_sample_work is not None:
_, dst_idx = self._backward_sample_work.get()
# def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
# dst_idx = None
# if self._backward_sample_work is not None:
# _, dst_idx = self._backward_sample_work.get()
if src_idx is not None:
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
self._backward_sample_work = work
else:
self._backward_sample_work = None
return dst_idx
# if src_idx is not None:
# # self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
# self._backward_sample_work = work
# else:
# self._backward_sample_work = None
# return dst_idx
def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
if self._async_push_work is not None:
self._async_push_work.get()
self._async_push_work = work
# def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
# if self._async_push_work is not None:
# self._async_push_work.get()
# self._async_push_work = work
def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
out = None
if self._async_pull_work is not None:
out = self._async_pull_work.get()
self._async_pull_work = work
return out
# def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
# out = None
# if self._async_pull_work is not None:
# out = self._async_pull_work.get()
# self._async_pull_work = work
# return out
class MessageCache:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
edge_index: Tensor,
num_features: Union[int, torch.Size],
num_layers: int,
cache_device: Union[str, torch.device, None],
bipartite: bool = False,
) -> None:
self.src_size = src_ids.size(0)
self.dst_size = dst_ids.size(0)
self.edge_index = edge_index
self.num_layers = num_layers
# class MessageCache:
# def __init__(self,
# src_ids: Tensor,
# dst_ids: Tensor,
# edge_index: Tensor,
# num_features: Union[int, torch.Size],
# num_layers: int,
# cache_device: Union[str, torch.device, None],
# bipartite: bool = False,
# ) -> None:
# self.src_size = src_ids.size(0)
# self.dst_size = dst_ids.size(0)
# self.edge_index = edge_index
# self.num_layers = num_layers
if cache_device is None:
cache_device = edge_index.device
self.cache_device = cache_device
self.bipartite = bipartite
# if cache_device is None:
# cache_device = edge_index.device
# self.cache_device = cache_device
# self.bipartite = bipartite
self.route = Route(dst_ids, src_ids, bipartite=bipartite)
self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
# self.route = Route(dst_ids, src_ids, bipartite=bipartite)
# self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
def to(self, device):
self.edge_index = self.edge_index.to(device)
self.route = self.route.to(device)
for i in range(self.num_layers):
self.layers[i].to(device)
return self
# def to(self, device):
# self.edge_index = self.edge_index.to(device)
# self.route = self.route.to(device)
# for i in range(self.num_layers):
# self.layers[i].to(device)
# return self
def cached_data_to(self, device):
self.cache_device = torch.device(device)
for i in range(self.num_layers):
self.layers[i].cache_data_to()
return self
# def cached_data_to(self, device):
# self.cache_device = torch.device(device)
# for i in range(self.num_layers):
# self.layers[i].cache_data_to()
# return self
# @property
# def is_offload(self) -> bool:
# return self.edge_index.device != self.cache_device
# # @property
# # def is_offload(self) -> bool:
# # return self.edge_index.device != self.cache_device
def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
if dst_idx is None:
return None
return ShrinkData(
src_size=self.src_size,
dst_size=self.dst_size,
dst_idx=dst_idx,
edge_index=self.edge_index,
bipartite=self.bipartite,
)
# def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
# if dst_idx is None:
# return None
# return ShrinkData(
# src_size=self.src_size,
# dst_size=self.dst_size,
# dst_idx=dst_idx,
# edge_index=self.edge_index,
# bipartite=self.bipartite,
# )
# def clear_shrink_data(self):
# for i in range(self.num_layers):
# self.layers[i].shrink = None
# # def clear_shrink_data(self):
# # for i in range(self.num_layers):
# # self.layers[i].shrink = None
def replace_layer_data(self, i: int, data: Tensor):
layer = self.layers[i]
assert layer.data.size(0) == data.size(0)
# def replace_layer_data(self, i: int, data: Tensor):
# layer = self.layers[i]
# assert layer.data.size(0) == data.size(0)
layer.data = data
return self
# layer.data = data
# return self
def update_cache(self,
i: int,
val: Tensor,
idx: Optional[Tensor],
async_op: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
layer: CacheLayer = self.layers[i]
src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# def update_cache(self,
# i: int,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# ) -> Tuple[Tensor, Optional[Tensor]]:
# layer: CacheLayer = self.layers[i]
# src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# full graph
if idx is None:
return src_val, None
# # full graph
# if idx is None:
# return src_val, None
shrink = layer.get_shrink_data()
if shrink is None: # 这里可能存在问题,在初始化的时候
data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
data.index_copy_(0, src_idx, src_val)
return data, src_idx
else:
push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
layer.sync_push_work(push_work)
# shrink = layer.get_shrink_data()
# if shrink is None: # 这里可能存在问题,在初始化的时候
# data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
# data.index_copy_(0, src_idx, src_val)
# return data, src_idx
# else:
# push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
# layer.sync_push_work(push_work)
data = layer.sync_pull_work()
sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
data.index_copy_(0, sidx, sval)
return data, sidx
# data = layer.sync_pull_work()
# sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
# data.index_copy_(0, sidx, sval)
# return data, sidx
# 先异步更新缓存
# 同时异步下载缓存
# 然后本地和远程缓存混合,作为最终结果返回
# # 先异步更新缓存
# # 同时异步下载缓存
# # 然后本地和远程缓存混合,作为最终结果返回
def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
return self.layers[i].sync_pull_work()
# def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
# return self.layers[i].sync_pull_work()
# def _update_cache_impl(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# # communcation
# state = self.layers[i].state
# src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # def _update_cache_impl(self,
# # i: int,
# # dst_val: Tensor,
# # dst_idx: Optional[Tensor],
# # route: Route,
# # async_op: bool,
# # ):
# # # communcation
# # state = self.layers[i].state
# # src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # push latest embeddings
# data = self.layers[i].data
# if src_idx is None:
# data[:] = src_val
# else:
# data[src_idx] = src_val
# # # push latest embeddings
# # data = self.layers[i].data
# # if src_idx is None:
# # data[:] = src_val
# # else:
# # data[src_idx] = src_val
# # get previous generated shrink data
# shr: ShrinkData = self.layers[i].shrink
# if shr is None:
# return src_val, src_idx
# else:
# # pull latest embeddings
# return data[shr.src_idx], shr.src_idx
# # # get previous generated shrink data
# # shr: ShrinkData = self.layers[i].shrink
# # if shr is None:
# # return src_val, src_idx
# # else:
# # # pull latest embeddings
# # return data[shr.src_idx], shr.src_idx
# def _update_cache_offload(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# raise NotImplementedError
# # def _update_cache_offload(self,
# # i: int,
# # dst_val: Tensor,
# # dst_idx: Optional[Tensor],
# # route: Route,
# # async_op: bool,
# # ):
# # raise NotImplementedError
# def _compute_cached_data_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cached_data_size = torch.Size([self.src_size,])
# else:
# cached_data_size = torch.Size([self.src_size, num_features])
# return cached_data_size
# # def _compute_cached_data_size(self) -> torch.Size:
# # num_features = self.num_features
# # if num_features is None:
# # cached_data_size = torch.Size([self.src_size,])
# # else:
# # cached_data_size = torch.Size([self.src_size, num_features])
# # return cached_data_size
# def _compute_cpu_buf_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cpu_buf_size = torch.Size([2**32,])
# else:
# cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# return cpu_buf_size
\ No newline at end of file
# # def _compute_cpu_buf_size(self) -> torch.Size:
# # num_features = self.num_features
# # if num_features is None:
# # cpu_buf_size = torch.Size([2**32,])
# # else:
# # cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# # return cpu_buf_size
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# class Lache:
# def __init__(self,
# cache_size: int,
# data: Tensor,
# ) -> None:
# assert not data.is_cuda, "data must be in CPU Memory"
# cache_size = (cache_size,) + data.shape[1:]
# self.fdata = data
# self.cache = torch.zeros(cache_size, dtype=data.dtype)
# self.no_idx: int = (2**62-1)*2+1
# self.cached_idx = torch.empty(cache_size, dtype=torch.long).fill_(self.no_idx)
# self.read_count = torch.zeros(data.size(0), dtype=torch.long)
# def to(self, device):
# self.cache = self.cache.to(device)
# self.cached_idx = self.cached_idx.to(device)
# self.read_count = self.read_count.to(device)
# return self
# def _push_impl(self, value: Tensor, index: Tensor):
# # self.read_count[index] += 1
# imp = torch.zeros_like(self.read_count)
# def _pull_impl(self, index: Tensor):
# assert index.device == self.cache.device
# self.read_count[index] += 1
# s = index.shape[:1] + self.cache.shape[1:]
# x = torch.empty(s, dtype=self.cache.dtype, device=self.cache.device)
# cache_mask = torch.zeros_like(self.read_count, dtype=torch.bool).index_fill_(0, self.cached_idx, 1)[index]
# cache_index = index[cache_mask]
# x[cache_mask] =
# no_cache_index = index[~cache_mask]
# rt_index = index[~lc_mask].to(self.fdata.device)
# rt_data = self.fdata[rt_index].to(index.device)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
......@@ -8,62 +8,11 @@ from torch import Tensor
from typing import *
from contextlib import contextmanager
from .a2a import all_to_all, Works, with_nccl
from .a2a import all_to_all
from .event import Event
# class GatherWork:
# def __init__(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# recv_val_idx: List[Tensor],
# recv_val_dat: List[Tensor],
# works_list: List[Works],
# ) -> None:
# self._waited = False
# self._values_buffer = values_buffer
# self._active_buffer = active_buffer
# self._recv_val_idx = recv_val_idx
# self._recv_val_dat = recv_val_dat
# self._works_list = works_list
# self._cached_idx: Optional[Tensor] = None
# def _wait(self) -> None:
# assert not self._waited
# for w in self._works_list:
# if w is None:
# continue
# w.wait()
# if self._active_buffer is None:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# else:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# self._active_buffer[idx] = True
# self._active_buffer = torch.where(self._active_buffer)[0]
# self._works_list = []
# self._recv_val_dat = []
# self._recv_val_idx = []
# self._waited = True
# def get_val(self) -> Tensor:
# if not self._waited:
# self._wait()
# return self._values_buffer
# def get_idx(self) -> Optional[Tensor]:
# if not self._waited:
# self._wait()
# return self._active_buffer
# def get(self) -> Tuple[Tensor, Optional[Tensor]]:
# return self.get_val(), self.get_idx()
class RouteWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
......@@ -87,69 +36,6 @@ class RouteWorkBase:
start, end = self._events
return start.elapsed_time(end)
class RouteWork(RouteWorkBase):
def __init__(self,
buf: Union[Tensor, int],
recv_val_idx: List[Tensor],
recv_val_dat: List[Tensor],
works_list: List[Works],
) -> None:
super().__init__()
assert len(recv_val_idx) != 0
assert len(recv_val_idx) == len(recv_val_dat)
self._buf = buf
self._recv_val_idx = recv_val_idx
self._recv_val_dat = recv_val_dat
self._works_list = works_list
self._val: Optional[Tensor] = None
self._idx: Optional[Tensor] = None
def wait(self) -> None:
if self._val is not None:
return
for w in self._works_list:
if w is None:
continue
w.wait()
ikw = dict(dtype=self._recv_val_idx[0].dtype, device=self._recv_val_idx[0].device)
fkw = dict(dtype=self._recv_val_dat[0].dtype, device=self._recv_val_dat[0].device)
if isinstance(self._buf, Tensor):
for idx in self._recv_val_idx:
self._buf[idx] = True
self._idx = torch.where(self._buf)[0]
imp = torch.empty(self._buf.size(0), **ikw).fill_((2**62-1)*2+1)
imp[self._idx] = torch.arange(self._idx.size(0), **ikw)
s = self._idx.shape[:1] + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
idx = imp[idx]
self._val[idx] += dat
else:
s = (self._buf,) + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._val[idx] += dat
self._buf = None
self._recv_val_dat = None
self._recv_val_idx = None
self._works_list = None
if self._events is not None:
self._events[1].record()
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
if self._val is None:
self.wait()
return self._val, self._idx
class RouteExecutorWork(RouteWorkBase):
def __init__(self, handle) -> None:
......@@ -167,18 +53,23 @@ class RouteExecutorWork(RouteWorkBase):
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class RouteExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
if torch.cuda.is_available():
self._stream = torch.cuda.Stream()
else:
self._stream = None
self._executor = ThreadPool(processes=1)
self._group = dist.new_group()
def async_forward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.forward_a2a, val, idx)
def async_backward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.backward_a2a, val, idx)
@torch.no_grad()
def _async_a2a_impl(self, func, val: Tensor, idx: Optional[Tensor]) -> RouteWorkBase:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
......@@ -187,43 +78,20 @@ class RouteExecutor:
def run():
start.wait(stream)
with torch.cuda.stream(stream):
ret = func(val, idx, group=self._group).get()
ret = func(val, idx, group=self._group)
end.record()
return ret
start.record()
handle = self._executor.apply_async(run)
return RouteExecutorWork(handle).set_events(start, end)
# def apply_async(self, func, val, idx, async_op) -> RouteWorkBase:
# start = Event(val.is_cuda)
# end = Event(val.is_cuda)
# if not async_op:
# start.record()
# return func(val, idx).set_events(start, end)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# ret = func(val, idx, group=self._group).get()
# end.record()
# return ret
# start.record()
# handle = self._executor.apply_async(run)
# return RouteExecutorWork(handle).set_events(start, end)
class RouteCtx:
def __init__(self, route, state, async_op) -> None:
self.route = route
self.state = state
self.async_op = async_op
class Route:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool = False,
group: Any = None,
):
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
......@@ -249,10 +117,10 @@ class Route:
# 获得每个分区发送方的节点个数
all_src_sizes: List[torch.Size] = [None] * world_size
dist.all_gather_object(all_src_sizes, src_ids.size())
dist.all_gather_object(all_src_sizes, src_ids.size(), group=group)
# 获得总节点个数
num_nodes = _all_reduce_num_nodes(src_ids, dst_ids)
num_nodes = _all_reduce_num_nodes(src_ids, dst_ids, group=group)
# dst_ids节点到局部编号的映射
imp = torch.empty(num_nodes, **ikw).fill_((2**62-1)*2+1)
......@@ -268,7 +136,7 @@ class Route:
s = all_src_sizes[i]
all_src_ids[i] = torch.empty(s, **ikw)
all_src_get[i] = dist.broadcast(
all_src_ids[i], src=i, async_op=True)
all_src_ids[i], src=i, async_op=True, group=group)
self.forward_routes: List[Tensor] = []
self.backward_routes: List[Tensor] = []
......@@ -307,7 +175,7 @@ class Route:
self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _fix_fw_routes(self.forward_routes)
self.forward_routes = _fix_fw_routes(self.forward_routes, group=group)
# 满足同构图条件,则每个点添加自环
if not bipartite:
......@@ -315,8 +183,6 @@ class Route:
rank_route = torch.vstack([rank_ind, rank_ind])
self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route
self.total_time_used: float = 0.0
dist.barrier()
def to(self, device):
......@@ -325,300 +191,188 @@ class Route:
return self
def _a2a_impl(self,
val: Tensor,
idx: Optional[Tensor],
send_buf_size: int,
recv_buf_size: int,
value: Tensor,
index: Optional[Tensor],
smask: Union[Tensor, int],
rmask: Union[Tensor, int],
send_routes: List[Tensor],
recv_routes: List[Tensor],
group: Optional[Any],
) -> RouteWork:
bkw = dict(dtype=torch.bool, device=val.device)
ikw = dict(dtype=torch.long, device=val.device)
fkw = dict(dtype=val.dtype, device=val.device)
if idx is not None:
msk = torch.zeros(send_buf_size, **bkw).index_fill_(0, idx, 1)
send_routes = [ro[:,msk[ro[0]]] for ro in send_routes]
) -> Tuple[Tensor, Optional[Tensor]]:
bkw = dict(dtype=torch.bool, device=value.device)
ikw = dict(dtype=torch.long, device=value.device)
fkw = dict(dtype=value.dtype, device=value.device)
send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes, group=group)
recv_sizes = recv_sizes.tolist()
send_buf_size = smask.size(0) if isinstance(smask, Tensor) else smask
recv_buf_size = rmask.size(0) if isinstance(rmask, Tensor) else rmask
if idx is None:
send_val_dat, recv_val_dat = [], []
if index is None:
send_val_idx = [ro[0] for ro in send_routes]
recv_val_idx = [ro[0] for ro in recv_routes]
for i, ro in enumerate(recv_routes):
s = ro.size(1)
c = (s,) + val.shape[1:]
send_val_dat.append(val[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_dat, send_val_dat, group=group),
]
return RouteWork(
buf=recv_buf_size,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
send_val_dat, recv_val_dat = [], []
for sidx, ridx in zip(send_val_idx, recv_val_idx):
s = (ridx.size(0),) + value.shape[1:]
send_val_dat.append(value[sidx])
recv_val_dat.append(torch.zeros(s, **fkw))
all_to_all(recv_val_dat, send_val_dat, group=group)
s = (recv_buf_size,) + value.shape[1:]
recv_value = torch.zeros(s, **fkw)
for idx, dat in zip(recv_val_idx, recv_val_dat):
recv_value[idx] += dat
return recv_value, None
else:
assert value.size(0) == index.size(0)
if not isinstance(smask, Tensor):
smask = torch.zeros(send_buf_size, **bkw).index_fill_(0, index, 1)
send_routes = [ro[:,smask[ro[0]]] for ro in send_routes]
send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes, group=group)
recv_sizes = recv_sizes.tolist()
imp = torch.empty(send_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[idx] = torch.arange(idx.size(0), **ikw)
imp[index] = torch.arange(index.size(0), **ikw)
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], []
for i, s in enumerate(recv_sizes):
c = (s,) + val.shape[1:]
c = (s,) + value.shape[1:]
send_val_idx.append(send_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_index = imp[send_routes[i][0]]
send_val_dat.append(val[send_index])
send_val_dat.append(value[send_index])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_idx, send_val_idx, group=group),
all_to_all(recv_val_dat, send_val_dat, group=group),
]
all_to_all(recv_val_idx, send_val_idx, group=group)
all_to_all(recv_val_dat, send_val_dat, group=group)
recv_buf = torch.zeros(recv_buf_size, **bkw)
return RouteWork(
buf=recv_buf,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
rmask = rmask if isinstance(rmask, Tensor) else torch.zeros(recv_buf_size, **bkw)
for idx in recv_val_idx:
rmask[idx] = True
recv_index = torch.where(rmask)[0]
imp = torch.empty(recv_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[recv_index] = torch.arange(recv_index.size(0), **ikw)
s = recv_index.shape[:1] + value.shape[1:]
recv_value = torch.zeros(s, **fkw)
for idx, dat in zip(recv_val_idx, recv_val_dat):
recv_value[imp[idx]] += dat
return recv_value, recv_index
def forward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
value: Tensor,
index: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.src_size
) -> Tuple[Tensor, Optional[Tensor]]:
if index is None or index.dtype == torch.long:
return self._a2a_impl(
value=value,
index=index,
smask=self.src_size,
rmask=self.dst_size,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
group=group,
)
else:
assert val.size(0) == idx.size(0)
if group is None:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
start.record()
work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.src_size,
recv_buf_size=self.dst_size,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
group=group,
)
if group is None:
return work.set_events(start, end)
return work
smask, index = index, torch.where(index)[0]
assert smask.size(0) == self.src_size
rmask = torch.zeros(self.dst_size, dtype=smask.dtype, device=smask.device)
value, index = self._a2a_impl(
value=value,
index=index,
smask=smask,
rmask=rmask,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
group=group,
)
return value, rmask
def backward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
value: Tensor,
index: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.dst_size
) -> Tuple[Tensor, Optional[Tensor]]:
if index is None or index.dtype == torch.long:
return self._a2a_impl(
value=value,
index=index,
smask=self.dst_size,
rmask=self.src_size,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
else:
assert val.size(0) == idx.size(0)
if group is None:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
start.record()
work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.dst_size,
recv_buf_size=self.src_size,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
if group is None:
return work.set_events(start, end)
return work
smask, index = index, torch.where(index)[0]
assert smask.size(0) == self.dst_size
rmask = torch.zeros(self.src_size, dtype=smask.dtype, device=smask.device)
value, index = self._a2a_impl(
value=value,
index=index,
smask=smask,
rmask=rmask,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
return value, rmask
def apply(self,
val: Tensor,
idx: Optional[Tensor],
state: Dict[str, Any],
async_op: bool = False,
value: Tensor,
index: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
with self._a2a_manager(state, async_op):
return RouteFunction.apply(val, idx)
with self._a2a_manager():
return RouteFunction.apply(value, index)
@contextmanager
def _a2a_manager(self, state, async_op) -> RouteCtx:
def _a2a_manager(self):
global _global_route_context
stacked_ctx = _global_route_context
try:
_global_route_context = RouteCtx(self, state, async_op)
_global_route_context = RouteCtx(self)
yield _global_route_context
finally:
_global_route_context = stacked_ctx
# def _gather_impl(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# values: Tensor,
# active: Optional[Tensor],
# send_routes: List[Tensor],
# recv_routes: List[Tensor],
# async_op: bool,
# ):
# ikw = dict(dtype=torch.long, device=values.device)
# fkw = dict(dtype=values.dtype, device=values.device)
# # 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
# if active is not None:
# # 计算新的路由表
# active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# # 计算接收缓冲区大小
# scatter_sizes = [ro.size(1) for ro in active_routes]
# scatter_sizes = torch.tensor(scatter_sizes, **ikw)
# gather_sizes = torch.zeros_like(scatter_sizes)
# dist.all_to_all_single(gather_sizes, scatter_sizes)
# gather_sizes = gather_sizes.tolist()
# def async_run():
# if active is not None:
# send_val_idx, recv_val_idx = [], []
# send_val_dat, recv_val_dat = [], []
# for i, s in enumerate(gather_sizes):
# c = (s,) + values.shape[1:]
# send_val_idx.append(active_routes[i][1])
# recv_val_idx.append(torch.zeros(s, **ikw))
# send_val_dat.append(values[active_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_idx, send_val_idx),
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# else:
# send_val_dat, recv_val_dat = [], []
# recv_val_idx = [ro[0] for ro in recv_routes]
# for i, r in enumerate(recv_routes):
# s = r.size(1)
# c = (s,) + values.shape[1:]
# send_val_dat.append(values[send_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# if async_op and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# return async_run()
# else:
# return async_run()
# def gather_forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.src_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.dst_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.dst_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.forward_routes,
# recv_routes=self.backward_routes,
# async_op=async_op,
# )
# def gather_backward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.dst_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.src_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.src_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.backward_routes,
# recv_routes=self.forward_routes,
# async_op=async_op,
# )
class RouteCtx:
def __init__(self, route) -> None:
self.route = route
class RouteFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
val: Tensor,
idx: Optional[Tensor],
value: Tensor,
index: Optional[Tensor],
):
route_ctx = get_global_route_context()
route: Route = route_ctx.route
state: Dict[str, Any] = route_ctx.state
async_op: bool = route_ctx.async_op
# cur_work = get_executor().apply_async(route.forward_a2a, val, idx, async_op)
cur_work = get_executor().async_forward_a2a(val, idx, route)
if async_op:
# cur_work = get_executor().async_forward_a2a(val, idx, route)
work = state.get("__route_fw_work__") or cur_work
state["__route_fw_work__"] = cur_work
else:
# work = route.forward_a2a(val, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
value = value.detach()
if index is not None:
index = index.detach()
ctx.route = route
ctx.state = state
ctx.async_op = async_op
ctx.save_for_backward(idx)
work = get_route_executor().async_forward_a2a(value, index, route)
recv_value, recv_index = work.get()
# if work.has_events():
# route.total_time_used += work.time_used()
return val, idx
ctx.route = route
ctx.save_for_backward(recv_index)
return recv_value, recv_index
@staticmethod
def backward(
......@@ -627,47 +381,38 @@ class RouteFunction(autograd.Function):
_: None,
):
route: Route = ctx.route
state: Dict[str, Any] = ctx.state
async_op: bool = ctx.async_op
index, = ctx.saved_tensors
grad = grad.detach()
with torch.no_grad():
idx, = ctx.saved_tensors
# cur_work = get_executor().apply_async(route.backward_a2a, grad, idx, async_op)
cur_work = get_executor().async_backward_a2a(grad, idx, route)
if async_op:
# cur_work = get_executor().async_backward_a2a(grad, idx, route)
work = state.get("__route_bw_work__") or cur_work
state["__route_bw_work__"] = cur_work
else:
# work = route.backward_a2a(grad, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
work = get_route_executor().async_backward_a2a(grad, index, route)
recv_grad, recv_index = work.get()
# if work.has_events():
# route.total_time_used += work.time_used()
return recv_grad, recv_index
return val, idx
#### private functions
def _all_reduce_num_nodes(
src_ids: Tensor,
dst_ids: Tensor,
group: Any = None,
) -> int:
max_ids = torch.zeros(1, dtype=src_ids.dtype, device=src_ids.device)
max_ids = max_ids.max(src_ids.max()) if src_ids.numel() > 0 else max_ids
max_ids = max_ids.max(dst_ids.max()) if dst_ids.numel() > 0 else max_ids
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX)
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX, group=group)
return max_ids.item() + 1
def _fix_fw_routes(tensors: List[Tensor]) -> List[Tensor]:
def _fix_fw_routes(tensors: List[Tensor], group: Any = None) -> List[Tensor]:
rank = dist.get_rank()
world_size = dist.get_world_size()
tensor_sizes: List[torch.Size] = [t.size() for t in tensors]
all_tensor_sizes: List[List[torch.Size]] = [None] * world_size
dist.all_gather_object(all_tensor_sizes, tensor_sizes)
dist.all_gather_object(all_tensor_sizes, tensor_sizes, group=group)
new_tensors: List[Tensor] = []
for i in range(world_size):
......@@ -675,27 +420,9 @@ def _fix_fw_routes(tensors: List[Tensor]) -> List[Tensor]:
t = torch.zeros(s).type_as(tensors[i])
new_tensors.append(t)
all_to_all(new_tensors, tensors).wait()
all_to_all(new_tensors, tensors, group=group)
return new_tensors
# def _spread_val_idx(val: Tensor, idx: Optional[Tensor], size: int) -> Tuple[Tensor, Optional[Tensor]]:
# if idx is None:
# assert val.size(0) == size
# return val, None
# else:
# assert val.size(0) == idx.size(0)
# x = torch.zeros(
# size=(size,) + val.shape[1:],
# dtype=val.dtype,
# device=val.device,
# )
# x[idx] = val
# y = torch.zeros(
# size=(size,),
# dtype=torch.bool,
# device=idx.device,
# ).index_fill_(0, idx, 1)
# return x, y
_global_route_context: Optional[RouteCtx] = None
def get_global_route_context() -> RouteCtx:
......@@ -703,26 +430,9 @@ def get_global_route_context() -> RouteCtx:
assert _global_route_context is not None
return _global_route_context
# 创建一个新的stream,专用于Gather异步通信
# _STREAM: Optional[torch.cuda.Stream] = None
# def get_stream() -> torch.cuda.Stream:
# global _STREAM
# if _STREAM is None:
# _STREAM = torch.cuda.Stream()
# return _STREAM
# @contextmanager
# def stream_manager(enable: bool = True) -> Optional[torch.cuda.Stream]:
# if enable and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# yield stream
# else:
# yield
_THREAD_EXEC: Optional[RouteExecutor] = None
def get_executor() -> RouteExecutor:
def get_route_executor() -> RouteExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = RouteExecutor()
......
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
# class ShrinkData:
# no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
# def __init__(self,
# src_size: int,
# dst_size: int,
# dst_idx: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> None:
# device = dst_idx.device
# tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
# tmp.fill_(0)
# tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
# edge_idx = torch.where(tmp[edge_index[1]])[0]
# edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
# if bipartite:
# tmp.fill_(0)
# tmp.index_fill_(0, edge_index[0], 1)
# src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
# dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# src = imp[edge_index[0]]
# edge_index = torch.vstack([src, dst])
# else:
# tmp.index_fill_(0, edge_index[0], 1)
# tmp.index_fill_(0, dst_idx, 0)
# src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edge_index = edge_index
# self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
# def to(self, device):
# self.src_idx = self.src_idx.to(device)
# self.dst_idx = self.dst_idx.to(device)
# self.edge_idx = self.edge_idx.to(device)
# self.edge_index = self.edge_index.to(device)
# self._src_imp = self._src_imp.to(device)
# return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
# def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
# idx = self._src_imp[idx]
# m = (idx != self.no_idx)
# return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
# @property
# def edge_size(self) -> int:
# return self.edge_idx.size(0)
from .distgraph import DistGraph
\ No newline at end of file
# from .distgraph import DistGraph
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from contextlib import contextmanager
# import torch_sparse
from .ndata import NData
from .edata import EData
from .utils import init_local_edge_index
from ..core import MessageCache, Route
class DistGraph:
def __init__(self,
ids: Tensor,
edge_index: Tensor,
num_features: int,
num_layers: int,
cache_device: str = "cpu",
**args: Dict[str, Any],
):
# build local_edge_index
dst_ids = ids
src_ids, local_edge_index = init_local_edge_index(
dst_ids=dst_ids,
edge_index=edge_index,
)
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# from contextlib import contextmanager
# # import torch_sparse
# from .ndata import NData
# from .edata import EData
# from .utils import init_local_edge_index
# from ..core import MessageCache, Route
# class DistGraph:
# def __init__(self,
# ids: Tensor,
# edge_index: Tensor,
# num_features: int,
# num_layers: int,
# cache_device: str = "cpu",
# **args: Dict[str, Any],
# ):
# # build local_edge_index
# dst_ids = ids
# src_ids, local_edge_index = init_local_edge_index(
# dst_ids=dst_ids,
# edge_index=edge_index,
# )
self._src_ids = src_ids
self._dst_ids = dst_ids
self._message_cache = MessageCache(
src_ids=src_ids,
dst_ids=dst_ids,
edge_index=local_edge_index,
num_features=num_features,
num_layers=num_layers,
cache_device=cache_device,
bipartite=False,
)
# node's attributes
self.ndata = NData(
src_size=src_ids.size(0),
dst_size=dst_ids.size(0),
)
# self._src_ids = src_ids
# self._dst_ids = dst_ids
# self._message_cache = MessageCache(
# src_ids=src_ids,
# dst_ids=dst_ids,
# edge_index=local_edge_index,
# num_features=num_features,
# num_layers=num_layers,
# cache_device=cache_device,
# bipartite=False,
# )
# # node's attributes
# self.ndata = NData(
# src_size=src_ids.size(0),
# dst_size=dst_ids.size(0),
# )
# edge's attributes
self.edata = EData(
edge_size=local_edge_index.size(1),
)
# # edge's attributes
# self.edata = EData(
# edge_size=local_edge_index.size(1),
# )
# graph's attributes
self.args = dict(args)
# # graph's attributes
# self.args = dict(args)
def to(self, device):
self._message_cache.to(device)
return self
# def to(self, device):
# self._message_cache.to(device)
# return self
def cache_data_to(self, device):
self._message_cache.cached_data_to(device)
# def cache_data_to(self, device):
# self._message_cache.cached_data_to(device)
@property
def cache(self) -> MessageCache:
return self._message_cache
# @property
# def cache(self) -> MessageCache:
# return self._message_cache
@property
def route(self) -> Route:
return self._message_cache.route
# @property
# def route(self) -> Route:
# return self._message_cache.route
@property
def device(self) -> torch.device:
return self.edge_index.device
# @property
# def device(self) -> torch.device:
# return self.edge_index.device
@property
def edge_index(self) -> Tensor:
return self._message_cache.edge_index
# @property
# def edge_index(self) -> Tensor:
# return self._message_cache.edge_index
@property
def src_ids(self) -> Tensor:
if self._src_ids.device != self.device:
self._src_ids = self._src_ids.to(self.device)
return self._src_ids
# @property
# def src_ids(self) -> Tensor:
# if self._src_ids.device != self.device:
# self._src_ids = self._src_ids.to(self.device)
# return self._src_ids
@property
def src_size(self) -> int:
return self._src_ids.size(0)
# @property
# def src_size(self) -> int:
# return self._src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
if self._dst_ids.device != self.device:
self._dst_ids = self._dst_ids.to(self.device)
return self._dst_ids
# @property
# def dst_ids(self) -> Tensor:
# if self._dst_ids.device != self.device:
# self._dst_ids = self._dst_ids.to(self.device)
# return self._dst_ids
@property
def dst_size(self) -> int:
return self._dst_ids.size(0)
# @property
# def dst_size(self) -> int:
# return self._dst_ids.size(0)
@contextmanager
def scoped_manager(self):
stacked_ndata = self.ndata
stacked_edata = self.edata
try:
self.ndata = NData(
p=stacked_ndata,
)
self.edata = EData(
p=stacked_edata,
)
yield self
finally:
self.ndata = stacked_ndata
self.edata = stacked_edata
# @contextmanager
# def scoped_manager(self):
# stacked_ndata = self.ndata
# stacked_edata = self.edata
# try:
# self.ndata = NData(
# p=stacked_ndata,
# )
# self.edata = EData(
# p=stacked_edata,
# )
# yield self
# finally:
# self.ndata = stacked_ndata
# self.edata = stacked_edata
# def permute_edge_(self):
# perm = self._local_edge_index[1].argsort()
# self._local_edge_index = self._local_edge_index[:,perm]
# self._local_edge_ptr = torch.ops.torch_sparse.ind2ptr(self._local_edge_index[1], self.dst_size)
# self.edata.permute_(perm)
# # def permute_edge_(self):
# # perm = self._local_edge_index[1].argsort()
# # self._local_edge_index = self._local_edge_index[:,perm]
# # self._local_edge_ptr = torch.ops.torch_sparse.ind2ptr(self._local_edge_index[1], self.dst_size)
# # self.edata.permute_(perm)
# @torch.no_grad()
# def shrink(self,
# src_mask: Optional[Tensor] = None,
# dst_mask: Optional[Tensor] = None,
# ):
# edge_index = self.edge_index
# device = edge_index.device
# ikw = dict(dtype=torch.long, device=device)
# bkw = dict(dtype=torch.bool, device=device)
# if src_mask is None and dst_mask is None:
# return self
# else:
# if dst_mask is None:
# dst_mask = torch.zeros(self.ndata.dst_size, **bkw)
# else:
# assert dst_mask.size(0) == self.ndata.dst_size
# dst_mask = dst_mask.clone()
# if src_mask is not None:
# m = src_mask[edge_index[0]]
# m = edge_index[1][m]
# dst_mask[m] = True
# edge_mask = dst_mask[edge_index[1]]
# edge_index = edge_index[:,edge_mask]
# src_mask = torch.zeros(self.ndata.src_size, **bkw)
# src_mask[:self.ndata.dst_size] = dst_mask
# src_mask[edge_index[0]] = True
# dst_mask = src_mask[:self.ndata.dst_size]
# # 重新编号edge_index
# imp = torch.empty(self.ndata.src_size, **ikw).fill_((2**62-1)*2+1)
# idx = torch.where(src_mask)[0]
# imp[idx] = torch.arange(idx.size(0), **ikw)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # assert dst_mask.count_nonzero() > edge_index[1].max().item()
# ndata = NData(
# src_size=None,
# dst_size=None,
# p=self.ndata,
# src_mask=src_mask,
# dst_mask=dst_mask,
# )
# edata = EData(
# edge_size=None,
# p=self.edata,
# edge_mask=edge_mask,
# )
# edata[EID] = edge_index
# return DistGraph(self.args, ndata, edata, self.route)
# class ShrinkGraph:
# def __init__(self,
# g: DistGraph,
# src_mask: Optional[Tensor] = None,
# dst_mask: Optional[Tensor] = None,
# ) -> None:
# device = g.edge_index.device
# ikw = dict(dtype=torch.long, device=device)
# bkw = dict(dtype=torch.bool, device=device)
# if src_mask is None and dst_mask is None:
# # 如果src_mask和dst_mask都不指定,
# # 则ShrinkGraph是DistGraph的复刻
# self.src_ids = g.sr
# self.dst_size = g.dst_size
# self.src_mask = torch.ones(self.src_size, **bkw)
# self.dst_mask_size = g.dst_size
# self.edge_index = g.edge_index
# self.pgraph = g
# else:
# if src_mask is not None:
# tmp_mask = torch.zeros(g.dst_size, **bkw)
# # @torch.no_grad()
# # def shrink(self,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ):
# # edge_index = self.edge_index
# # device = edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # return self
# # else:
# # if dst_mask is None:
# # dst_mask = torch.zeros(self.ndata.dst_size, **bkw)
# # else:
# # assert dst_mask.size(0) == self.ndata.dst_size
# # dst_mask = dst_mask.clone()
# # if src_mask is not None:
# # m = src_mask[edge_index[0]]
# # m = edge_index[1][m]
# # dst_mask[m] = True
# # edge_mask = dst_mask[edge_index[1]]
# # edge_index = edge_index[:,edge_mask]
# # src_mask = torch.zeros(self.ndata.src_size, **bkw)
# # src_mask[:self.ndata.dst_size] = dst_mask
# # src_mask[edge_index[0]] = True
# # dst_mask = src_mask[:self.ndata.dst_size]
# # # 重新编号edge_index
# # imp = torch.empty(self.ndata.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # # assert dst_mask.count_nonzero() > edge_index[1].max().item()
# # ndata = NData(
# # src_size=None,
# # dst_size=None,
# # p=self.ndata,
# # src_mask=src_mask,
# # dst_mask=dst_mask,
# # )
# # edata = EData(
# # edge_size=None,
# # p=self.edata,
# # edge_mask=edge_mask,
# # )
# # edata[EID] = edge_index
# # return DistGraph(self.args, ndata, edata, self.route)
# # class ShrinkGraph:
# # def __init__(self,
# # g: DistGraph,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ) -> None:
# # device = g.edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # # 如果src_mask和dst_mask都不指定,
# # # 则ShrinkGraph是DistGraph的复刻
# # self.src_ids = g.sr
# # self.dst_size = g.dst_size
# # self.src_mask = torch.ones(self.src_size, **bkw)
# # self.dst_mask_size = g.dst_size
# # self.edge_index = g.edge_index
# # self.pgraph = g
# # else:
# # if src_mask is not None:
# # tmp_mask = torch.zeros(g.dst_size, **bkw)
# # 计算直接激活的边
# m = src_mask[g.edge_index[0]]
# m = g.edge_index[1][m]
# # # 计算直接激活的边
# # m = src_mask[g.edge_index[0]]
# # m = g.edge_index[1][m]
# # 计算直接激活的dst_ids
# tmp_mask[m] = True
# # # 计算直接激活的dst_ids
# # tmp_mask[m] = True
# # 和已激活的dst_ids合并
# if dst_mask is not None:
# tmp_mask |= dst_mask
# dst_mask = tmp_mask
# # # 和已激活的dst_ids合并
# # if dst_mask is not None:
# # tmp_mask |= dst_mask
# # dst_mask = tmp_mask
# # 计算间接激活的边
# edge_mask = dst_mask[g.edge_index[1]]
# edge_index = g.edge_index[:,edge_mask]
# # 计算间接激活的src_ids
# src_mask = torch.zeros(g.src_size, **bkw)
# src_mask[edge_index[0]] = True
# src_mask[:g.dst_size] |= dst_mask
# self.src_ids = g.src_ids[src_mask]
# self.dst_size = src_mask[:g.dst_size].count_nonzero().item()
# self.src_mask = src_mask
# self.dst_mask_size = g.dst_size
# # 重新编号edge_index
# imp = torch.empty(g.src_size, **ikw).fill_((2**62-1)*2+1)
# idx = torch.where(src_mask)[0]
# imp[idx] = torch.arange(idx.size(0), **ikw)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
# self.edge_index = edge_index
# self.pgraph = g
# self.ndata = ShrinkData(self.dst_mask, g.ndata)
# self.edata = ShrinkData(self.edge_mask, g.edata)
\ No newline at end of file
# # # 计算间接激活的边
# # edge_mask = dst_mask[g.edge_index[1]]
# # edge_index = g.edge_index[:,edge_mask]
# # # 计算间接激活的src_ids
# # src_mask = torch.zeros(g.src_size, **bkw)
# # src_mask[edge_index[0]] = True
# # src_mask[:g.dst_size] |= dst_mask
# # self.src_ids = g.src_ids[src_mask]
# # self.dst_size = src_mask[:g.dst_size].count_nonzero().item()
# # self.src_mask = src_mask
# # self.dst_mask_size = g.dst_size
# # # 重新编号edge_index
# # imp = torch.empty(g.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # self.edge_index = edge_index
# # self.pgraph = g
# # self.ndata = ShrinkData(self.dst_mask, g.ndata)
# # self.edata = ShrinkData(self.edge_mask, g.edata)
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class EData:
def __init__(self,
edge_size: Optional[int] = None,
p = None,
) -> None:
if p is None:
assert edge_size is not None
self.edge_size = edge_size
else:
assert edge_size is None
self.edge_size = p.edge_size
# class EData:
# def __init__(self,
# edge_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert edge_size is not None
# self.edge_size = edge_size
# else:
# assert edge_size is None
# self.edge_size = p.edge_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
raise ValueError(f"not found '{name}' in data")
return t
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError(f"the second parameter's type must be Tensor")
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError(f"the second parameter's type must be Tensor")
if tensor.size(0) == self.edge_size:
self.data[name] = tensor
else:
raise ValueError(f"tensor's shape must match the edge_size")
# if tensor.size(0) == self.edge_size:
# self.data[name] = tensor
# else:
# raise ValueError(f"tensor's shape must match the edge_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
# def __delitem__(self, name: str) -> None:
# self.pop(name)
def get(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = p.prev_data
return t
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
def pop(self, name: str) -> Tensor:
return self.data.pop(name)
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
def permute_(self, perm: Tensor):
p = self
while p is not None:
for key in list(p.data.keys()):
val = p.data.get(key)
p.data[key] = val[perm]
p = p.prev_data
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
class NData:
def __init__(self,
src_size: Optional[int] = None,
dst_size: Optional[int] = None,
p = None,
) -> None:
if p is None:
assert src_size is not None
assert dst_size is not None
self.src_size = src_size
self.dst_size = dst_size
else:
assert src_size is None
assert dst_size is None
self.src_size = p.src_size
self.dst_size = p.dst_size
# class NData:
# def __init__(self,
# src_size: Optional[int] = None,
# dst_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert src_size is not None
# assert dst_size is not None
# self.src_size = src_size
# self.dst_size = dst_size
# else:
# assert src_size is None
# assert dst_size is None
# self.src_size = p.src_size
# self.dst_size = p.dst_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
raise ValueError(f"not found '{name}' in data")
return t
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError("the second parameter's type must be Tensor")
if tensor.size(0) == self.src_size:
self.data[name] = tensor
elif tensor.size(0) == self.dst_size:
self.data[name] = tensor
else:
raise ValueError("tensor's shape must match the src_size or dst_size")
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError("the second parameter's type must be Tensor")
# if tensor.size(0) == self.src_size:
# self.data[name] = tensor
# elif tensor.size(0) == self.dst_size:
# self.data[name] = tensor
# else:
# raise ValueError("tensor's shape must match the src_size or dst_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
# def __delitem__(self, name: str) -> None:
# self.pop(name)
def get(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = p.prev_data
return t
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
def pop(self, name: str) -> Tensor:
return self.data.pop(name)
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
def permute_(self, perm: Tensor):
p = self
while p is not None:
for key in list(p.data.keys()):
val = p.data.get(key)
p.data[key] = val[perm]
p = p.prev_data
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
def get_type(self, key: Union[str, Tensor]):
if isinstance(key, Tensor):
if key.size(0) == self.src_size:
return "src"
elif key.size(0) == self.dst_size:
return "dst"
else:
raise RuntimeError
t = self.__getitem__(key)
if t.size(0) == self.src_size:
return "src"
elif t.size(0) == self.dst_size:
return "dst"
else:
raise RuntimeError
\ No newline at end of file
# def get_type(self, key: Union[str, Tensor]):
# if isinstance(key, Tensor):
# if key.size(0) == self.src_size:
# return "src"
# elif key.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
# t = self.__getitem__(key)
# if t.size(0) == self.src_size:
# return "src"
# elif t.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# def init_local_edge_index(
# dst_ids: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# max_ids = calc_max_ids(dst_ids, edge_index)
# ikw = dict(dtype=torch.long, device=dst_ids.device)
# xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
# # 判断是不是点分割且所有边被划分到目标点所在分区
# xmp[edge_index[1].unique()] += 0b01
# xmp[dst_ids.unique()] += 0b10
# if not (xmp != 0x01).all():
# raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# if bipartite:
# src_ids = edge_index[0].unique()
# else:
# # 假设是同构图
# # src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
# xmp.fill_(0)
# xmp[edge_index[0]] = 1
# xmp[dst_ids] = 0
# src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
# # 计算局部索引
# xmp.fill_((2**62-1)*2+1)
# xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
# src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
# xmp.fill_((2**62-1)*2+1)
# xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
# dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
# local_edge_index = torch.vstack([src, dst])
# return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
\ No newline at end of file
# def calc_max_ids(*ids: Tensor) -> int:
# x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
# return max(*x)
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from .buffer import TensorBuffers
class BatchHandle:
def __init__(self,
src_ids: Tensor,
dst_size: int,
feat_buffer: TensorBuffers,
grad_buffer: TensorBuffers,
with_feat0: bool = False,
layer_id: Optional[int] = None,
target_device: Any = None,
) -> None:
self.src_ids = src_ids
self.dst_size = dst_size
self.feat_buffer = feat_buffer
self.grad_buffer = grad_buffer
self.with_feat0 = with_feat0
self.layer_id = layer_id
if target_device is None:
self.target_device = src_ids.device
else:
self.target_device = torch.device(target_device)
if with_feat0:
_ = self.feat0
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def device(self) -> torch.device:
return self.src_ids.device
@property
def feat0(self) -> Tensor:
if not hasattr(self, "_feat0"):
self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
return self._feat0
def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor:
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device)
def update_feat(self, x: Tensor, layer_id: Optional[int] = None):
assert x.size(0) == self.dst_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
assert x.size(0) == self.dst_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
return torch.cat([x, o], dim=0)
def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor:
layer_id = int(self.layer_id if layer_id is None else layer_id)
return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device)
def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
assert x.size(0) == self.src_size
layer_id = int(self.layer_id if layer_id is None else layer_id)
self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class DataLoader:
def __init__(self,
node_parts: Tensor,
feat_buffer: TensorBuffers,
grad_buffer: TensorBuffers,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
) -> None:
self.src_size = feat_buffer.src_size
self.dst_size = node_parts.size(0)
assert node_parts.size(0) <= self.src_size
assert grad_buffer.src_size == self.src_size
num_parts = node_parts.max().item() + 1
cluster, node_perm = node_parts.sort(dim=0)
node_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
edge_parts = node_parts[edge_index[1]]
cluster, edge_perm = edge_parts.sort()
edge_ptr: Tensor = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
self.num_parts = num_parts
self.node_ptr = node_ptr
self.edge_ptr = edge_ptr
self.node_perm = node_perm
self.edge_perm = edge_perm
self.edge_index = edge_index[:,edge_perm]
if edge_attr is not None:
self.edge_attr = edge_attr[edge_perm]
if node_time is not None:
self.node_time = node_time[node_perm]
if edge_time is not None:
self.edge_time = edge_time[edge_perm]
self.feat_buffer = feat_buffer
self.grad_buffer = grad_buffer
def iter(self, batch_size: int = 1, layer_id: Optional[int] = None, seed: int = 0, filter: Callable[[Tensor], Tensor] = None, device = None):
rnd = torch.Generator()
if seed != 0:
rnd.manual_seed(seed)
sampled = torch.randperm(self.num_parts, generator=rnd, dtype=torch.long)
s = 0
imp = torch.empty(self.src_size, dtype=torch.long)
while s < sampled.size(0):
t = min(s + batch_size, sampled.size(0))
dst_ids = []
edge_index = []
edge_attr = []
imp.zero_()
for i in sampled[s:t].tolist():
s += batch_size
a, b = self.node_ptr[i:i+2].tolist()
nidx = self.node_perm[a:b]
if hasattr(self, "node_time") and filter is not None:
node_mask = filter(self.node_time[a:b])
if not node_mask.any():
continue
nidx = nidx[node_mask]
else:
node_mask = None
a, b = self.edge_ptr[i:i+2].tolist()
eidx = self.edge_index[:, a:b]
if hasattr(self, "edge_time") and filter is not None:
if node_mask is None:
edge_mask = filter(self.edge_time[a:b])
else:
imp[nidx] = i+1
edge_mask = (imp[eidx[1]] == i+1)
edge_mask &= filter(self.edge_time[a:b])
if not edge_mask.any():
continue
eidx = eidx[:, edge_mask]
else:
edge_mask = None
dst_ids.append(nidx)
edge_index.append(eidx)
if hasattr(self, "edge_attr"):
attr = self.edge_attr[a:b]
if edge_mask is None:
edge_attr.append(attr)
else:
edge_attr.append(attr[edge_mask])
if len(dst_ids) == 0 or len(edge_index) == 0:
continue
dst_ids = torch.cat(dst_ids, dim=-1)
edge_index = torch.cat(edge_index, dim=-1)
if hasattr(self, "edge_attr"):
edge_attr = torch.cat(edge_attr, dim=0)
else:
edge_attr = None
imp.zero_()
imp.index_fill_(0, edge_index[0], 1).index_fill_(0, dst_ids, 0)
src_ids = torch.cat([dst_ids, torch.where(imp > 0)[0]], dim=-1)
assert (src_ids[:dst_ids.size(0)] == dst_ids).all()
imp.fill_((2**62-1)*2+1)
imp[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
handle = BatchHandle(
src_ids, dst_ids.size(0),
self.feat_buffer, self.grad_buffer,
with_feat0 = (layer_id is None),
layer_id = layer_id,
target_device = device,
)
edge_index = edge_index.to(device)
edge_attr = edge_attr.to(device)
yield handle, edge_index, edge_attr
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
class TensorBuffers(nn.Module):
def __init__(self, src_size: int, channels: Tuple[Union[int, Tuple[int]]]) -> None:
super().__init__()
self.src_size = src_size
self.num_layers = len(channels)
for i, s in enumerate(channels):
s = (s,) if isinstance(s, int) else s
self.register_buffer(f"data_{i}", torch.zeros(src_size, *s), persistent=False)
def get(self, i: int, index: Optional[Tensor]) -> Tensor:
i = self._idx(i)
if index is None:
return self.get_buffer(f"data_{i}")
else:
return self.get_buffer(f"data_{i}")[index]
def set(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
if index is None:
self.get_buffer(f"data_{i}")[:,...] = value
else:
self.get_buffer(f"data_{i}")[index] = value
def add(self, i: int, index: Optional[Tensor], value: Tensor):
i = self._idx(i)
if index is None:
self.get_buffer(f"data_{i}")[:,...] += value
else:
self.get_buffer(f"data_{i}")[index] += value
def zero_grad(self):
for name, grad in self.named_buffers("data_", recurse=False):
grad.zero_()
def _idx(self, i: int) -> int:
assert -self.num_layers < i and i < self.num_layers
return (self.num_layers + i) % self.num_layers
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from starrygl.parallel import get_compute_device
from starrygl.core.route import Route
from starrygl.utils.partition import metis_partition
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def collect_feat0(src_ids: Tensor, dst_ids: Tensor, feat0: Tensor):
device = get_compute_device()
route = Route(dst_ids.to(device), src_ids.to(device))
return route.forward_a2a(feat0.to(device))[0].to(feat0.device)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.loader import BatchHandle
class BaseLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return x
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat()
with torch.no_grad():
x = self.forward(x, edge_index, edge_attr)
handle.update_feat(x)
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x = handle.fetch_feat().requires_grad_()
g = handle.fetch_grad()
self.forward(x, edge_index, edge_attr).backward(g)
handle.accumulate_grad(x.grad)
x.grad = None
def all_reduce_grad(self):
for p in self.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class BaseModel(nn.Module):
def __init__(self,
num_features: int,
layers: List[int],
prev_layer: bool = False,
post_layer: bool = False,
) -> None:
super().__init__()
def init_prev_layer(self) -> Tensor:
pass
def init_post_layer(self) -> Tensor:
pass
def init_conv_layer(self) -> Tensor:
pass
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
# from torch_geometric.utils import softmax
# from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
from .gat_conv import GATConv
from .utils import ShrinkHelper
# from starrygl.graph import DistGraph
# from .gat_conv import GATConv
# from .utils import ShrinkHelper
class ShrinkGATConv(GATConv):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
heads=heads,
concat=concat,
negative_slope=negative_slope,
dropout=dropout,
edge_dim=edge_dim,
bias=bias,
**kwargs
)
# class ShrinkGATConv(GATConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# heads: int = 1,
# concat: bool = False,
# negative_slope: float = 0.2,
# dropout: float = 0.0,
# edge_dim: Optional[int] = None,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# heads=heads,
# concat=concat,
# negative_slope=negative_slope,
# dropout=dropout,
# edge_dim=edge_dim,
# bias=bias,
# **kwargs
# )
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels
x = g.ndata["x"]
# H, C = self.heads, self.out_channels
# x = g.ndata["x"]
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
# src_x = (src_x @ self.weight).view(-1, H, C)
# dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
# alpha_i = (src_x * self.att_src).sum(dim=-1)
# alpha_j = alpha_j[edge_index[0]]
alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
# alpha_i = (dst_x * self.att_dst).sum(dim=-1)
# alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
# if self.edge_dim is not None:
# edge_attr = g.edata["edge_attr"]
# edge_attr = edge_attr[sh.edge_idx]
# if edge_attr.dim() == 1:
# edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=sh.dst_size,
)
# e = (edge_attr @ self.lin_edge).view(-1, H, C)
# alpha_e = (e * self.att_edge).sum(dim=-1)
# alpha = alpha_i + alpha_j + alpha_e
# else:
# alpha = alpha_i + alpha_j
# alpha = F.leaky_relu(alpha, self.negative_slope)
# alpha = softmax(
# src=alpha,
# index=edge_index[1],
# num_nodes=sh.dst_size,
# )
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# x = x[edge_index[0]] * alpha.view(-1, H, 1)
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
# if self.concat:
# x = x.view(-1, H * C)
# else:
# x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
# if self.bias is not None:
# x += self.bias
# return x
import torch
import torch.nn as nn
# import torch
# import torch.nn as nn
from torch_scatter import scatter_sum
# from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
from .gcn_conv import GCNConv
from .utils import ShrinkHelper
# from starrygl.graph import DistGraph
# from .gcn_conv import GCNConv
# from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
**kwargs
)
# class ShrinkGCNConv(GCNConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# bias=bias,
# **kwargs
# )
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
# x = g.ndata["x"]
# gcn_norm = g.edata["gcn_norm"].view(-1, 1)
x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
edge_index = sh.edges
# x = x[sh.src_idx]
# gcn_norm = gcn_norm[sh.edge_idx]
# edge_index = sh.edges
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None:
x += self.bias
return x
# x = x @ self.weight
# x = x[edge_index[0]] * gcn_norm
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.bias is not None:
# x += self.bias
# return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from .gin_conv import GINConv
from .utils import ShrinkHelper
class ShrinkGINConv(GINConv):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
mlp_channels=mlp_channels,
eps=eps,
train_eps=train_eps,
**kwargs
)
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gin_conv import GINConv
# from .utils import ShrinkHelper
# class ShrinkGINConv(GINConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# mlp_channels: Optional[int] = None,
# eps: float = 0,
# train_eps: bool = False,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# mlp_channels=mlp_channels,
# eps=eps,
# train_eps=train_eps,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
# x = g.ndata["x"]
sh = ShrinkHelper(g, dst_idx)
# sh = ShrinkHelper(g, dst_idx)
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
x = z + (1 + self.eps) * dst_x
return self.nn(x)
\ No newline at end of file
# z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
# x = z + (1 + self.eps) * dst_x
# return self.nn(x)
\ No newline at end of file
......@@ -3,11 +3,10 @@ import torch.nn as nn
import torch.distributed as dist
import os
from typing import *
from .degree import compute_degree, compute_gcn_norm
from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model(
net: nn.Module,
......@@ -16,28 +15,38 @@ def convert_parallel_model(
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
broadcast_buffers=False,
# broadcast_buffers=False,
)
for name, buffer in net.named_buffers():
if name.endswith("last_embd"):
continue
if name.endswith("last_w"):
continue
dist.broadcast(buffer, src=0)
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
dist.init_process_group(backend)
if backend == "nccl":
local_rank = os.getenv("LOCAL_RANK")
if backend == "nccl" or backend == "mpi":
if local_rank is None:
device = torch.device(f"cuda:{dist.get_rank()}")
else:
local_rank = int(local_rank)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
get_executor() # initialize route executor
return device
\ No newline at end of file
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
\ No newline at end of file
......@@ -4,29 +4,30 @@ from torch import Tensor
from typing import *
from torch_scatter import scatter_sum
from starrygl.core.route import Route
from ..graph.distgraph import DistGraph
# from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
edge_index = g.edge_index
x = torch.ones(
edge_index.size(1),
dtype=torch.long,
device=edge_index.device,
)
in_degree = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size) # (g.dst_size,)
out_degree = scatter_sum(x, edge_index[0], dim=0, dim_size=g.src_size)
out_degree = g.gather(out_degree, direct="src_to_dst") # (g.dst_size,)
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size = route.src_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_degree = g.gather(in_degree, direct="dst_to_src") # (g.src_size,)
out_degree = g.gather(out_degree, direct="dst_to_src") # (g.src_size,)
return in_degree, out_degree
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg, _ = route.forward_a2a(in_deg)
return in_deg
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size = route.dst_size
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg, _ = route.backward_a2a(out_deg)
out_deg, _ = route.forward_a2a(out_deg)
return out_deg
def compute_gcn_norm(g: DistGraph) -> Tensor:
edge_index = g.edge_index
in_deg, out_deg = compute_degree(g)
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg = compute_in_degree(edge_index, route)
out_deg = compute_out_degree(edge_index, route)
a = in_deg[edge_index[0]].pow(-0.5)
b = out_deg[edge_index[0]].pow(-0.5)
......
from .functional import train_epoch, eval_epoch
# from .functional import train_epoch, eval_epoch
from .partition import partition_load, partition_save, partition_data
from .printer import sync_print, main_print
from .metrics import all_reduce_loss, accuracy
\ No newline at end of file
import torch
import torch.nn as nn
# import torch
# import torch.nn as nn
from torch import Tensor
from ..graph import DistGraph
# from torch import Tensor
# from ..graph import DistGraph
from typing import *
from .metrics import *
# from typing import *
# from .metrics import *
def train_epoch(
model: nn.Module,
opt: torch.optim.Optimizer,
g: DistGraph,
mask: Optional[Tensor] = None,
) -> float:
model.train()
criterion = nn.CrossEntropyLoss()
# def train_epoch(
# model: nn.Module,
# opt: torch.optim.Optimizer,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> float:
# model.train()
# criterion = nn.CrossEntropyLoss()
pred: Tensor = model(g)
targ: Tensor = g.ndata["y"]
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
if mask is not None:
pred = pred[mask]
targ = targ[mask]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
loss: Tensor = criterion(pred, targ)
# loss: Tensor = criterion(pred, targ)
opt.zero_grad()
loss.backward()
opt.step()
# opt.zero_grad()
# loss.backward()
# opt.step()
with torch.no_grad():
train_loss = all_reduce_loss(loss, targ.size(0))
train_acc = accuracy(pred.argmax(dim=-1), targ)
return train_loss, train_acc
# with torch.no_grad():
# train_loss = all_reduce_loss(loss, targ.size(0))
# train_acc = accuracy(pred.argmax(dim=-1), targ)
# return train_loss, train_acc
@torch.no_grad()
def eval_epoch(
model: nn.Module,
g: DistGraph,
mask: Optional[Tensor] = None,
) -> Tuple[float, float]:
model.eval()
criterion = nn.CrossEntropyLoss()
# @torch.no_grad()
# def eval_epoch(
# model: nn.Module,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> Tuple[float, float]:
# model.eval()
# criterion = nn.CrossEntropyLoss()
pred: Tensor = model(g)
targ: Tensor = g.ndata["y"]
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
if mask is not None:
pred = pred[mask]
targ = targ[mask]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
loss = criterion(pred, targ)
# loss = criterion(pred, targ)
eval_loss = all_reduce_loss(loss, targ.size(0))
eval_acc = accuracy(pred.argmax(dim=-1), targ)
return eval_loss, eval_acc
\ No newline at end of file
# eval_loss = all_reduce_loss(loss, targ.size(0))
# eval_acc = accuracy(pred.argmax(dim=-1), targ)
# return eval_loss, eval_acc
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment