Commit a4c4cadb by Wenjie Huang

with bugs

parent 06d10ed3
from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
from .gather import Gather
from .route import Route, GatherWork
\ No newline at end of file
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
......@@ -30,12 +30,14 @@ class Works:
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works:
assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl():
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=True,
)
return Works(work)
......@@ -48,8 +50,8 @@ def all_to_all(
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i)
send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w)
output_tensor_list[rank][:] = input_tensor_list[rank]
......
import torch
from torch import Tensor
from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
def wait(self):
pass
def get(self):
pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
def wait(self):
pass
def get(self) -> Tensor:
return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
import torch
import torch.nn as nn
from torch.utils import hooks
from torch import Tensor
from typing import *
class EmbeddingCache:
def __init__(self) -> None:
self.values: Optional[Tensor] = None
self.active: Optional[Tensor] = None
from .route import Route, RouteWorkBase
from .shrink import ShrinkData
from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
from .route import get_executor as get_route_executor
def get_values(self) -> Optional[Tensor]:
return self.values
def get_active(self) -> Optional[Tensor]:
return self.active
class ProbeLayer:
def __init__(self, layer_id, probe_obj) -> None:
self.layer_id: int = layer_id
self.probe_obj: NodeProbe = probe_obj
self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
self._val_hook_handle: Optional[hooks.RemovableHandle] = None
def get(self) -> Tuple[Optional[Tensor], Optional[Tensor]]:
return self.get_values(), self.get_active()
def to(self, device):
self.last_w = self.last_w.to(device)
return self
def set(self, values: Tensor, active: Optional[Tensor] = None):
values = values.detach()
active = active.detach()
if active is None:
if self.values is None:
self.values = values.clone()
else:
self.values[:] = values
self.active = None
else:
if self.values is None:
self.values = torch.zeros(
size=active.shape[:1] + values.shape[1:],
dtype=values.dtype,
device=values.device,
)
def remove_val_hook(self):
if self._val_hook_handle is not None:
self._val_hook_handle.remove()
self._val_hook_handle = None
def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
def hook(grad: Tensor):
from starrygl.utils.printer import main_print
import time
self.probe_obj.update_sample_w(self.last_w, grad, idx)
self.remove_val_hook()
self._backward_sample(False)
main_print(self.layer_id, grad.size(0), idx.size(0))
self.val_hook_handle = val.register_hook(hook)
def warmup_sample(self):
assert self._is_last_layer()
for probe_layer in self.probe_obj.layers[::-1]:
probe_layer._backward_sample()
# max_norm = max(probe_layer.last_w.max().item(), 1.0)
# probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
def _backward_sample(self, x: bool = True):
dst_idx = self._collect_prev_dst_idx()
if dst_idx is None:
return
for cache in self.probe_obj.hook_caches:
next_cache_layer = self._next_cache_layer(cache)
shrink = next_cache_layer.set_shrink_data(dst_idx)
src_idx = shrink.src_idx
work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
next_cache_layer.sync_pull_work(work)
if not self._is_first_layer() and x:
next_cache_layer.sync_backward_sample_work(src_idx)
def _collect_prev_dst_idx(self) -> Optional[Tensor]:
if len(self.probe_obj.hook_caches) == 0:
return None
if values.size(0) == active.size(0):
self.values[active] = values[active]
if self._is_last_layer():
return self._wrapped_sample(self.probe_obj.num_samples, None)
if len(self.probe_obj.hook_caches) == 1:
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
dst_idx = prev_cache_layer.sync_backward_sample_work()
else:
self.values[active] = values
tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
prev_idx = prev_cache_layer.sync_backward_sample_work()
if prev_idx is not None:
tmp.index_fill_(0, prev_idx, 1)
dst_idx = torch.where(tmp)[0]
if dst_idx.size(0) == 0:
dst_idx = None
if dst_idx is None:
return None
return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
def _prev_cache_layer(self, cache):
cache: MessageCache = cache
i = self.layer_id + 1
if i < self.probe_obj.num_layers:
return cache.layers[i]
return None
def _next_cache_layer(self, cache):
cache: MessageCache = cache
return cache.layers[self.layer_id]
if self.active is None:
self.active = active.clone()
def _is_last_layer(self):
return self.layer_id + 1 >= self.probe_obj.num_layers
def _is_first_layer(self):
return self.layer_id == 0
def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
w = self.last_w
if idx is None:
return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
if 0 < k and k < idx.size(0):
t = self.probe_obj.sample(w[idx], k)
return idx[t]
else:
self.active[:] = active
return idx
class NodeProbe:
def __init__(self,
num_nodes: int,
num_layers: int,
num_samples: int = 0,
p: str = "fro",
dim: int = -1,
beta: float = 1.0,
) -> None:
super().__init__()
self.num_nodes = num_nodes
self.num_layers = num_layers
self.num_samples = num_samples
self.p = p
self.dim = dim
self.beta = beta
self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
self.hook_caches: Set[MessageCache] = set()
def fuse(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor:
if active is None:
return values
def to(self, device):
for i in range(self.num_layers):
self.layers[i].to(device)
return self
def assign_message_cache(self, cache):
cache: MessageCache = cache
assert self.num_layers == cache.num_layers
self.hook_caches.add(cache)
def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
val_norm = grad.norm(p=self.p, dim=self.dim)
if idx is None:
last_w[:] = val_norm
else:
if self.values is None:
x = torch.zeros(
size=active.shape[:1] + values.shape[1:],
dtype=values.dtype,
device=values.device,
if self.beta != 1.0:
last_w.mul_(self.beta)
last_w[idx] = val_norm
def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
w = w / w.sum()
return torch.multinomial(w, num_samples=k, replacement=False)
def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
self.layers[i].register_val_hook(val, idx)
return val
class CacheLayer:
def __init__(self, layer_id, cache_obj, num_features) -> None:
self.layer_id: int = layer_id
self.cache_obj: MessageCache = cache_obj
self.data: Tensor = torch.randn(
size=(self.cache_obj.src_size, num_features),
dtype=torch.float32,
device=self.cache_obj.cache_device,
)
self.state: Dict[str, Any] = {}
self.shrink: Optional[ShrinkData] = None
self._async_push_work: Optional[AsyncCopyWorkBase] = None
self._async_pull_work: Optional[AsyncCopyWorkBase] = None
self._backward_sample_work: Optional[RouteWorkBase] = None
def to(self, device):
if self.shrink is not None:
self.shrink = self.shrink.to(device)
return self
def cache_data_to(self):
self.data = self.data.to(self.cache_obj.cache_device)
return self
def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
self.shrink = self.cache_obj.new_shrink_data(dst_idx)
return self.shrink
def get_shrink_data(self) -> Optional[ShrinkData]:
return self.shrink
# def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# if src_idx is None:
# return
# assert self._backward_sample_work is None
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# work = self._backward_sample_work
# self._backward_sample_work = None
# return work
def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
dst_idx = None
if self._backward_sample_work is not None:
_, dst_idx = self._backward_sample_work.get()
if src_idx is not None:
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
self._backward_sample_work = work
else:
x = self.values.clone()
self._backward_sample_work = None
return dst_idx
def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
if self._async_push_work is not None:
self._async_push_work.get()
self._async_push_work = work
def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
out = None
if self._async_pull_work is not None:
out = self._async_pull_work.get()
self._async_pull_work = work
return out
if values.size(0) == active.size(0):
x[active] = values[active]
class MessageCache:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
edge_index: Tensor,
num_features: Union[int, torch.Size],
num_layers: int,
cache_device: Union[str, torch.device, None],
bipartite: bool = False,
) -> None:
self.src_size = src_ids.size(0)
self.dst_size = dst_ids.size(0)
self.edge_index = edge_index
self.num_layers = num_layers
if cache_device is None:
cache_device = edge_index.device
self.cache_device = cache_device
self.bipartite = bipartite
self.route = Route(dst_ids, src_ids, bipartite=bipartite)
self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
def to(self, device):
self.edge_index = self.edge_index.to(device)
self.route = self.route.to(device)
for i in range(self.num_layers):
self.layers[i].to(device)
return self
def cached_data_to(self, device):
self.cache_device = torch.device(device)
for i in range(self.num_layers):
self.layers[i].cache_data_to()
return self
# @property
# def is_offload(self) -> bool:
# return self.edge_index.device != self.cache_device
def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
if dst_idx is None:
return None
return ShrinkData(
src_size=self.src_size,
dst_size=self.dst_size,
dst_idx=dst_idx,
edge_index=self.edge_index,
bipartite=self.bipartite,
)
# def clear_shrink_data(self):
# for i in range(self.num_layers):
# self.layers[i].shrink = None
def replace_layer_data(self, i: int, data: Tensor):
layer = self.layers[i]
assert layer.data.size(0) == data.size(0)
layer.data = data
return self
def update_cache(self,
i: int,
val: Tensor,
idx: Optional[Tensor],
async_op: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
layer: CacheLayer = self.layers[i]
src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# full graph
if idx is None:
return src_val, None
shrink = layer.get_shrink_data()
if shrink is None: # 这里可能存在问题,在初始化的时候
data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
data.index_copy_(0, src_idx, src_val)
return data, src_idx
else:
x[active] = values
return x
push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
layer.sync_push_work(push_work)
data = layer.sync_pull_work()
sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
data.index_copy_(0, sidx, sval)
return data, sidx
# 先异步更新缓存
# 同时异步下载缓存
# 然后本地和远程缓存混合,作为最终结果返回
def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
return self.layers[i].sync_pull_work()
# def _update_cache_impl(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# # communcation
# state = self.layers[i].state
# src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # push latest embeddings
# data = self.layers[i].data
# if src_idx is None:
# data[:] = src_val
# else:
# data[src_idx] = src_val
# # get previous generated shrink data
# shr: ShrinkData = self.layers[i].shrink
# if shr is None:
# return src_val, src_idx
# else:
# # pull latest embeddings
# return data[shr.src_idx], shr.src_idx
# def _update_cache_offload(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# raise NotImplementedError
# def _compute_cached_data_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cached_data_size = torch.Size([self.src_size,])
# else:
# cached_data_size = torch.Size([self.src_size, num_features])
# return cached_data_size
# def _compute_cpu_buf_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cpu_buf_size = torch.Size([2**32,])
# else:
# cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# return cpu_buf_size
\ No newline at end of file
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from contextlib import contextmanager
from typing import *
from .route import Route, GatherWork
from .cache import EmbeddingCache
class GatherContext:
def __init__(self,
this,
route: Route,
async_op: bool,
) -> None:
self.this = this
self.route = route
self.async_op = async_op
class Gather(nn.Module):
def __init__(self,
num_nodes: int,
num_features: Optional[int] = None,
beta: float = 1.0,
) -> None:
super().__init__()
self.num_nodes = num_nodes
self.num_features = num_features
self.beta = beta
if num_features is None:
self.register_buffer("last_embd", torch.zeros(num_nodes))
else:
self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
self.last_fw_work: Optional[GatherWork] = None
self.last_bw_work: Optional[GatherWork] = None
self.reset_parameters()
def forward(self,
values: Tensor,
active: Optional[Tensor],
route: Route,
async_op: bool = False,
) -> Tuple[Tensor, Tensor]:
with self._gather_manager(route=route, async_op=async_op):
return GatherFunction.apply(values, active)
def reset_parameters(self):
last_embd = self.get_buffer("last_embd")
nn.init.normal_(last_embd, mean=0, std=1)
def fuse_embeddings(self,
values: Tensor,
active: Optional[Tensor] = None,
inplace: bool = False,
) -> Tensor:
last_embd = self.get_buffer("last_embd")
return GatherFuseFunction.apply(values, active, last_embd, self.beta, inplace, self.training)
# if not inplace:
# last_embd = last_embd.clone()
# if active is None:
# if self.beta != 1.0:
# values = values * self.beta + last_embd * (1 - self.beta)
# last_embd[:] = values
# else:
# if values.size(0) == active.size(0):
# values = values[active]
# if self.beta != 1.0:
# values = values * self.beta + last_embd[active] * (1 - self.beta)
# last_embd[active] = values
# return last_embd
@contextmanager
def _gather_manager(self, route: Route, async_op: bool):
global _global_gather_context
stacked = _global_gather_context
try:
_global_gather_context = GatherContext(
this=self,
route=route,
async_op=async_op,
)
yield _global_gather_context
finally:
_global_gather_context = stacked
class GatherFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
values: Tensor,
active: Optional[Tensor] = None,
):
gather_ctx = _last_global_gather_context()
this: Gather = gather_ctx.this
route: Route = gather_ctx.route
async_op: bool = gather_ctx.async_op
current_work = route.gather_forward(values, active, async_op=async_op)
if async_op:
work = this.last_fw_work or current_work
this.last_fw_work = current_work
else:
work = current_work
recv_values, recv_active = work.get()
recv_values = this.fuse_embeddings(
values=recv_values,
active=recv_active,
inplace=True,
)
if this.training:
# 如果输入的values是收缩过的,求解梯度的时候需要去除空洞
if active is not None and values.size(0) < active.size(0):
ctx.shrink_grad = True
ctx.save_for_backward(active, recv_active)
else:
ctx.shrink_grad = False
ctx.save_for_backward(recv_active)
ctx.this = this
ctx.route = route
ctx.async_op = async_op
return recv_values, recv_active
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad_values: Tensor,
grad_active: Optional[Tensor],
):
this: Gather = ctx.this
route: Route = ctx.route
async_op: bool = ctx.async_op
shrink_grad: bool = ctx.shrink_grad
with torch.no_grad():
# # 反向传播激活值是沿着前向传播的反方向进行
if shrink_grad:
recv_active, grad_active = ctx.saved_tensors
else:
grad_active, = ctx.saved_tensors
current_work = route.gather_backward(grad_values, grad_active, async_op=async_op)
if async_op:
work = this.last_bw_work or current_work
this.last_bw_work = current_work
else:
work = current_work
recv_values = work.get_values()
if shrink_grad:
recv_values = recv_values[recv_active]
return recv_values, None
class GatherFuseFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
values: Tensor,
active: Optional[Tensor],
last_embd: Tensor,
beta: float,
inplace: bool,
training: bool,
):
if not inplace:
last_embd = last_embd.clone()
ctx.beta = beta
if active is None:
if beta != 1.0:
values = values * beta + last_embd * (1 - beta)
last_embd[:] = values
ctx.shrink_grad = False
else:
if values.size(0) == active.size(0):
values = values[active]
ctx.shrink_grad = False
else:
if training:
ctx.save_for_backward(active)
ctx.shrink_grad = True
if beta != 1.0:
values = values * beta + last_embd[active] * (1 - beta)
last_embd[active] = values
return last_embd
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
if ctx.shrink_grad:
active, = ctx.saved_tensors
return grad[active] * ctx.beta, None, None, None, None, None
else:
return grad * ctx.beta, None, None, None, None, None
#### private functions
_global_gather_context: Optional[GatherContext] = None
def _last_global_gather_context() -> GatherContext:
global _global_gather_context
assert _global_gather_context is not None
return _global_gather_context
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# from .route import Route, GatherWork
# # from .cache import CachedEmbeddings
# class GatherContext:
# def __init__(self,
# this,
# route: Route,
# async_op: bool,
# ) -> None:
# self.this = this
# self.route = route
# self.async_op = async_op
# class Gather(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_features: Optional[int] = None,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.beta = beta
# if num_features is None:
# self.register_buffer("last_embd", torch.zeros(num_nodes))
# else:
# self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
# self.last_fw_work: Optional[GatherWork] = None
# self.last_bw_work: Optional[GatherWork] = None
# self.reset_parameters()
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# route: Route,
# async_op: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# with self._manager(route=route, async_op=async_op):
# return GatherFunction.apply(val, idx)
# def reset_parameters(self):
# last_embd = self.get_buffer("last_embd")
# nn.init.normal_(last_embd, mean=0, std=1)
# def fuse_embeddings(self,
# val: Tensor,
# idx: Optional[Tensor] = None,
# inplace: bool = False,
# ) -> Tensor:
# last_embd = self.get_buffer("last_embd")
# return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
# @contextmanager
# def _manager(self, route: Route, async_op: bool):
# global _global_gather_context
# stacked = _global_gather_context
# try:
# _global_gather_context = GatherContext(
# this=self,
# route=route,
# async_op=async_op,
# )
# yield _global_gather_context
# finally:
# _global_gather_context = stacked
# class GatherFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor] = None,
# ):
# gather_ctx = _last_global_gather_context()
# this: Gather = gather_ctx.this
# route: Route = gather_ctx.route
# async_op: bool = gather_ctx.async_op
# return_idx: bool = idx is not None
# current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
# if async_op:
# work = this.last_fw_work or current_work
# this.last_fw_work = current_work
# else:
# work = current_work
# recv_val, recv_idx = work.get()
# recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
# recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
# if this.training:
# ctx.save_for_backward(idx, recv_idx)
# ctx.this = this
# ctx.route = route
# ctx.async_op = async_op
# return recv_val, recv_idx
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# val_grad: Tensor,
# idx_grad: Optional[Tensor],
# ):
# this: Gather = ctx.this
# route: Route = ctx.route
# async_op: bool = ctx.async_op
# with torch.no_grad():
# recv_idx, idx_grad = ctx.saved_tensors
# if idx_grad is not None:
# val_grad = val_grad[idx_grad]
# current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
# if async_op:
# work = this.last_bw_work or current_work
# this.last_bw_work = current_work
# else:
# work = current_work
# recv_val = work.get_val()
# if recv_idx is not None:
# recv_val = recv_val[recv_idx]
# return recv_val, None
# class GatherFuseFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# last_embd: Tensor,
# beta: float,
# inplace: bool,
# training: bool,
# ):
# if not inplace:
# last_embd = last_embd.clone()
# ctx.beta = beta
# if idx is None:
# assert val.size(0) == last_embd.size(0)
# if beta != 1.0:
# last_embd.mul_(1 - beta).add_(val * beta)
# else:
# last_embd[:] = (val)
# else:
# assert val.size(0) == idx.size(0)
# if beta != 1.0:
# last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
# else:
# last_embd[idx] = val
# if training:
# ctx.beta = beta
# ctx.save_for_backward(idx)
# return last_embd
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# beta: float = ctx.beta
# idx, = ctx.saved_tensors
# if idx is not None:
# grad = grad[idx]
# if beta != 1.0:
# grad = grad * beta
# return grad, None, None, None, None, None
# #### private functions
# _global_gather_context: Optional[GatherContext] = None
# def _last_global_gather_context() -> GatherContext:
# global _global_gather_context
# assert _global_gather_context is not None
# return _global_gather_context
import torch
import torch.autograd as autograd
import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor
from typing import *
from contextlib import contextmanager
from .a2a import all_to_all, Works, with_nccl
from .event import Event
# class GatherWork:
# def __init__(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# recv_val_idx: List[Tensor],
# recv_val_dat: List[Tensor],
# works_list: List[Works],
# ) -> None:
# self._waited = False
# self._values_buffer = values_buffer
# self._active_buffer = active_buffer
# self._recv_val_idx = recv_val_idx
# self._recv_val_dat = recv_val_dat
# self._works_list = works_list
# self._cached_idx: Optional[Tensor] = None
# def _wait(self) -> None:
# assert not self._waited
# for w in self._works_list:
# if w is None:
# continue
# w.wait()
# if self._active_buffer is None:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# else:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# self._active_buffer[idx] = True
# self._active_buffer = torch.where(self._active_buffer)[0]
# self._works_list = []
# self._recv_val_dat = []
# self._recv_val_idx = []
# self._waited = True
# def get_val(self) -> Tensor:
# if not self._waited:
# self._wait()
# return self._values_buffer
# def get_idx(self) -> Optional[Tensor]:
# if not self._waited:
# self._wait()
# return self._active_buffer
# def get(self) -> Tuple[Tensor, Optional[Tensor]]:
# return self.get_val(), self.get_idx()
class RouteWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
raise NotImplementedError
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class GatherWork:
class RouteWork(RouteWorkBase):
def __init__(self,
values_buffer: Tensor,
active_buffer: Optional[Tensor],
buf: Union[Tensor, int],
recv_val_idx: List[Tensor],
recv_val_dat: List[Tensor],
works_list: List[Works],
) -> None:
self._waited = False
self._values_buffer = values_buffer
self._active_buffer = active_buffer
super().__init__()
assert len(recv_val_idx) != 0
assert len(recv_val_idx) == len(recv_val_dat)
self._buf = buf
self._recv_val_idx = recv_val_idx
self._recv_val_dat = recv_val_dat
self._works_list = works_list
def _wait(self) -> None:
assert not self._waited
self._val: Optional[Tensor] = None
self._idx: Optional[Tensor] = None
def wait(self) -> None:
if self._val is not None:
return
for w in self._works_list:
if w is None:
continue
w.wait()
if self._active_buffer is None:
ikw = dict(dtype=self._recv_val_idx[0].dtype, device=self._recv_val_idx[0].device)
fkw = dict(dtype=self._recv_val_dat[0].dtype, device=self._recv_val_dat[0].device)
if isinstance(self._buf, Tensor):
for idx in self._recv_val_idx:
self._buf[idx] = True
self._idx = torch.where(self._buf)[0]
imp = torch.empty(self._buf.size(0), **ikw).fill_((2**62-1)*2+1)
imp[self._idx] = torch.arange(self._idx.size(0), **ikw)
s = self._idx.shape[:1] + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat
idx = imp[idx]
self._val[idx] += dat
else:
s = (self._buf,) + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat
self._active_buffer[idx] = True
self._waited = True
self._val[idx] += dat
self._buf = None
self._recv_val_dat = None
self._recv_val_idx = None
self._works_list = None
def get_values(self) -> Tensor:
if not self._waited:
self._wait()
return self._values_buffer
if self._events is not None:
self._events[1].record()
def get_active(self) -> Optional[Tensor]:
if not self._waited:
self._wait()
return self._active_buffer
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
if self._val is None:
self.wait()
return self._val, self._idx
class RouteExecutorWork(RouteWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.wait()
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
return self.get_values(), self.get_active()
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class RouteExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
self._group = dist.new_group()
def async_forward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.forward_a2a, val, idx)
def async_backward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.backward_a2a, val, idx)
def _async_a2a_impl(self, func, val: Tensor, idx: Optional[Tensor]) -> RouteWorkBase:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
ret = func(val, idx, group=self._group).get()
end.record()
return ret
start.record()
handle = self._executor.apply_async(run)
return RouteExecutorWork(handle).set_events(start, end)
# def apply_async(self, func, val, idx, async_op) -> RouteWorkBase:
# start = Event(val.is_cuda)
# end = Event(val.is_cuda)
# if not async_op:
# start.record()
# return func(val, idx).set_events(start, end)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# ret = func(val, idx, group=self._group).get()
# end.record()
# return ret
# start.record()
# handle = self._executor.apply_async(run)
# return RouteExecutorWork(handle).set_events(start, end)
class RouteCtx:
def __init__(self, route, state, async_op) -> None:
self.route = route
self.state = state
self.async_op = async_op
class Route:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool = False,
):
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
if not bipartite:
# 要求数据为点分割,即所有分区的src_ids正交
assert src_ids.size(0) <= dst_ids.size(0)
assert (dst_ids[:src_ids.size(0)] == src_ids).all()
rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -135,146 +307,347 @@ class Route:
self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _p2p_recv(self.forward_routes)
self.forward_routes = _fix_fw_routes(self.forward_routes)
# 满足同构图条件,则每个点添加自环
if src_ids.size(0) <= dst_ids.size(0):
t = dst_ids[:src_ids.size(0)] == src_ids
if t.all():
if not bipartite:
rank_ind = torch.arange(src_ids.size(0), **ikw)
rank_route = torch.vstack([rank_ind, rank_ind])
self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route
self.total_time_used: float = 0.0
dist.barrier()
def _gather_impl(self,
values_buffer: Tensor,
active_buffer: Optional[Tensor],
values: Tensor,
active: Optional[Tensor],
def to(self, device):
self.forward_routes = [ro.to(device) for ro in self.forward_routes]
self.backward_routes = [ro.to(device) for ro in self.backward_routes]
return self
def _a2a_impl(self,
val: Tensor,
idx: Optional[Tensor],
send_buf_size: int,
recv_buf_size: int,
send_routes: List[Tensor],
recv_routes: List[Tensor],
async_op: bool,
):
ikw = dict(dtype=torch.long, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
# 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
if active is not None:
if values.size(0) < active.size(0):
values = _spread_values(values, active)
assert values.size(0) == active.size(0)
# 计算新的路由表
active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# 计算接收缓冲区大小
scatter_sizes = [ro.size(1) for ro in active_routes]
scatter_sizes = torch.tensor(scatter_sizes, **ikw)
gather_sizes = torch.zeros_like(scatter_sizes)
dist.all_to_all_single(gather_sizes, scatter_sizes)
gather_sizes = gather_sizes.tolist()
def async_run():
if active is not None:
send_val_idx, recv_val_idx = [], []
group: Optional[Any],
) -> RouteWork:
bkw = dict(dtype=torch.bool, device=val.device)
ikw = dict(dtype=torch.long, device=val.device)
fkw = dict(dtype=val.dtype, device=val.device)
if idx is not None:
msk = torch.zeros(send_buf_size, **bkw).index_fill_(0, idx, 1)
send_routes = [ro[:,msk[ro[0]]] for ro in send_routes]
send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes, group=group)
recv_sizes = recv_sizes.tolist()
if idx is None:
send_val_dat, recv_val_dat = [], []
for i, s in enumerate(gather_sizes):
c = (s,) + values.shape[1:]
send_val_idx.append(active_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_val_dat.append(values[active_routes[i][0]])
recv_val_idx = [ro[0] for ro in recv_routes]
for i, ro in enumerate(recv_routes):
s = ro.size(1)
c = (s,) + val.shape[1:]
send_val_dat.append(val[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_idx, send_val_idx),
all_to_all(recv_val_dat, send_val_dat),
all_to_all(recv_val_dat, send_val_dat, group=group),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
return RouteWork(
buf=recv_buf_size,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
else:
imp = torch.empty(send_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[idx] = torch.arange(idx.size(0), **ikw)
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], []
recv_val_idx = [ro[0] for ro in recv_routes]
for i, r in enumerate(recv_routes):
s = r.size(1)
c = (s,) + values.shape[1:]
send_val_dat.append(values[send_routes[i][0]])
for i, s in enumerate(recv_sizes):
c = (s,) + val.shape[1:]
send_val_idx.append(send_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_index = imp[send_routes[i][0]]
send_val_dat.append(val[send_index])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_dat, send_val_dat),
all_to_all(recv_val_idx, send_val_idx, group=group),
all_to_all(recv_val_dat, send_val_dat, group=group),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
recv_buf = torch.zeros(recv_buf_size, **bkw)
return RouteWork(
buf=recv_buf,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
if async_op and with_nccl():
stream = get_stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
return async_run()
else:
return async_run()
def gather_forward(self,
values: Tensor,
active: Optional[Tensor],
async_op: bool = False,
return_active: bool = True,
):
bkw = dict(dtype=torch.bool, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
s = (self.dst_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
if return_active:
active_buffer = torch.zeros(self.dst_size, **bkw)
def forward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.src_size
else:
active_buffer = None
return self._gather_impl(
values_buffer=values_buffer,
active_buffer=active_buffer,
values=values,
active=active,
assert val.size(0) == idx.size(0)
if group is None:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
start.record()
work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.src_size,
recv_buf_size=self.dst_size,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
async_op=async_op,
group=group,
)
def gather_backward(self,
values: Tensor,
active: Optional[Tensor],
if group is None:
return work.set_events(start, end)
return work
def backward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.dst_size
else:
assert val.size(0) == idx.size(0)
if group is None:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
start.record()
work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.dst_size,
recv_buf_size=self.src_size,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
if group is None:
return work.set_events(start, end)
return work
def apply(self,
val: Tensor,
idx: Optional[Tensor],
state: Dict[str, Any],
async_op: bool = False,
return_active: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
with self._a2a_manager(state, async_op):
return RouteFunction.apply(val, idx)
@contextmanager
def _a2a_manager(self, state, async_op) -> RouteCtx:
global _global_route_context
stacked_ctx = _global_route_context
try:
_global_route_context = RouteCtx(self, state, async_op)
yield _global_route_context
finally:
_global_route_context = stacked_ctx
# def _gather_impl(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# values: Tensor,
# active: Optional[Tensor],
# send_routes: List[Tensor],
# recv_routes: List[Tensor],
# async_op: bool,
# ):
# ikw = dict(dtype=torch.long, device=values.device)
# fkw = dict(dtype=values.dtype, device=values.device)
# # 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
# if active is not None:
# # 计算新的路由表
# active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# # 计算接收缓冲区大小
# scatter_sizes = [ro.size(1) for ro in active_routes]
# scatter_sizes = torch.tensor(scatter_sizes, **ikw)
# gather_sizes = torch.zeros_like(scatter_sizes)
# dist.all_to_all_single(gather_sizes, scatter_sizes)
# gather_sizes = gather_sizes.tolist()
# def async_run():
# if active is not None:
# send_val_idx, recv_val_idx = [], []
# send_val_dat, recv_val_dat = [], []
# for i, s in enumerate(gather_sizes):
# c = (s,) + values.shape[1:]
# send_val_idx.append(active_routes[i][1])
# recv_val_idx.append(torch.zeros(s, **ikw))
# send_val_dat.append(values[active_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_idx, send_val_idx),
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# else:
# send_val_dat, recv_val_dat = [], []
# recv_val_idx = [ro[0] for ro in recv_routes]
# for i, r in enumerate(recv_routes):
# s = r.size(1)
# c = (s,) + values.shape[1:]
# send_val_dat.append(values[send_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# if async_op and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# return async_run()
# else:
# return async_run()
# def gather_forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.src_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.dst_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.dst_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.forward_routes,
# recv_routes=self.backward_routes,
# async_op=async_op,
# )
# def gather_backward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.dst_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.src_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.src_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.backward_routes,
# recv_routes=self.forward_routes,
# async_op=async_op,
# )
class RouteFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
val: Tensor,
idx: Optional[Tensor],
):
bkw = dict(dtype=torch.bool, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
route_ctx = get_global_route_context()
route: Route = route_ctx.route
state: Dict[str, Any] = route_ctx.state
async_op: bool = route_ctx.async_op
# cur_work = get_executor().apply_async(route.forward_a2a, val, idx, async_op)
cur_work = get_executor().async_forward_a2a(val, idx, route)
if async_op:
# cur_work = get_executor().async_forward_a2a(val, idx, route)
work = state.get("__route_fw_work__") or cur_work
state["__route_fw_work__"] = cur_work
else:
# work = route.forward_a2a(val, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
ctx.route = route
ctx.state = state
ctx.async_op = async_op
ctx.save_for_backward(idx)
s = (self.src_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
if return_active:
active_buffer = torch.zeros(self.src_size, **bkw)
return val, idx
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
_: None,
):
route: Route = ctx.route
state: Dict[str, Any] = ctx.state
async_op: bool = ctx.async_op
with torch.no_grad():
idx, = ctx.saved_tensors
# cur_work = get_executor().apply_async(route.backward_a2a, grad, idx, async_op)
cur_work = get_executor().async_backward_a2a(grad, idx, route)
if async_op:
# cur_work = get_executor().async_backward_a2a(grad, idx, route)
work = state.get("__route_bw_work__") or cur_work
state["__route_bw_work__"] = cur_work
else:
active_buffer = None
# work = route.backward_a2a(grad, idx)
work = cur_work
return self._gather_impl(
values_buffer=values_buffer,
active_buffer=active_buffer,
values=values,
active=active,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
async_op=async_op,
)
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
return val, idx
#### private functions
......@@ -288,7 +661,7 @@ def _all_reduce_num_nodes(
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX)
return max_ids.item() + 1
def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]:
def _fix_fw_routes(tensors: List[Tensor]) -> List[Tensor]:
rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -305,19 +678,52 @@ def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]:
all_to_all(new_tensors, tensors).wait()
return new_tensors
def _spread_values(values: Tensor, active: Tensor) -> Tensor:
new_values = torch.zeros(
size=active.shape[:1] + values.shape[1:],
dtype=values.dtype,
device=values.device,
)
new_values[active] = values
return new_values
# def _spread_val_idx(val: Tensor, idx: Optional[Tensor], size: int) -> Tuple[Tensor, Optional[Tensor]]:
# if idx is None:
# assert val.size(0) == size
# return val, None
# else:
# assert val.size(0) == idx.size(0)
# x = torch.zeros(
# size=(size,) + val.shape[1:],
# dtype=val.dtype,
# device=val.device,
# )
# x[idx] = val
# y = torch.zeros(
# size=(size,),
# dtype=torch.bool,
# device=idx.device,
# ).index_fill_(0, idx, 1)
# return x, y
_global_route_context: Optional[RouteCtx] = None
def get_global_route_context() -> RouteCtx:
global _global_route_context
assert _global_route_context is not None
return _global_route_context
# 创建一个新的stream,专用于Gather异步通信
_STREAM: Optional[torch.cuda.Stream] = None
def get_stream() -> torch.cuda.Stream:
global _STREAM
if _STREAM is None:
_STREAM = torch.cuda.Stream()
return _STREAM
\ No newline at end of file
# _STREAM: Optional[torch.cuda.Stream] = None
# def get_stream() -> torch.cuda.Stream:
# global _STREAM
# if _STREAM is None:
# _STREAM = torch.cuda.Stream()
# return _STREAM
# @contextmanager
# def stream_manager(enable: bool = True) -> Optional[torch.cuda.Stream]:
# if enable and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# yield stream
# else:
# yield
_THREAD_EXEC: Optional[RouteExecutor] = None
def get_executor() -> RouteExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = RouteExecutor()
return _THREAD_EXEC
import torch
from torch import Tensor
from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from contextlib import contextmanager
from typing import *
class StraightContext:
def __init__(self, this) -> None:
self.this = this
class Straight(nn.Module):
def __init__(self,
num_nodes: int,
p: int = 2,
dim: int = -1,
beta: float = 1.0,
) -> None:
super().__init__()
self.num_nodes = num_nodes
self.norm_p = p
self.norm_dim = dim
self.norm_beta = beta
self.register_buffer("grad_norm", torch.ones(num_nodes))
self.reset_parameters()
def reset_parameters(self):
grad_norm = self.get_buffer("grad_norm")
nn.init.constant_(grad_norm, 1.0)
def forward(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor:
with self._manager(self):
return StraightFunction.apply(values, active)
def multinomial(self,
num_samples: int,
replacement: bool = False,
) -> Tensor:
w = self.get_buffer("grad_norm")
if num_samples <= 0:
return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# print(w)
w = w / w.sum()
return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
def multinomial_mask(self,
num_samples: int,
replacement: bool = False,
) -> Tensor:
device = self.get_buffer("grad_norm").device
if num_samples <= 0:
return torch.ones(self.num_nodes, dtype=torch.bool, device=device)
w = self.multinomial(num_samples, replacement)
m = torch.zeros(self.num_nodes, dtype=torch.bool, device=device)
m[w] = True
return m
@contextmanager
def _manager(self, this):
global _global_straight_context
stacked = _global_straight_context
try:
_global_straight_context = StraightContext(this=this)
yield _global_straight_context
finally:
_global_straight_context = stacked
class StraightFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
values: Tensor,
active: Optional[Tensor],
):
stx = _last_global_straight_context()
this: Straight = stx.this
if this.training:
ctx.this = this
ctx.save_for_backward(active)
return values
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
this: Straight = ctx.this
active, = ctx.saved_tensors
grad_norm = this.get_buffer("grad_norm")
if this.norm_beta != 1.0:
grad_norm.mul_(this.norm_beta)
if active is None:
x = grad.detach()
x = x.norm(p=this.norm_p, dim=this.norm_dim)
grad_norm[:] = x
else:
if grad.size(0) == grad_norm.size(0):
x = grad.detach()[active]
else:
x = grad.detach()
x = x.norm(p=this.norm_p, dim=this.norm_dim)
grad_norm[active] = x
return grad, None
#### private functions
_global_straight_context: Optional[StraightContext] = None
def _last_global_straight_context() -> StraightContext:
global _global_straight_context
assert _global_straight_context is not None
return _global_straight_context
if __name__ == "__main__":
s = Straight(3, beta=1.1)
x = torch.rand(3, 10).requires_grad_()
m = torch.tensor([0, 1, 0], dtype=torch.bool)
s(x).sum().backward()
print(s.grad_norm)
print(s.multinomial(2))
s(x, m).sum().backward()
print(s.grad_norm)
print(s.multinomial(2))
s(x[m], m).sum().backward()
print(s.grad_norm)
print(s.multinomial(2))
print(s.multinomial_mask(2))
\ No newline at end of file
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# class StraightContext:
# def __init__(self, this, g) -> None:
# self.this = this
# self.g = g
# class Straight(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_samples: int,
# norm_kwargs: Optional[Dict[str, Any]] = None,
# beta: float = 1.0,
# prev: Optional[List[Any]] = None,
# ) -> None:
# super().__init__()
# assert num_samples <= num_nodes
# self.num_nodes = num_nodes
# self.num_samples = num_samples
# self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
# self.prev = prev
# self.register_buffer("last_w", torch.ones(num_nodes))
# self._next_idx = None
# self._next_shrink_helper = None
# self.reset_parameters()
# def reset_parameters(self):
# last_w = self.get_buffer("last_w")
# nn.init.constant_(last_w, 1.0)
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# g,
# ) -> Tensor:
# with self._manager(self, g):
# return StraightFunction.apply(val, idx)
# def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
# if not self.training:
# return None, None
# next_idx = self._next_idx
# self._next_idx = None
# next_sh = self._next_shrink_helper
# self._next_shrink_helper = None
# return next_idx, next_sh
# def _sample_next(self) -> Tensor:
# w = self.get_buffer("last_w")
# if self._next_idx is None:
# if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
# else:
# self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
# elif self.num_samples < self._next_idx.size(0):
# idx = self.sample_impl(w[self._next_idx])
# self._next_idx = self._next_idx[idx]
# return self._next_idx
# def sample_impl(self, w: Tensor) -> Tensor:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
# # def multinomial(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# # w = w / w.sum()
# # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
# # def multinomial_mask(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
# # w = self.multinomial(num_samples, replacement)
# # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
# # return m
# @contextmanager
# def _manager(self, this, g):
# global _global_straight_context
# stacked = _global_straight_context
# try:
# _global_straight_context = StraightContext(this=this, g=g)
# yield _global_straight_context
# finally:
# _global_straight_context = stacked
# class StraightFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# ):
# from ..graph import DistGraph
# stx = _last_global_straight_context()
# this: Straight = stx.this
# g: DistGraph = stx.g
# last_w = this.get_buffer("last_w")
# if idx is None:
# assert val.size(0) == last_w.size(0)
# else:
# assert val.size(0) == idx.size(0)
# if this.training:
# ctx.this = this
# ctx.g = g
# ctx.save_for_backward(idx)
# return val
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
# if this.prev is not None or this._next_idx is None:
# from ..nn.convs.utils import ShrinkHelper
# dst_idx = this._sample_next()
# if this.prev is not None and this._next_idx is not None:
# this._next_shrink_helper = ShrinkHelper(g, dst_idx)
# src_idx = this._next_shrink_helper.src_idx
# work = g.route.gather_backward(
# src_idx, src_idx, async_op=False, return_idx=True)
# prev_dst_idx = work.get_idx()
# for p in this.prev:
# assert isinstance(p, Straight)
# p._next_idx = prev_dst_idx
# return grad, None
# #### private functions
# _global_straight_context: Optional[StraightContext] = None
# def _last_global_straight_context() -> StraightContext:
# global _global_straight_context
# assert _global_straight_context is not None
# return _global_straight_context
# if __name__ == "__main__":
# s = Straight(3, beta=1.1)
# x = torch.rand(3, 10).requires_grad_()
# m = torch.tensor([0, 1, 0], dtype=torch.bool)
# s(x).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x, m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x[m], m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# print(s.multinomial_mask(2))
\ No newline at end of file
......@@ -5,26 +5,23 @@ from torch import Tensor
from typing import *
from contextlib import contextmanager
import torch_sparse
# import torch_sparse
from .ndata import NData
from .edata import EData
from .utils import init_local_edge_index
from ..core import Route
from ..core import MessageCache, Route
class DistGraph:
def __init__(self,
ids: Tensor,
edge_index: Tensor,
args: Dict[str, Any] = {},
device: Union[torch.device, str, int, None] = None,
num_features: int,
num_layers: int,
cache_device: str = "cpu",
**args: Dict[str, Any],
):
# graph's target device
if device is None:
device = ids.device
self.device = torch.device(device)
# build local_edge_index
dst_ids = ids
src_ids, local_edge_index = init_local_edge_index(
......@@ -34,8 +31,15 @@ class DistGraph:
self._src_ids = src_ids
self._dst_ids = dst_ids
self._local_edge_index = local_edge_index
self._local_edge_ptr: Optional[Tensor] = None
self._message_cache = MessageCache(
src_ids=src_ids,
dst_ids=dst_ids,
edge_index=local_edge_index,
num_features=num_features,
num_layers=num_layers,
cache_device=cache_device,
bipartite=False,
)
# node's attributes
self.ndata = NData(
......@@ -49,24 +53,30 @@ class DistGraph:
)
# graph's attributes
self.args = args
self.args = dict(args)
# route table
self.route = Route(dst_ids, src_ids)
def to(self, device):
self._message_cache.to(device)
return self
def cache_data_to(self, device):
self._message_cache.cached_data_to(device)
@property
def edge_index(self) -> Tensor:
if self._local_edge_index.device != self.device:
self._local_edge_index = self._local_edge_index.to(self.device)
return self._local_edge_index
def cache(self) -> MessageCache:
return self._message_cache
@property
def edge_ptr(self) -> Optional[Tensor]:
if self._local_edge_ptr is None:
return None
if self._local_edge_ptr.device != self.device:
self._local_edge_ptr = self._local_edge_ptr.to(self.device)
return self._local_edge_ptr
def route(self) -> Route:
return self._message_cache.route
@property
def device(self) -> torch.device:
return self.edge_index.device
@property
def edge_index(self) -> Tensor:
return self._message_cache.edge_index
@property
def src_ids(self) -> Tensor:
......@@ -88,16 +98,6 @@ class DistGraph:
def dst_size(self) -> int:
return self._dst_ids.size(0)
@torch.no_grad()
def gather(self, values: Tensor, direct: str = "dst_to_src") -> Tensor:
if direct == "dst_to_src":
work = self.route.gather_forward(values, None, async_op=False, return_active=False)
elif direct == "src_to_dst":
work = self.route.gather_backward(values, None, async_op=False, return_active=False)
else:
raise ValueError(f"direct must be 'src_to_dst' or 'dst_to_src', but got '{direct}'")
return work.get_values()
@contextmanager
def scoped_manager(self):
stacked_ndata = self.ndata
......
......@@ -6,6 +6,7 @@ from typing import *
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
......@@ -17,6 +18,10 @@ def init_local_edge_index(
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
......@@ -26,8 +31,13 @@ def init_local_edge_index(
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
local_edge_index = xmp[edge_index.flatten()].view_as(edge_index)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
......
from .convs import GCNConv, GATConv, GINConv
from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicJKOptions, BasicStraightOptions
# from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
# from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
class GCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGCNConv(in_channels, out_channels, **kwargs)
# class ShrinkGCN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGCNConv(in_channels, out_channels, **kwargs)
class GAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGATConv(in_channels, out_channels, **kwargs)
# class ShrinkGAT(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGATConv(in_channels, out_channels, **kwargs)
class GIN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGINConv(in_channels, out_channels, **kwargs)
# class ShrinkGIN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import copy
import inspect
from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from dataclasses import dataclass
from ..graph.distgraph import DistGraph
from ..core.gather import Gather
from ..core.straight import Straight
@dataclass
class BasicInputOptions:
weight: bool = True
bias: bool = True
gather: bool = True
gather_first: bool = False
dropout: float = 0.0
act: Optional[str] = None
act_kwargs: Optional[Dict[str, Any]] = None
act_first: bool = False
norm: Optional[str] = None
norm_kwargs: Optional[Dict[str, Any]] = None
straight_enabled: bool = True
@dataclass
class BasicLayerOptions:
in_channels: int
hidden_channels: int
num_layers: int
out_channels: Optional[int] = None
gather_beta: float = 1.0
dropout: float = 0.0
act: Optional[str] = "relu"
act_kwargs: Optional[Dict[str, Any]] = None
act_first: bool = False
norm: Optional[str] = None
norm_kwargs: Optional[Dict[str, Any]] = None
@dataclass
class BasicJKOptions:
jk_mode: Optional[str] = None
@dataclass
class BasicStraightOptions:
enabled: bool = False
p: int = 2
beta: float = 1.0
class BasicGNN(nn.Module):
def __init__(
self,
g: DistGraph,
layer_options: BasicLayerOptions,
input_options: BasicInputOptions = BasicInputOptions(),
jk_options: BasicJKOptions = BasicJKOptions(),
straight_options: BasicStraightOptions = BasicStraightOptions(),
**kwargs,
):
super().__init__()
self.in_channels = in_channels = layer_options.in_channels
self.hidden_channels = hidden_channels = layer_options.hidden_channels
self.num_layers = num_layers = layer_options.num_layers
# default out_channels is hidden_channels
self.out_channels = out_channels = hidden_channels \
if layer_options.out_channels is None else layer_options.out_channels
self.gather_beta = gather_beta = layer_options.gather_beta
self.layer_options = layer_options
self.input_options = input_options
self.jk_options = jk_options
self.straight_options = straight_options
# initialize input layer
if input_options.weight:
self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias)
in_channels = hidden_channels
if straight_options.enabled and input_options.straight_enabled:
if input_options.gather_first:
self.straight_x = Straight(
g.src_size, p=straight_options.p, beta=straight_options.beta)
else:
self.straight_x = Straight(
g.dst_size, p=straight_options.p, beta=straight_options.beta)
if input_options.gather:
self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta)
if input_options.act is not None:
self.act_x = activation_resolver(
input_options.act, **(input_options.act_kwargs or {}))
if input_options.norm is not None:
self.norm_x = normalization_resolver(
input_options.norm, in_channels, **(input_options.norm_kwargs or {}))
# initialize activation layers
if layer_options.act is not None:
self.acts = nn.ModuleList()
for _ in range(num_layers - 1):
self.acts.append(activation_resolver(
layer_options.act, **(layer_options.act_kwargs or {})))
if jk_options.jk_mode is not None:
self.acts.append(activation_resolver(
layer_options.act, **(layer_options.act_kwargs or {})))
# initialize normalization layers
if layer_options.norm is not None:
self.norms = nn.ModuleList()
for _ in range(num_layers - 1):
self.norms.append(normalization_resolver(
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
if jk_options.jk_mode is not None:
self.norms.append(normalization_resolver(
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# initialize straight layers
if straight_options.enabled:
self.straights = nn.ModuleList()
for _ in range(num_layers - 1):
self.straights.append(Straight(
g.dst_size, p=straight_options.p, beta=straight_options.beta))
if jk_options.jk_mode is not None:
self.straights.append(Straight(
g.dst_size, p=straight_options.p, beta=straight_options.beta))
# initialize gather and conv layers
self.convs = nn.ModuleList()
self.gathers = nn.ModuleList()
for _ in range(num_layers - 1):
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta))
in_channels = hidden_channels
if jk_options.jk_mode is None:
self.convs.append(
self.init_conv(in_channels, out_channels, **kwargs))
self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings
else:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings
if jk_options.jk_mode != "last":
self.jk = JumpingKnowledge(jk_options.jk_mode, hidden_channels, num_layers)
if jk_options.jk_mode == "cat":
jk_channels = num_layers * hidden_channels
else:
jk_channels = hidden_channels
self.lin_jk = nn.Linear(jk_channels, out_channels)
self.reset_parameters()
def init_conv(self,
in_channels: int,
out_channels: int,
**kwargs
) -> nn.Module:
raise NotImplementedError
def reset_parameters(self):
if hasattr(self, "lin_x"):
self.lin_x.reset_parameters()
if hasattr(self, "straight_x"):
self.straight_x.reset_parameters()
if hasattr(self, "gather_x"):
self.gather_x.reset_parameters()
if hasattr(self, "norm_x"):
self.norm_x.reset_parameters()
if hasattr(self, "norms"):
for norm in self.norms:
norm.reset_parameters()
if hasattr(self, "straights"):
for straight in self.straights:
straight.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for gather in self.gathers:
gather.reset_parameters()
if hasattr(self, "jk"):
self.jk.reset_parameters()
if hasattr(self, "lin_jk"):
self.lin_jk.reset_parameters()
def forward(
self,
g: DistGraph,
):
# from ..utils.printer import main_print
x = g.ndata["x"]
# main_print(f"features_x: {x.size()}")
sample_k = g.args.get("sample_k", 0)
async_op = g.args.get("async_op", False)
# input layer
if hasattr(self, "gather_x") and self.input_options.gather_first:
x, m = self.gather_x(x, None, g.route, async_op=async_op)
if self.input_options.dropout != 0.0:
x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# main_print(f"x_x: {x.size()} {x.requires_grad}")
if hasattr(self, "lin_x"):
# main_print("enter lin_x")
x = self.lin_x(x)
# main_print(f"lin_x: {x.size()} {x.requires_grad}")
if hasattr(self, "act_x") and self.input_options.act_first:
x = self.act_x(x)
if hasattr(self, "norm_x"):
x = self.norm_x(x)
if hasattr(self, "act_x") and not self.input_options.act_first:
x = self.act_x(x)
# main_print(f"act_x: {x.size()} {x.requires_grad}")
# straight sampler
if hasattr(self, "straight_x"):
x = self.straight_x(x)
m = self.straight_x.multinomial_mask(sample_k)
else:
m = None
# main_print(f"straight_x: {x.size()} {'' if m is None else m.size()} {x.requires_grad}")
# gather features
if hasattr(self, "gather_x") and not self.input_options.gather_first:
x, m = self.gather_x(x, m, g.route, async_op=async_op)
# main_print(f"rg {0}: {x.requires_grad}")
# conv layers
xs: List[Tensor] = []
for i in range(self.num_layers):
if self.layer_options.dropout != 0.0:
x = F.dropout(x, p=self.layer_options.dropout, training=self.training)
with g.scoped_manager():
g.ndata["x"] = x
if m is not None:
g.ndata["m"] = m
x, m = self.convs[i](g)
# main_print(f"rg {i+1}: {x.requires_grad}")
if i == self.num_layers - 1 and not hasattr(self, "jk"):
x = self.gathers[i].fuse_embeddings(x, m, inplace=True)
break
if hasattr(self, "acts") and self.layer_options.act_first:
x = self.acts[i](x)
if hasattr(self, "norms"):
x = self.norms[i](x)
if hasattr(self, "acts") and not self.layer_options.act_first:
x = self.acts[i](x)
if hasattr(self, "straights"):
x = self.straights[i](x, m)
if i == self.num_layers - 1:
x = self.gathers[i].fuse_embeddings(x, m, inplace=True)
else:
x, m = self.gathers[i](x, m, g.route, async_op=async_op)
if hasattr(self, "jk"):
xs.append(x[:g.dst_size])
x = self.jk(xs) if hasattr(self, "jk") else x
x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
# main_print(f"out: {x.size()}")
return x
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch import Tensor
# from typing import *
# import copy
# import inspect
# from torch_geometric.nn import JumpingKnowledge
# from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
# from dataclasses import dataclass
# from ..core.cache import NodeProbe
# from ..graph.distgraph import DistGraph
# # from ..core.gather import Gather
# # from ..core.straight import Straight
# @dataclass
# class BasicInputOptions:
# weight: bool = True
# bias: bool = True
# gather: bool = True
# gather_first: bool = False
# dropout: float = 0.0
# act: Optional[str] = None
# act_kwargs: Optional[Dict[str, Any]] = None
# act_first: bool = False
# norm: Optional[str] = None
# norm_kwargs: Optional[Dict[str, Any]] = None
# straight_enabled: bool = True
# straight_num_samples: Optional[int] = None
# @dataclass
# class BasicLayerOptions:
# in_channels: int
# hidden_channels: int
# num_layers: int
# out_channels: Optional[int] = None
# gather_beta: float = 1.0
# dropout: float = 0.0
# act: Optional[str] = "relu"
# act_kwargs: Optional[Dict[str, Any]] = None
# act_first: bool = False
# norm: Optional[str] = None
# norm_kwargs: Optional[Dict[str, Any]] = None
# jk_mode: Optional[str] = None
# @dataclass
# class BasicStraightOptions:
# enabled: bool = False
# num_samples: Optional[int] = None
# beta: float = 1.0
# class BasicGNN(nn.Module):
# def __init__(
# self,
# g: DistGraph,
# layer_options: BasicLayerOptions,
# input_options: BasicInputOptions = BasicInputOptions(),
# straight_options: BasicStraightOptions = BasicStraightOptions(),
# **kwargs,
# ):
# super().__init__()
# num_samples = straight_options.num_samples or g.dst_size
# prev_straight = None
# self.in_channels = in_channels = layer_options.in_channels
# self.hidden_channels = hidden_channels = layer_options.hidden_channels
# self.num_layers = num_layers = layer_options.num_layers
# # default out_channels is hidden_channels
# self.out_channels = out_channels = hidden_channels \
# if layer_options.out_channels is None else layer_options.out_channels
# self.gather_beta = gather_beta = layer_options.gather_beta
# self.layer_options = layer_options
# self.input_options = input_options
# self.straight_options = straight_options
# # initialize input layer
# if input_options.weight:
# self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias)
# in_channels = hidden_channels
# if straight_options.enabled and input_options.straight_enabled and not input_options.gather_first:
# sns = input_options.straight_num_samples or g.dst_size
# self.straight_x = Straight(g.dst_size, sns, beta=straight_options.beta, prev=prev_straight)
# prev_straight = [self.straight_x]
# if input_options.gather:
# self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta)
# if input_options.act is not None:
# self.act_x = activation_resolver(
# input_options.act, **(input_options.act_kwargs or {}))
# if input_options.norm is not None:
# self.norm_x = normalization_resolver(
# input_options.norm, in_channels, **(input_options.norm_kwargs or {}))
# # initialize activation layers
# if layer_options.act is not None:
# self.acts = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.acts.append(activation_resolver(
# layer_options.act, **(layer_options.act_kwargs or {})))
# if layer_options.jk_mode is not None:
# self.acts.append(activation_resolver(
# layer_options.act, **(layer_options.act_kwargs or {})))
# # initialize normalization layers
# if layer_options.norm is not None:
# self.norms = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.norms.append(normalization_resolver(
# layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# if layer_options.jk_mode is not None:
# self.norms.append(normalization_resolver(
# layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# # initialize straight layers
# if straight_options.enabled:
# self.straights = nn.ModuleList()
# for _ in range(num_layers):
# prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
# self.straights.append(prev_straight[0])
# # if layer_options.jk_mode is not None:
# # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
# # self.straights.append(prev_straight[0])
# # initialize gather and conv layers
# self.convs = nn.ModuleList()
# self.gathers = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.convs.append(
# self.init_conv(in_channels, hidden_channels, **kwargs))
# self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta))
# in_channels = hidden_channels
# if layer_options.jk_mode is None:
# self.convs.append(
# self.init_conv(in_channels, out_channels, **kwargs))
# self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings
# else:
# self.convs.append(
# self.init_conv(in_channels, hidden_channels, **kwargs))
# self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings
# if layer_options.jk_mode != "last":
# self.jk = JumpingKnowledge(layer_options.jk_mode, hidden_channels, num_layers)
# if layer_options.jk_mode == "cat":
# jk_channels = num_layers * hidden_channels
# else:
# jk_channels = hidden_channels
# self.lin_jk = nn.Linear(jk_channels, out_channels)
# self.reset_parameters()
# def init_conv(self,
# in_channels: int,
# out_channels: int,
# **kwargs
# ) -> nn.Module:
# raise NotImplementedError
# def reset_parameters(self):
# if hasattr(self, "lin_x"):
# self.lin_x.reset_parameters()
# if hasattr(self, "straight_x"):
# self.straight_x.reset_parameters()
# if hasattr(self, "gather_x"):
# self.gather_x.reset_parameters()
# if hasattr(self, "norm_x"):
# self.norm_x.reset_parameters()
# if hasattr(self, "norms"):
# for norm in self.norms:
# norm.reset_parameters()
# if hasattr(self, "straights"):
# for straight in self.straights:
# straight.reset_parameters()
# for conv in self.convs:
# conv.reset_parameters()
# for gather in self.gathers:
# gather.reset_parameters()
# if hasattr(self, "jk"):
# self.jk.reset_parameters()
# if hasattr(self, "lin_jk"):
# self.lin_jk.reset_parameters()
# def forward(
# self,
# g: DistGraph,
# ):
# from ..utils.printer import main_print, sync_print
# x = g.ndata["x"]
# async_op = g.args.get("async_op", False)
# # input layer
# if hasattr(self, "straight_x"):
# dst_idx, _ = self.straight_x.pop_next_shrink_helper()
# else:
# dst_idx = None
# if dst_idx is not None:
# x = x[dst_idx]
# if hasattr(self, "gather_x") and self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# if self.input_options.dropout != 0.0:
# x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# if hasattr(self, "lin_x"):
# x = self.lin_x(x)
# if hasattr(self, "act_x") and self.input_options.act_first:
# x = self.act_x(x)
# if hasattr(self, "norm_x"):
# x = self.norm_x(x)
# if hasattr(self, "act_x") and not self.input_options.act_first:
# x = self.act_x(x)
# # straight sampler
# if hasattr(self, "straight_x"):
# # sync_print(f"{x.size()} - {'' if dst_idx is None else dst_idx.size()}")
# x = self.straight_x(x, dst_idx, g)
# # gather features
# if hasattr(self, "gather_x") and not self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# # conv layers
# xs: List[Tensor] = []
# for i in range(self.num_layers):
# if self.layer_options.dropout != 0.0:
# x = F.dropout(x, p=self.layer_options.dropout, training=self.training)
# if hasattr(self, "straights"):
# straight: Straight = self.straights[i]
# dst_idx, sh = straight.pop_next_shrink_helper()
# with g.scoped_manager():
# g.ndata["x"] = x
# x = self.convs[i](g, sh=sh, dst_idx=dst_idx)
# if i == self.num_layers - 1 and not hasattr(self, "jk"):
# x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
# break
# if hasattr(self, "acts") and self.layer_options.act_first:
# x = self.acts[i](x)
# if hasattr(self, "norms"):
# x = self.norms[i](x)
# if hasattr(self, "acts") and not self.layer_options.act_first:
# x = self.acts[i](x)
# if hasattr(self, "straights"):
# x = self.straights[i](x, dst_idx, g)
# if i == self.num_layers - 1:
# x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
# else:
# x, _ = self.gathers[i](x, dst_idx, g.route, async_op=async_op)
# if hasattr(self, "jk"):
# xs.append(x[:g.dst_size])
# x = self.jk(xs) if hasattr(self, "jk") else x
# x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
# # sync_print(f"out: {x.size()}")
# return x
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
from .shrink_gcn_conv import ShrinkGCNConv
from .shrink_gat_conv import ShrinkGATConv
from .shrink_gin_conv import ShrinkGINConv
# from .shrink_gcn_conv import ShrinkGCNConv
# from .shrink_gat_conv import ShrinkGATConv
# from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
......@@ -10,7 +10,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gat_conv import GATConv
from .utils import ShrinkData
from .utils import ShrinkHelper
......@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels
x = g.ndata["x"]
s = ShrinkData(g, m)
x = x[s.src_mask]
edge_index = s.edge_index
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
x = (x @ self.weight).view(-1, H, C)
src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[s.edge_mask]
edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
......@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv):
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=s.dst_size,
num_nodes=sh.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat:
x = x.view(-1, H * C)
......@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv):
if self.bias is not None:
x += self.bias
return x, s.dst_mask
return x
......@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gcn_conv import GCNConv
from .utils import ShrinkData
from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv):
def __init__(self,
......@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
s = ShrinkData(g, m)
x = x[s.src_mask]
gcn_norm = gcn_norm[s.edge_mask]
edge_index = s.edge_index
x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
edge_index = sh.edges
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None:
x += self.bias
return x, s.dst_mask
return x
......@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gin_conv import GINConv
from .utils import ShrinkData
from .utils import ShrinkHelper
......@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
s = ShrinkData(g, m)
edge_index = s.edge_index
sh = ShrinkHelper(g, dst_idx)
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
z = x[s.src_mask][edge_index[0]]
z = scatter_sum(z, edge_index[1], dim=0, dim_size=s.dst_size)
x = z + (1 + self.eps) * x[s.dst_mask]
return self.nn(x), s.dst_mask
\ No newline at end of file
z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
x = z + (1 + self.eps) * dst_x
return self.nn(x)
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
class ShrinkData:
def __init__(self, g: DistGraph, m: Tensor) -> None:
edge_index = g.edge_index
# class ShrinkHelper:
# def __init__(self, g, dst_idx: Tensor) -> None:
# from starrygl.graph import DistGraph
# g: DistGraph = g
# src_to_dst edge's mask
m = m[edge_index[0]]
# self.device = dst_idx.device
# dst's mask
m = edge_index[1][m]
dst_m: Tensor = torch.zeros(g.dst_size, dtype=torch.bool, device=m.device).index_fill_(0, m, 1)
self.dst_mask = dst_m
# dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
# dst_m.index_fill_(0, dst_idx, 1)
# dst_to_src edge's mask
m = dst_m[edge_index[1]]
self.edge_mask = m
# edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
# edge_index = g.edge_index[:, edge_idx]
# src's mask
edge_index = edge_index[:,m]
src_m: Tensor = torch.zeros(g.src_size, dtype=torch.bool, device=m.device).index_fill_(0, edge_index[0], 1)
self.src_mask = src_m
# src_idx = edge_index[0]
# src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
# src_m.index_fill_(0, src_idx, 1)
# src_idx = torch.where(src_m)[0]
self._src_size = src_m.count_nonzero()
self._dst_size = dst_m.count_nonzero()
# imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
# print(f"{dst_m.size()}_{self._dst_size}")
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# edge_index
imp = torch.empty(g.src_size, dtype=torch.long, device=m.device)
imp[src_m] = torch.arange(self._src_size, dtype=torch.long, device=m.device)
src = imp[edge_index[0]]
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
# dst = imp[edge_index[1]]
imp = torch.empty(g.dst_size, dtype=torch.long, device=m.device)
imp[dst_m] = torch.arange(self._dst_size, dtype=torch.long, device=m.device)
dst = imp[edge_index[1]]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
self.edge_index = torch.vstack([src, dst])
@property
def src_size(self) -> int:
return self._src_size.item()
@property
def dst_size(self) -> int:
return self._dst_size.item()
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
\ No newline at end of file
......@@ -7,6 +7,7 @@ import os
from .degree import compute_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model(
net: nn.Module,
......@@ -20,7 +21,7 @@ def convert_parallel_model(
for name, buffer in net.named_buffers():
if name.endswith("last_embd"):
continue
if name.endswith("grad_norm"):
if name.endswith("last_w"):
continue
dist.broadcast(buffer, src=0)
return net
......@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device:
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
get_executor() # initialize route executor
return device
\ No newline at end of file
......@@ -6,7 +6,7 @@ from typing import *
from torch_scatter import scatter_sum
from ..graph.distgraph import DistGraph
from ..core.gather import Gather
# from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
......
......@@ -13,4 +13,4 @@ def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
dist.barrier()
\ No newline at end of file
# dist.barrier()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment