Commit a4c4cadb by Wenjie Huang

with bugs

parent 06d10ed3
from .a2a import all_to_all, with_gloo, with_nccl from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache # from .cache import EmbeddingCache
from .gather import Gather # from .gather import Gather
from .route import Route, GatherWork # from .route import Route, GatherWork
\ No newline at end of file from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
...@@ -30,12 +30,14 @@ class Works: ...@@ -30,12 +30,14 @@ class Works:
def all_to_all( def all_to_all(
output_tensor_list: List[Tensor], output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor], input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works: ) -> Works:
assert len(output_tensor_list) == len(input_tensor_list) assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl(): if with_nccl():
work = dist.all_to_all( work = dist.all_to_all(
output_tensor_list=output_tensor_list, output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list, input_tensor_list=input_tensor_list,
group=group,
async_op=True, async_op=True,
) )
return Works(work) return Works(work)
...@@ -48,8 +50,8 @@ def all_to_all( ...@@ -48,8 +50,8 @@ def all_to_all(
send_i = (rank + i) % world_size send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i) send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i) recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w) works.push(recv_w, send_w)
output_tensor_list[rank][:] = input_tensor_list[rank] output_tensor_list[rank][:] = input_tensor_list[rank]
......
import torch
from torch import Tensor
from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
def wait(self):
pass
def get(self):
pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
def wait(self):
pass
def get(self) -> Tensor:
return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
import torch import torch
import torch.nn as nn
from torch.utils import hooks
from torch import Tensor from torch import Tensor
from typing import * from typing import *
class EmbeddingCache: from .route import Route, RouteWorkBase
def __init__(self) -> None: from .shrink import ShrinkData
self.values: Optional[Tensor] = None from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
self.active: Optional[Tensor] = None from .route import get_executor as get_route_executor
def get_values(self) -> Optional[Tensor]:
return self.values class ProbeLayer:
def __init__(self, layer_id, probe_obj) -> None:
def get_active(self) -> Optional[Tensor]: self.layer_id: int = layer_id
return self.active self.probe_obj: NodeProbe = probe_obj
self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
def get(self) -> Tuple[Optional[Tensor], Optional[Tensor]]: self._val_hook_handle: Optional[hooks.RemovableHandle] = None
return self.get_values(), self.get_active()
def to(self, device):
def set(self, values: Tensor, active: Optional[Tensor] = None): self.last_w = self.last_w.to(device)
values = values.detach() return self
active = active.detach()
if active is None: def remove_val_hook(self):
if self.values is None: if self._val_hook_handle is not None:
self.values = values.clone() self._val_hook_handle.remove()
else: self._val_hook_handle = None
self.values[:] = values
self.active = None def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
def hook(grad: Tensor):
from starrygl.utils.printer import main_print
import time
self.probe_obj.update_sample_w(self.last_w, grad, idx)
self.remove_val_hook()
self._backward_sample(False)
main_print(self.layer_id, grad.size(0), idx.size(0))
self.val_hook_handle = val.register_hook(hook)
def warmup_sample(self):
assert self._is_last_layer()
for probe_layer in self.probe_obj.layers[::-1]:
probe_layer._backward_sample()
# max_norm = max(probe_layer.last_w.max().item(), 1.0)
# probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
def _backward_sample(self, x: bool = True):
dst_idx = self._collect_prev_dst_idx()
if dst_idx is None:
return
for cache in self.probe_obj.hook_caches:
next_cache_layer = self._next_cache_layer(cache)
shrink = next_cache_layer.set_shrink_data(dst_idx)
src_idx = shrink.src_idx
work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
next_cache_layer.sync_pull_work(work)
if not self._is_first_layer() and x:
next_cache_layer.sync_backward_sample_work(src_idx)
def _collect_prev_dst_idx(self) -> Optional[Tensor]:
if len(self.probe_obj.hook_caches) == 0:
return None
if self._is_last_layer():
return self._wrapped_sample(self.probe_obj.num_samples, None)
if len(self.probe_obj.hook_caches) == 1:
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
dst_idx = prev_cache_layer.sync_backward_sample_work()
else: else:
if self.values is None: tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
self.values = torch.zeros( for cache in self.probe_obj.hook_caches:
size=active.shape[:1] + values.shape[1:], prev_cache_layer = self._prev_cache_layer(cache)
dtype=values.dtype, prev_idx = prev_cache_layer.sync_backward_sample_work()
device=values.device, if prev_idx is not None:
) tmp.index_fill_(0, prev_idx, 1)
dst_idx = torch.where(tmp)[0]
if values.size(0) == active.size(0): if dst_idx.size(0) == 0:
self.values[active] = values[active] dst_idx = None
else:
self.values[active] = values if dst_idx is None:
return None
if self.active is None: return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
self.active = active.clone()
else: def _prev_cache_layer(self, cache):
self.active[:] = active cache: MessageCache = cache
i = self.layer_id + 1
def fuse(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor: if i < self.probe_obj.num_layers:
if active is None: return cache.layers[i]
return values return None
def _next_cache_layer(self, cache):
cache: MessageCache = cache
return cache.layers[self.layer_id]
def _is_last_layer(self):
return self.layer_id + 1 >= self.probe_obj.num_layers
def _is_first_layer(self):
return self.layer_id == 0
def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
w = self.last_w
if idx is None:
return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
if 0 < k and k < idx.size(0):
t = self.probe_obj.sample(w[idx], k)
return idx[t]
else: else:
if self.values is None: return idx
x = torch.zeros(
size=active.shape[:1] + values.shape[1:], class NodeProbe:
dtype=values.dtype, def __init__(self,
device=values.device, num_nodes: int,
) num_layers: int,
else: num_samples: int = 0,
x = self.values.clone() p: str = "fro",
dim: int = -1,
if values.size(0) == active.size(0): beta: float = 1.0,
x[active] = values[active] ) -> None:
else: super().__init__()
x[active] = values self.num_nodes = num_nodes
return x self.num_layers = num_layers
self.num_samples = num_samples
\ No newline at end of file self.p = p
self.dim = dim
self.beta = beta
self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
self.hook_caches: Set[MessageCache] = set()
def to(self, device):
for i in range(self.num_layers):
self.layers[i].to(device)
return self
def assign_message_cache(self, cache):
cache: MessageCache = cache
assert self.num_layers == cache.num_layers
self.hook_caches.add(cache)
def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
val_norm = grad.norm(p=self.p, dim=self.dim)
if idx is None:
last_w[:] = val_norm
else:
if self.beta != 1.0:
last_w.mul_(self.beta)
last_w[idx] = val_norm
def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
w = w / w.sum()
return torch.multinomial(w, num_samples=k, replacement=False)
def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
self.layers[i].register_val_hook(val, idx)
return val
class CacheLayer:
def __init__(self, layer_id, cache_obj, num_features) -> None:
self.layer_id: int = layer_id
self.cache_obj: MessageCache = cache_obj
self.data: Tensor = torch.randn(
size=(self.cache_obj.src_size, num_features),
dtype=torch.float32,
device=self.cache_obj.cache_device,
)
self.state: Dict[str, Any] = {}
self.shrink: Optional[ShrinkData] = None
self._async_push_work: Optional[AsyncCopyWorkBase] = None
self._async_pull_work: Optional[AsyncCopyWorkBase] = None
self._backward_sample_work: Optional[RouteWorkBase] = None
def to(self, device):
if self.shrink is not None:
self.shrink = self.shrink.to(device)
return self
def cache_data_to(self):
self.data = self.data.to(self.cache_obj.cache_device)
return self
def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
self.shrink = self.cache_obj.new_shrink_data(dst_idx)
return self.shrink
def get_shrink_data(self) -> Optional[ShrinkData]:
return self.shrink
# def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# if src_idx is None:
# return
# assert self._backward_sample_work is None
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# work = self._backward_sample_work
# self._backward_sample_work = None
# return work
def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
dst_idx = None
if self._backward_sample_work is not None:
_, dst_idx = self._backward_sample_work.get()
if src_idx is not None:
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
self._backward_sample_work = work
else:
self._backward_sample_work = None
return dst_idx
def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
if self._async_push_work is not None:
self._async_push_work.get()
self._async_push_work = work
def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
out = None
if self._async_pull_work is not None:
out = self._async_pull_work.get()
self._async_pull_work = work
return out
class MessageCache:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
edge_index: Tensor,
num_features: Union[int, torch.Size],
num_layers: int,
cache_device: Union[str, torch.device, None],
bipartite: bool = False,
) -> None:
self.src_size = src_ids.size(0)
self.dst_size = dst_ids.size(0)
self.edge_index = edge_index
self.num_layers = num_layers
if cache_device is None:
cache_device = edge_index.device
self.cache_device = cache_device
self.bipartite = bipartite
self.route = Route(dst_ids, src_ids, bipartite=bipartite)
self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
def to(self, device):
self.edge_index = self.edge_index.to(device)
self.route = self.route.to(device)
for i in range(self.num_layers):
self.layers[i].to(device)
return self
def cached_data_to(self, device):
self.cache_device = torch.device(device)
for i in range(self.num_layers):
self.layers[i].cache_data_to()
return self
# @property
# def is_offload(self) -> bool:
# return self.edge_index.device != self.cache_device
def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
if dst_idx is None:
return None
return ShrinkData(
src_size=self.src_size,
dst_size=self.dst_size,
dst_idx=dst_idx,
edge_index=self.edge_index,
bipartite=self.bipartite,
)
# def clear_shrink_data(self):
# for i in range(self.num_layers):
# self.layers[i].shrink = None
def replace_layer_data(self, i: int, data: Tensor):
layer = self.layers[i]
assert layer.data.size(0) == data.size(0)
layer.data = data
return self
def update_cache(self,
i: int,
val: Tensor,
idx: Optional[Tensor],
async_op: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
layer: CacheLayer = self.layers[i]
src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# full graph
if idx is None:
return src_val, None
shrink = layer.get_shrink_data()
if shrink is None: # 这里可能存在问题,在初始化的时候
data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
data.index_copy_(0, src_idx, src_val)
return data, src_idx
else:
push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
layer.sync_push_work(push_work)
data = layer.sync_pull_work()
sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
data.index_copy_(0, sidx, sval)
return data, sidx
# 先异步更新缓存
# 同时异步下载缓存
# 然后本地和远程缓存混合,作为最终结果返回
def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
return self.layers[i].sync_pull_work()
# def _update_cache_impl(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# # communcation
# state = self.layers[i].state
# src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # push latest embeddings
# data = self.layers[i].data
# if src_idx is None:
# data[:] = src_val
# else:
# data[src_idx] = src_val
# # get previous generated shrink data
# shr: ShrinkData = self.layers[i].shrink
# if shr is None:
# return src_val, src_idx
# else:
# # pull latest embeddings
# return data[shr.src_idx], shr.src_idx
# def _update_cache_offload(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# raise NotImplementedError
# def _compute_cached_data_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cached_data_size = torch.Size([self.src_size,])
# else:
# cached_data_size = torch.Size([self.src_size, num_features])
# return cached_data_size
# def _compute_cpu_buf_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cpu_buf_size = torch.Size([2**32,])
# else:
# cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# return cpu_buf_size
\ No newline at end of file
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.autograd as autograd # import torch.autograd as autograd
from torch import Tensor # from torch import Tensor
from contextlib import contextmanager # from contextlib import contextmanager
from typing import * # from typing import *
from .route import Route, GatherWork # from .route import Route, GatherWork
from .cache import EmbeddingCache # # from .cache import CachedEmbeddings
class GatherContext: # class GatherContext:
def __init__(self, # def __init__(self,
this, # this,
route: Route, # route: Route,
async_op: bool, # async_op: bool,
) -> None: # ) -> None:
self.this = this # self.this = this
self.route = route # self.route = route
self.async_op = async_op # self.async_op = async_op
class Gather(nn.Module): # class Gather(nn.Module):
def __init__(self, # def __init__(self,
num_nodes: int, # num_nodes: int,
num_features: Optional[int] = None, # num_features: Optional[int] = None,
beta: float = 1.0, # beta: float = 1.0,
) -> None: # ) -> None:
super().__init__() # super().__init__()
self.num_nodes = num_nodes # self.beta = beta
self.num_features = num_features # if num_features is None:
self.beta = beta # self.register_buffer("last_embd", torch.zeros(num_nodes))
if num_features is None: # else:
self.register_buffer("last_embd", torch.zeros(num_nodes)) # self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
else:
self.register_buffer("last_embd", torch.zeros(num_nodes, num_features)) # self.last_fw_work: Optional[GatherWork] = None
self.last_fw_work: Optional[GatherWork] = None # self.last_bw_work: Optional[GatherWork] = None
self.last_bw_work: Optional[GatherWork] = None # self.reset_parameters()
self.reset_parameters()
def forward(self, # def forward(self,
values: Tensor, # val: Tensor,
active: Optional[Tensor], # idx: Optional[Tensor],
route: Route, # route: Route,
async_op: bool = False, # async_op: bool = False,
) -> Tuple[Tensor, Tensor]: # ) -> Tuple[Tensor, Tensor]:
with self._gather_manager(route=route, async_op=async_op): # with self._manager(route=route, async_op=async_op):
return GatherFunction.apply(values, active) # return GatherFunction.apply(val, idx)
def reset_parameters(self): # def reset_parameters(self):
last_embd = self.get_buffer("last_embd") # last_embd = self.get_buffer("last_embd")
nn.init.normal_(last_embd, mean=0, std=1) # nn.init.normal_(last_embd, mean=0, std=1)
def fuse_embeddings(self, # def fuse_embeddings(self,
values: Tensor, # val: Tensor,
active: Optional[Tensor] = None, # idx: Optional[Tensor] = None,
inplace: bool = False, # inplace: bool = False,
) -> Tensor: # ) -> Tensor:
last_embd = self.get_buffer("last_embd") # last_embd = self.get_buffer("last_embd")
return GatherFuseFunction.apply(values, active, last_embd, self.beta, inplace, self.training) # return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
# if not inplace:
# last_embd = last_embd.clone()
# if active is None:
# if self.beta != 1.0:
# values = values * self.beta + last_embd * (1 - self.beta)
# last_embd[:] = values
# else:
# if values.size(0) == active.size(0):
# values = values[active]
# if self.beta != 1.0:
# values = values * self.beta + last_embd[active] * (1 - self.beta)
# last_embd[active] = values
# return last_embd
@contextmanager # @contextmanager
def _gather_manager(self, route: Route, async_op: bool): # def _manager(self, route: Route, async_op: bool):
global _global_gather_context # global _global_gather_context
stacked = _global_gather_context # stacked = _global_gather_context
try: # try:
_global_gather_context = GatherContext( # _global_gather_context = GatherContext(
this=self, # this=self,
route=route, # route=route,
async_op=async_op, # async_op=async_op,
) # )
yield _global_gather_context # yield _global_gather_context
finally: # finally:
_global_gather_context = stacked # _global_gather_context = stacked
class GatherFunction(autograd.Function): # class GatherFunction(autograd.Function):
@staticmethod # @staticmethod
def forward( # def forward(
ctx: autograd.function.FunctionCtx, # ctx: autograd.function.FunctionCtx,
values: Tensor, # val: Tensor,
active: Optional[Tensor] = None, # idx: Optional[Tensor] = None,
): # ):
gather_ctx = _last_global_gather_context() # gather_ctx = _last_global_gather_context()
this: Gather = gather_ctx.this # this: Gather = gather_ctx.this
route: Route = gather_ctx.route # route: Route = gather_ctx.route
async_op: bool = gather_ctx.async_op # async_op: bool = gather_ctx.async_op
# return_idx: bool = idx is not None
current_work = route.gather_forward(values, active, async_op=async_op) # current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
if async_op: # if async_op:
work = this.last_fw_work or current_work # work = this.last_fw_work or current_work
this.last_fw_work = current_work # this.last_fw_work = current_work
else: # else:
work = current_work # work = current_work
recv_values, recv_active = work.get() # recv_val, recv_idx = work.get()
recv_values = this.fuse_embeddings( # recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
values=recv_values, # recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
active=recv_active,
inplace=True,
)
if this.training: # if this.training:
# 如果输入的values是收缩过的,求解梯度的时候需要去除空洞 # ctx.save_for_backward(idx, recv_idx)
if active is not None and values.size(0) < active.size(0): # ctx.this = this
ctx.shrink_grad = True # ctx.route = route
ctx.save_for_backward(active, recv_active) # ctx.async_op = async_op
else: # return recv_val, recv_idx
ctx.shrink_grad = False
ctx.save_for_backward(recv_active)
ctx.this = this
ctx.route = route
ctx.async_op = async_op
return recv_values, recv_active
@staticmethod # @staticmethod
def backward( # def backward(
ctx: autograd.function.FunctionCtx, # ctx: autograd.function.FunctionCtx,
grad_values: Tensor, # val_grad: Tensor,
grad_active: Optional[Tensor], # idx_grad: Optional[Tensor],
): # ):
this: Gather = ctx.this # this: Gather = ctx.this
route: Route = ctx.route # route: Route = ctx.route
async_op: bool = ctx.async_op # async_op: bool = ctx.async_op
shrink_grad: bool = ctx.shrink_grad
with torch.no_grad(): # with torch.no_grad():
# # 反向传播激活值是沿着前向传播的反方向进行 # recv_idx, idx_grad = ctx.saved_tensors
if shrink_grad: # if idx_grad is not None:
recv_active, grad_active = ctx.saved_tensors # val_grad = val_grad[idx_grad]
else:
grad_active, = ctx.saved_tensors
current_work = route.gather_backward(grad_values, grad_active, async_op=async_op) # current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
if async_op: # if async_op:
work = this.last_bw_work or current_work # work = this.last_bw_work or current_work
this.last_bw_work = current_work # this.last_bw_work = current_work
else: # else:
work = current_work # work = current_work
recv_values = work.get_values()
if shrink_grad: # recv_val = work.get_val()
recv_values = recv_values[recv_active] # if recv_idx is not None:
return recv_values, None # recv_val = recv_val[recv_idx]
# return recv_val, None
# class GatherFuseFunction(autograd.Function):
class GatherFuseFunction(autograd.Function): # @staticmethod
@staticmethod # def forward(
def forward( # ctx: autograd.function.FunctionCtx,
ctx: autograd.function.FunctionCtx, # val: Tensor,
values: Tensor, # idx: Optional[Tensor],
active: Optional[Tensor], # last_embd: Tensor,
last_embd: Tensor, # beta: float,
beta: float, # inplace: bool,
inplace: bool, # training: bool,
training: bool, # ):
): # if not inplace:
if not inplace: # last_embd = last_embd.clone()
last_embd = last_embd.clone() # ctx.beta = beta
ctx.beta = beta
if active is None: # if idx is None:
if beta != 1.0: # assert val.size(0) == last_embd.size(0)
values = values * beta + last_embd * (1 - beta) # if beta != 1.0:
last_embd[:] = values # last_embd.mul_(1 - beta).add_(val * beta)
ctx.shrink_grad = False # else:
else: # last_embd[:] = (val)
if values.size(0) == active.size(0): # else:
values = values[active] # assert val.size(0) == idx.size(0)
ctx.shrink_grad = False # if beta != 1.0:
else: # last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
if training: # else:
ctx.save_for_backward(active) # last_embd[idx] = val
ctx.shrink_grad = True
# if training:
if beta != 1.0: # ctx.beta = beta
values = values * beta + last_embd[active] * (1 - beta) # ctx.save_for_backward(idx)
last_embd[active] = values # return last_embd
return last_embd
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
if ctx.shrink_grad:
active, = ctx.saved_tensors
return grad[active] * ctx.beta, None, None, None, None, None
else:
return grad * ctx.beta, None, None, None, None, None
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# beta: float = ctx.beta
# idx, = ctx.saved_tensors
# if idx is not None:
# grad = grad[idx]
# if beta != 1.0:
# grad = grad * beta
# return grad, None, None, None, None, None
#### private functions # #### private functions
_global_gather_context: Optional[GatherContext] = None # _global_gather_context: Optional[GatherContext] = None
def _last_global_gather_context() -> GatherContext: # def _last_global_gather_context() -> GatherContext:
global _global_gather_context # global _global_gather_context
assert _global_gather_context is not None # assert _global_gather_context is not None
return _global_gather_context # return _global_gather_context
import torch import torch
import torch.autograd as autograd
import torch.distributed as dist import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from contextlib import contextmanager
from .a2a import all_to_all, Works, with_nccl from .a2a import all_to_all, Works, with_nccl
from .event import Event
# class GatherWork:
# def __init__(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# recv_val_idx: List[Tensor],
# recv_val_dat: List[Tensor],
# works_list: List[Works],
# ) -> None:
# self._waited = False
# self._values_buffer = values_buffer
# self._active_buffer = active_buffer
# self._recv_val_idx = recv_val_idx
# self._recv_val_dat = recv_val_dat
# self._works_list = works_list
# self._cached_idx: Optional[Tensor] = None
# def _wait(self) -> None:
# assert not self._waited
# for w in self._works_list:
# if w is None:
# continue
# w.wait()
# if self._active_buffer is None:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# else:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# self._active_buffer[idx] = True
# self._active_buffer = torch.where(self._active_buffer)[0]
# self._works_list = []
# self._recv_val_dat = []
# self._recv_val_idx = []
# self._waited = True
# def get_val(self) -> Tensor:
# if not self._waited:
# self._wait()
# return self._values_buffer
# def get_idx(self) -> Optional[Tensor]:
# if not self._waited:
# self._wait()
# return self._active_buffer
# def get(self) -> Tuple[Tensor, Optional[Tensor]]:
# return self.get_val(), self.get_idx()
class RouteWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
raise NotImplementedError
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class GatherWork: class RouteWork(RouteWorkBase):
def __init__(self, def __init__(self,
values_buffer: Tensor, buf: Union[Tensor, int],
active_buffer: Optional[Tensor],
recv_val_idx: List[Tensor], recv_val_idx: List[Tensor],
recv_val_dat: List[Tensor], recv_val_dat: List[Tensor],
works_list: List[Works], works_list: List[Works],
) -> None: ) -> None:
self._waited = False super().__init__()
self._values_buffer = values_buffer assert len(recv_val_idx) != 0
self._active_buffer = active_buffer assert len(recv_val_idx) == len(recv_val_dat)
self._buf = buf
self._recv_val_idx = recv_val_idx self._recv_val_idx = recv_val_idx
self._recv_val_dat = recv_val_dat self._recv_val_dat = recv_val_dat
self._works_list = works_list self._works_list = works_list
self._val: Optional[Tensor] = None
self._idx: Optional[Tensor] = None
def _wait(self) -> None: def wait(self) -> None:
assert not self._waited if self._val is not None:
return
for w in self._works_list: for w in self._works_list:
if w is None: if w is None:
continue continue
w.wait() w.wait()
if self._active_buffer is None: ikw = dict(dtype=self._recv_val_idx[0].dtype, device=self._recv_val_idx[0].device)
fkw = dict(dtype=self._recv_val_dat[0].dtype, device=self._recv_val_dat[0].device)
if isinstance(self._buf, Tensor):
for idx in self._recv_val_idx:
self._buf[idx] = True
self._idx = torch.where(self._buf)[0]
imp = torch.empty(self._buf.size(0), **ikw).fill_((2**62-1)*2+1)
imp[self._idx] = torch.arange(self._idx.size(0), **ikw)
s = self._idx.shape[:1] + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat): for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat idx = imp[idx]
self._val[idx] += dat
else: else:
s = (self._buf,) + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat): for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat self._val[idx] += dat
self._active_buffer[idx] = True
self._waited = True self._buf = None
self._recv_val_dat = None
def get_values(self) -> Tensor: self._recv_val_idx = None
if not self._waited: self._works_list = None
self._wait()
return self._values_buffer if self._events is not None:
self._events[1].record()
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
if self._val is None:
self.wait()
return self._val, self._idx
class RouteExecutorWork(RouteWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
self._events: Optional[Tuple[Event, Event]] = None
def get_active(self) -> Optional[Tensor]: def wait(self) -> None:
if not self._waited: if self._events is not None:
self._wait() self._events[1].wait(torch.cuda.current_stream())
return self._active_buffer return self._handle.wait()
def get(self) -> Tuple[Tensor, Optional[Tensor]]: def get(self) -> Tuple[Tensor, Optional[Tensor]]:
return self.get_values(), self.get_active() if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class RouteExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
self._group = dist.new_group()
def async_forward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.forward_a2a, val, idx)
def async_backward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.backward_a2a, val, idx)
def _async_a2a_impl(self, func, val: Tensor, idx: Optional[Tensor]) -> RouteWorkBase:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
ret = func(val, idx, group=self._group).get()
end.record()
return ret
start.record()
handle = self._executor.apply_async(run)
return RouteExecutorWork(handle).set_events(start, end)
# def apply_async(self, func, val, idx, async_op) -> RouteWorkBase:
# start = Event(val.is_cuda)
# end = Event(val.is_cuda)
# if not async_op:
# start.record()
# return func(val, idx).set_events(start, end)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# ret = func(val, idx, group=self._group).get()
# end.record()
# return ret
# start.record()
# handle = self._executor.apply_async(run)
# return RouteExecutorWork(handle).set_events(start, end)
class RouteCtx:
def __init__(self, route, state, async_op) -> None:
self.route = route
self.state = state
self.async_op = async_op
class Route: class Route:
def __init__(self, def __init__(self,
src_ids: Tensor, src_ids: Tensor,
dst_ids: Tensor, dst_ids: Tensor,
bipartite: bool = False,
): ):
assert src_ids.dtype == torch.long assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1 assert src_ids.dim() == 1
assert dst_ids.dim() == 1 assert dst_ids.dim() == 1
if not bipartite:
# 要求数据为点分割,即所有分区的src_ids正交
assert src_ids.size(0) <= dst_ids.size(0)
assert (dst_ids[:src_ids.size(0)] == src_ids).all()
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -135,146 +307,347 @@ class Route: ...@@ -135,146 +307,347 @@ class Route:
self.backward_routes.append(bw_route) self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表 # 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _p2p_recv(self.forward_routes) self.forward_routes = _fix_fw_routes(self.forward_routes)
# 满足同构图条件,则每个点添加自环 # 满足同构图条件,则每个点添加自环
if src_ids.size(0) <= dst_ids.size(0): if not bipartite:
t = dst_ids[:src_ids.size(0)] == src_ids rank_ind = torch.arange(src_ids.size(0), **ikw)
if t.all(): rank_route = torch.vstack([rank_ind, rank_ind])
rank_ind = torch.arange(src_ids.size(0), **ikw) self.forward_routes[rank] = rank_route
rank_route = torch.vstack([rank_ind, rank_ind]) self.backward_routes[rank] = rank_route
self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route self.total_time_used: float = 0.0
dist.barrier() dist.barrier()
def _gather_impl(self, def to(self, device):
values_buffer: Tensor, self.forward_routes = [ro.to(device) for ro in self.forward_routes]
active_buffer: Optional[Tensor], self.backward_routes = [ro.to(device) for ro in self.backward_routes]
values: Tensor, return self
active: Optional[Tensor],
def _a2a_impl(self,
val: Tensor,
idx: Optional[Tensor],
send_buf_size: int,
recv_buf_size: int,
send_routes: List[Tensor], send_routes: List[Tensor],
recv_routes: List[Tensor], recv_routes: List[Tensor],
async_op: bool, group: Optional[Any],
): ) -> RouteWork:
ikw = dict(dtype=torch.long, device=values.device) bkw = dict(dtype=torch.bool, device=val.device)
fkw = dict(dtype=values.dtype, device=values.device) ikw = dict(dtype=torch.long, device=val.device)
fkw = dict(dtype=val.dtype, device=val.device)
# 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
if active is not None:
if values.size(0) < active.size(0):
values = _spread_values(values, active)
assert values.size(0) == active.size(0)
# 计算新的路由表
active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# 计算接收缓冲区大小
scatter_sizes = [ro.size(1) for ro in active_routes]
scatter_sizes = torch.tensor(scatter_sizes, **ikw)
gather_sizes = torch.zeros_like(scatter_sizes)
dist.all_to_all_single(gather_sizes, scatter_sizes)
gather_sizes = gather_sizes.tolist()
def async_run(): if idx is not None:
if active is not None: msk = torch.zeros(send_buf_size, **bkw).index_fill_(0, idx, 1)
send_val_idx, recv_val_idx = [], [] send_routes = [ro[:,msk[ro[0]]] for ro in send_routes]
send_val_dat, recv_val_dat = [], []
for i, s in enumerate(gather_sizes): send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
c = (s,) + values.shape[1:] recv_sizes = torch.zeros_like(send_sizes)
send_val_idx.append(active_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_val_dat.append(values[active_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_idx, send_val_idx),
all_to_all(recv_val_dat, send_val_dat),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
else:
send_val_dat, recv_val_dat = [], []
recv_val_idx = [ro[0] for ro in recv_routes]
for i, r in enumerate(recv_routes):
s = r.size(1)
c = (s,) + values.shape[1:]
send_val_dat.append(values[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_dat, send_val_dat),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
if async_op and with_nccl(): dist.all_to_all_single(recv_sizes, send_sizes, group=group)
stream = get_stream() recv_sizes = recv_sizes.tolist()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream): if idx is None:
return async_run() send_val_dat, recv_val_dat = [], []
recv_val_idx = [ro[0] for ro in recv_routes]
for i, ro in enumerate(recv_routes):
s = ro.size(1)
c = (s,) + val.shape[1:]
send_val_dat.append(val[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_dat, send_val_dat, group=group),
]
return RouteWork(
buf=recv_buf_size,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
else: else:
return async_run() imp = torch.empty(send_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[idx] = torch.arange(idx.size(0), **ikw)
def gather_forward(self,
values: Tensor, send_val_idx, recv_val_idx = [], []
active: Optional[Tensor], send_val_dat, recv_val_dat = [], []
async_op: bool = False, for i, s in enumerate(recv_sizes):
return_active: bool = True, c = (s,) + val.shape[1:]
):
bkw = dict(dtype=torch.bool, device=values.device) send_val_idx.append(send_routes[i][1])
fkw = dict(dtype=values.dtype, device=values.device) recv_val_idx.append(torch.zeros(s, **ikw))
s = (self.dst_size,) + values.shape[1:] send_index = imp[send_routes[i][0]]
values_buffer = torch.zeros(s, **fkw) send_val_dat.append(val[send_index])
if return_active: recv_val_dat.append(torch.zeros(c, **fkw))
active_buffer = torch.zeros(self.dst_size, **bkw)
works_list = [
all_to_all(recv_val_idx, send_val_idx, group=group),
all_to_all(recv_val_dat, send_val_dat, group=group),
]
recv_buf = torch.zeros(recv_buf_size, **bkw)
return RouteWork(
buf=recv_buf,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
def forward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.src_size
else: else:
active_buffer = None assert val.size(0) == idx.size(0)
return self._gather_impl( if group is None:
values_buffer=values_buffer, start = Event(use_cuda=val.is_cuda)
active_buffer=active_buffer, end = Event(use_cuda=val.is_cuda)
values=values, start.record()
active=active, work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.src_size,
recv_buf_size=self.dst_size,
send_routes=self.forward_routes, send_routes=self.forward_routes,
recv_routes=self.backward_routes, recv_routes=self.backward_routes,
async_op=async_op, group=group,
) )
if group is None:
return work.set_events(start, end)
return work
def gather_backward(self, def backward_a2a(self,
values: Tensor, val: Tensor,
active: Optional[Tensor], idx: Optional[Tensor] = None,
async_op: bool = False, group: Optional[Any] = None,
return_active: bool = True, ) -> RouteWork:
): if idx is None:
bkw = dict(dtype=torch.bool, device=values.device) assert val.size(0) == self.dst_size
fkw = dict(dtype=values.dtype, device=values.device)
s = (self.src_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
if return_active:
active_buffer = torch.zeros(self.src_size, **bkw)
else: else:
active_buffer = None assert val.size(0) == idx.size(0)
return self._gather_impl( if group is None:
values_buffer=values_buffer, start = Event(use_cuda=val.is_cuda)
active_buffer=active_buffer, end = Event(use_cuda=val.is_cuda)
values=values, start.record()
active=active, work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.dst_size,
recv_buf_size=self.src_size,
send_routes=self.backward_routes, send_routes=self.backward_routes,
recv_routes=self.forward_routes, recv_routes=self.forward_routes,
async_op=async_op, group=group,
) )
if group is None:
return work.set_events(start, end)
return work
def apply(self,
val: Tensor,
idx: Optional[Tensor],
state: Dict[str, Any],
async_op: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
with self._a2a_manager(state, async_op):
return RouteFunction.apply(val, idx)
@contextmanager
def _a2a_manager(self, state, async_op) -> RouteCtx:
global _global_route_context
stacked_ctx = _global_route_context
try:
_global_route_context = RouteCtx(self, state, async_op)
yield _global_route_context
finally:
_global_route_context = stacked_ctx
# def _gather_impl(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# values: Tensor,
# active: Optional[Tensor],
# send_routes: List[Tensor],
# recv_routes: List[Tensor],
# async_op: bool,
# ):
# ikw = dict(dtype=torch.long, device=values.device)
# fkw = dict(dtype=values.dtype, device=values.device)
# # 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
# if active is not None:
# # 计算新的路由表
# active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# # 计算接收缓冲区大小
# scatter_sizes = [ro.size(1) for ro in active_routes]
# scatter_sizes = torch.tensor(scatter_sizes, **ikw)
# gather_sizes = torch.zeros_like(scatter_sizes)
# dist.all_to_all_single(gather_sizes, scatter_sizes)
# gather_sizes = gather_sizes.tolist()
# def async_run():
# if active is not None:
# send_val_idx, recv_val_idx = [], []
# send_val_dat, recv_val_dat = [], []
# for i, s in enumerate(gather_sizes):
# c = (s,) + values.shape[1:]
# send_val_idx.append(active_routes[i][1])
# recv_val_idx.append(torch.zeros(s, **ikw))
# send_val_dat.append(values[active_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_idx, send_val_idx),
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# else:
# send_val_dat, recv_val_dat = [], []
# recv_val_idx = [ro[0] for ro in recv_routes]
# for i, r in enumerate(recv_routes):
# s = r.size(1)
# c = (s,) + values.shape[1:]
# send_val_dat.append(values[send_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# if async_op and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# return async_run()
# else:
# return async_run()
# def gather_forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.src_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.dst_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.dst_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.forward_routes,
# recv_routes=self.backward_routes,
# async_op=async_op,
# )
# def gather_backward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.dst_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.src_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.src_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.backward_routes,
# recv_routes=self.forward_routes,
# async_op=async_op,
# )
class RouteFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
val: Tensor,
idx: Optional[Tensor],
):
route_ctx = get_global_route_context()
route: Route = route_ctx.route
state: Dict[str, Any] = route_ctx.state
async_op: bool = route_ctx.async_op
# cur_work = get_executor().apply_async(route.forward_a2a, val, idx, async_op)
cur_work = get_executor().async_forward_a2a(val, idx, route)
if async_op:
# cur_work = get_executor().async_forward_a2a(val, idx, route)
work = state.get("__route_fw_work__") or cur_work
state["__route_fw_work__"] = cur_work
else:
# work = route.forward_a2a(val, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
ctx.route = route
ctx.state = state
ctx.async_op = async_op
ctx.save_for_backward(idx)
return val, idx
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
_: None,
):
route: Route = ctx.route
state: Dict[str, Any] = ctx.state
async_op: bool = ctx.async_op
with torch.no_grad():
idx, = ctx.saved_tensors
# cur_work = get_executor().apply_async(route.backward_a2a, grad, idx, async_op)
cur_work = get_executor().async_backward_a2a(grad, idx, route)
if async_op:
# cur_work = get_executor().async_backward_a2a(grad, idx, route)
work = state.get("__route_bw_work__") or cur_work
state["__route_bw_work__"] = cur_work
else:
# work = route.backward_a2a(grad, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
return val, idx
#### private functions #### private functions
...@@ -288,7 +661,7 @@ def _all_reduce_num_nodes( ...@@ -288,7 +661,7 @@ def _all_reduce_num_nodes(
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX) dist.all_reduce(max_ids, op=dist.ReduceOp.MAX)
return max_ids.item() + 1 return max_ids.item() + 1
def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]: def _fix_fw_routes(tensors: List[Tensor]) -> List[Tensor]:
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -305,19 +678,52 @@ def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]: ...@@ -305,19 +678,52 @@ def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]:
all_to_all(new_tensors, tensors).wait() all_to_all(new_tensors, tensors).wait()
return new_tensors return new_tensors
def _spread_values(values: Tensor, active: Tensor) -> Tensor: # def _spread_val_idx(val: Tensor, idx: Optional[Tensor], size: int) -> Tuple[Tensor, Optional[Tensor]]:
new_values = torch.zeros( # if idx is None:
size=active.shape[:1] + values.shape[1:], # assert val.size(0) == size
dtype=values.dtype, # return val, None
device=values.device, # else:
) # assert val.size(0) == idx.size(0)
new_values[active] = values # x = torch.zeros(
return new_values # size=(size,) + val.shape[1:],
# dtype=val.dtype,
# device=val.device,
# )
# x[idx] = val
# y = torch.zeros(
# size=(size,),
# dtype=torch.bool,
# device=idx.device,
# ).index_fill_(0, idx, 1)
# return x, y
_global_route_context: Optional[RouteCtx] = None
def get_global_route_context() -> RouteCtx:
global _global_route_context
assert _global_route_context is not None
return _global_route_context
# 创建一个新的stream,专用于Gather异步通信 # 创建一个新的stream,专用于Gather异步通信
_STREAM: Optional[torch.cuda.Stream] = None # _STREAM: Optional[torch.cuda.Stream] = None
def get_stream() -> torch.cuda.Stream: # def get_stream() -> torch.cuda.Stream:
global _STREAM # global _STREAM
if _STREAM is None: # if _STREAM is None:
_STREAM = torch.cuda.Stream() # _STREAM = torch.cuda.Stream()
return _STREAM # return _STREAM
\ No newline at end of file
# @contextmanager
# def stream_manager(enable: bool = True) -> Optional[torch.cuda.Stream]:
# if enable and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# yield stream
# else:
# yield
_THREAD_EXEC: Optional[RouteExecutor] = None
def get_executor() -> RouteExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = RouteExecutor()
return _THREAD_EXEC
import torch
from torch import Tensor
from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.autograd as autograd # import torch.autograd as autograd
from torch import Tensor # from torch import Tensor
from contextlib import contextmanager # from contextlib import contextmanager
from typing import * # from typing import *
class StraightContext:
def __init__(self, this) -> None: # class StraightContext:
self.this = this # def __init__(self, this, g) -> None:
# self.this = this
class Straight(nn.Module): # self.g = g
def __init__(self,
num_nodes: int, # class Straight(nn.Module):
p: int = 2, # def __init__(self,
dim: int = -1, # num_nodes: int,
beta: float = 1.0, # num_samples: int,
) -> None: # norm_kwargs: Optional[Dict[str, Any]] = None,
super().__init__() # beta: float = 1.0,
self.num_nodes = num_nodes # prev: Optional[List[Any]] = None,
self.norm_p = p # ) -> None:
self.norm_dim = dim # super().__init__()
self.norm_beta = beta # assert num_samples <= num_nodes
self.register_buffer("grad_norm", torch.ones(num_nodes)) # self.num_nodes = num_nodes
# self.num_samples = num_samples
# self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
# self.prev = prev
self.reset_parameters() # self.register_buffer("last_w", torch.ones(num_nodes))
# self._next_idx = None
# self._next_shrink_helper = None
# self.reset_parameters()
def reset_parameters(self): # def reset_parameters(self):
grad_norm = self.get_buffer("grad_norm") # last_w = self.get_buffer("last_w")
nn.init.constant_(grad_norm, 1.0) # nn.init.constant_(last_w, 1.0)
def forward(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor: # def forward(self,
with self._manager(self): # val: Tensor,
return StraightFunction.apply(values, active) # idx: Optional[Tensor],
# g,
# ) -> Tensor:
# with self._manager(self, g):
# return StraightFunction.apply(val, idx)
def multinomial(self, # def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
num_samples: int, # if not self.training:
replacement: bool = False, # return None, None
) -> Tensor:
w = self.get_buffer("grad_norm")
if num_samples <= 0:
return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# print(w) # next_idx = self._next_idx
w = w / w.sum() # self._next_idx = None
return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
# next_sh = self._next_shrink_helper
# self._next_shrink_helper = None
# return next_idx, next_sh
# def _sample_next(self) -> Tensor:
# w = self.get_buffer("last_w")
# if self._next_idx is None:
# if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
# else:
# self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
# elif self.num_samples < self._next_idx.size(0):
# idx = self.sample_impl(w[self._next_idx])
# self._next_idx = self._next_idx[idx]
# return self._next_idx
def multinomial_mask(self, # def sample_impl(self, w: Tensor) -> Tensor:
num_samples: int, # w = w / w.sum()
replacement: bool = False, # return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
) -> Tensor:
device = self.get_buffer("grad_norm").device # # def multinomial(self,
if num_samples <= 0: # # num_samples: int,
return torch.ones(self.num_nodes, dtype=torch.bool, device=device) # # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
w = self.multinomial(num_samples, replacement) # # w = w / w.sum()
m = torch.zeros(self.num_nodes, dtype=torch.bool, device=device) # # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
m[w] = True
return m
@contextmanager # # def multinomial_mask(self,
def _manager(self, this): # # num_samples: int,
global _global_straight_context # # replacement: bool = False,
stacked = _global_straight_context # # ) -> Tensor:
# # w = self.get_buffer("last_w")
try: # # if num_samples <= 0:
_global_straight_context = StraightContext(this=this) # # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
yield _global_straight_context
finally: # # w = self.multinomial(num_samples, replacement)
_global_straight_context = stacked # # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
# # return m
# @contextmanager
# def _manager(self, this, g):
# global _global_straight_context
# stacked = _global_straight_context
class StraightFunction(autograd.Function): # try:
@staticmethod # _global_straight_context = StraightContext(this=this, g=g)
def forward( # yield _global_straight_context
ctx: autograd.function.FunctionCtx, # finally:
values: Tensor, # _global_straight_context = stacked
active: Optional[Tensor],
):
stx = _last_global_straight_context() # class StraightFunction(autograd.Function):
this: Straight = stx.this # @staticmethod
if this.training: # def forward(
ctx.this = this # ctx: autograd.function.FunctionCtx,
ctx.save_for_backward(active) # val: Tensor,
return values # idx: Optional[Tensor],
# ):
@staticmethod # from ..graph import DistGraph
def backward(
ctx: autograd.function.FunctionCtx, # stx = _last_global_straight_context()
grad: Tensor, # this: Straight = stx.this
): # g: DistGraph = stx.g
this: Straight = ctx.this
active, = ctx.saved_tensors # last_w = this.get_buffer("last_w")
# if idx is None:
# assert val.size(0) == last_w.size(0)
# else:
# assert val.size(0) == idx.size(0)
# if this.training:
# ctx.this = this
# ctx.g = g
# ctx.save_for_backward(idx)
# return val
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
grad_norm = this.get_buffer("grad_norm") # if this.prev is not None or this._next_idx is None:
if this.norm_beta != 1.0: # from ..nn.convs.utils import ShrinkHelper
grad_norm.mul_(this.norm_beta) # dst_idx = this._sample_next()
if active is None: # if this.prev is not None and this._next_idx is not None:
x = grad.detach() # this._next_shrink_helper = ShrinkHelper(g, dst_idx)
x = x.norm(p=this.norm_p, dim=this.norm_dim)
grad_norm[:] = x # src_idx = this._next_shrink_helper.src_idx
else: # work = g.route.gather_backward(
if grad.size(0) == grad_norm.size(0): # src_idx, src_idx, async_op=False, return_idx=True)
x = grad.detach()[active]
else: # prev_dst_idx = work.get_idx()
x = grad.detach() # for p in this.prev:
x = x.norm(p=this.norm_p, dim=this.norm_dim) # assert isinstance(p, Straight)
grad_norm[active] = x # p._next_idx = prev_dst_idx
return grad, None # return grad, None
#### private functions # #### private functions
_global_straight_context: Optional[StraightContext] = None # _global_straight_context: Optional[StraightContext] = None
def _last_global_straight_context() -> StraightContext: # def _last_global_straight_context() -> StraightContext:
global _global_straight_context # global _global_straight_context
assert _global_straight_context is not None # assert _global_straight_context is not None
return _global_straight_context # return _global_straight_context
if __name__ == "__main__": # if __name__ == "__main__":
s = Straight(3, beta=1.1) # s = Straight(3, beta=1.1)
x = torch.rand(3, 10).requires_grad_() # x = torch.rand(3, 10).requires_grad_()
m = torch.tensor([0, 1, 0], dtype=torch.bool) # m = torch.tensor([0, 1, 0], dtype=torch.bool)
s(x).sum().backward() # s(x).sum().backward()
print(s.grad_norm) # print(s.grad_norm)
print(s.multinomial(2)) # print(s.multinomial(2))
s(x, m).sum().backward() # s(x, m).sum().backward()
print(s.grad_norm) # print(s.grad_norm)
print(s.multinomial(2)) # print(s.multinomial(2))
s(x[m], m).sum().backward() # s(x[m], m).sum().backward()
print(s.grad_norm) # print(s.grad_norm)
print(s.multinomial(2)) # print(s.multinomial(2))
print(s.multinomial_mask(2)) # print(s.multinomial_mask(2))
\ No newline at end of file \ No newline at end of file
...@@ -5,26 +5,23 @@ from torch import Tensor ...@@ -5,26 +5,23 @@ from torch import Tensor
from typing import * from typing import *
from contextlib import contextmanager from contextlib import contextmanager
import torch_sparse # import torch_sparse
from .ndata import NData from .ndata import NData
from .edata import EData from .edata import EData
from .utils import init_local_edge_index from .utils import init_local_edge_index
from ..core import Route from ..core import MessageCache, Route
class DistGraph: class DistGraph:
def __init__(self, def __init__(self,
ids: Tensor, ids: Tensor,
edge_index: Tensor, edge_index: Tensor,
args: Dict[str, Any] = {}, num_features: int,
device: Union[torch.device, str, int, None] = None, num_layers: int,
cache_device: str = "cpu",
**args: Dict[str, Any],
): ):
# graph's target device
if device is None:
device = ids.device
self.device = torch.device(device)
# build local_edge_index # build local_edge_index
dst_ids = ids dst_ids = ids
src_ids, local_edge_index = init_local_edge_index( src_ids, local_edge_index = init_local_edge_index(
...@@ -34,8 +31,15 @@ class DistGraph: ...@@ -34,8 +31,15 @@ class DistGraph:
self._src_ids = src_ids self._src_ids = src_ids
self._dst_ids = dst_ids self._dst_ids = dst_ids
self._local_edge_index = local_edge_index self._message_cache = MessageCache(
self._local_edge_ptr: Optional[Tensor] = None src_ids=src_ids,
dst_ids=dst_ids,
edge_index=local_edge_index,
num_features=num_features,
num_layers=num_layers,
cache_device=cache_device,
bipartite=False,
)
# node's attributes # node's attributes
self.ndata = NData( self.ndata = NData(
...@@ -49,24 +53,30 @@ class DistGraph: ...@@ -49,24 +53,30 @@ class DistGraph:
) )
# graph's attributes # graph's attributes
self.args = args self.args = dict(args)
def to(self, device):
self._message_cache.to(device)
return self
def cache_data_to(self, device):
self._message_cache.cached_data_to(device)
@property
def cache(self) -> MessageCache:
return self._message_cache
# route table @property
self.route = Route(dst_ids, src_ids) def route(self) -> Route:
return self._message_cache.route
@property @property
def edge_index(self) -> Tensor: def device(self) -> torch.device:
if self._local_edge_index.device != self.device: return self.edge_index.device
self._local_edge_index = self._local_edge_index.to(self.device)
return self._local_edge_index
@property @property
def edge_ptr(self) -> Optional[Tensor]: def edge_index(self) -> Tensor:
if self._local_edge_ptr is None: return self._message_cache.edge_index
return None
if self._local_edge_ptr.device != self.device:
self._local_edge_ptr = self._local_edge_ptr.to(self.device)
return self._local_edge_ptr
@property @property
def src_ids(self) -> Tensor: def src_ids(self) -> Tensor:
...@@ -88,16 +98,6 @@ class DistGraph: ...@@ -88,16 +98,6 @@ class DistGraph:
def dst_size(self) -> int: def dst_size(self) -> int:
return self._dst_ids.size(0) return self._dst_ids.size(0)
@torch.no_grad()
def gather(self, values: Tensor, direct: str = "dst_to_src") -> Tensor:
if direct == "dst_to_src":
work = self.route.gather_forward(values, None, async_op=False, return_active=False)
elif direct == "src_to_dst":
work = self.route.gather_backward(values, None, async_op=False, return_active=False)
else:
raise ValueError(f"direct must be 'src_to_dst' or 'dst_to_src', but got '{direct}'")
return work.get_values()
@contextmanager @contextmanager
def scoped_manager(self): def scoped_manager(self):
stacked_ndata = self.ndata stacked_ndata = self.ndata
......
...@@ -6,6 +6,7 @@ from typing import * ...@@ -6,6 +6,7 @@ from typing import *
def init_local_edge_index( def init_local_edge_index(
dst_ids: Tensor, dst_ids: Tensor,
edge_index: Tensor, edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index) max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device) ikw = dict(dtype=torch.long, device=dst_ids.device)
...@@ -17,17 +18,26 @@ def init_local_edge_index( ...@@ -17,17 +18,26 @@ def init_local_edge_index(
if not (xmp != 0x01).all(): if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph") raise RuntimeError(f"must be vertex-cut partition graph")
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids] if bipartite:
xmp.fill_(0) src_ids = edge_index[0].unique()
xmp[edge_index[0]] = 1 else:
xmp[dst_ids] = 0 # 假设是同构图
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1) # src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引 # 计算局部索引
xmp.fill_((2**62-1)*2+1) xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw) xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
local_edge_index = xmp[edge_index.flatten()].view_as(edge_index) src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int: def calc_max_ids(*ids: Tensor) -> int:
......
from .convs import GCNConv, GATConv, GINConv from .convs import GCNConv, GATConv, GINConv
from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv # from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicJKOptions, BasicStraightOptions # from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
class GCN(BasicGNN): # class ShrinkGCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGCNConv(in_channels, out_channels, **kwargs) # return ShrinkGCNConv(in_channels, out_channels, **kwargs)
class GAT(BasicGNN): # class ShrinkGAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGATConv(in_channels, out_channels, **kwargs) # return ShrinkGATConv(in_channels, out_channels, **kwargs)
class GIN(BasicGNN): # class ShrinkGIN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGINConv(in_channels, out_channels, **kwargs) # return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.nn.functional as F # import torch.nn.functional as F
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
import copy # import copy
import inspect # import inspect
from torch_geometric.nn import JumpingKnowledge # from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver # from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from dataclasses import dataclass # from dataclasses import dataclass
from ..graph.distgraph import DistGraph # from ..core.cache import NodeProbe
from ..core.gather import Gather # from ..graph.distgraph import DistGraph
from ..core.straight import Straight # # from ..core.gather import Gather
# # from ..core.straight import Straight
@dataclass # @dataclass
class BasicInputOptions: # class BasicInputOptions:
weight: bool = True # weight: bool = True
bias: bool = True # bias: bool = True
gather: bool = True # gather: bool = True
gather_first: bool = False # gather_first: bool = False
dropout: float = 0.0 # dropout: float = 0.0
act: Optional[str] = None # act: Optional[str] = None
act_kwargs: Optional[Dict[str, Any]] = None # act_kwargs: Optional[Dict[str, Any]] = None
act_first: bool = False # act_first: bool = False
norm: Optional[str] = None # norm: Optional[str] = None
norm_kwargs: Optional[Dict[str, Any]] = None # norm_kwargs: Optional[Dict[str, Any]] = None
straight_enabled: bool = True # straight_enabled: bool = True
# straight_num_samples: Optional[int] = None
@dataclass # @dataclass
class BasicLayerOptions: # class BasicLayerOptions:
in_channels: int # in_channels: int
hidden_channels: int # hidden_channels: int
num_layers: int # num_layers: int
out_channels: Optional[int] = None # out_channels: Optional[int] = None
gather_beta: float = 1.0 # gather_beta: float = 1.0
dropout: float = 0.0 # dropout: float = 0.0
act: Optional[str] = "relu" # act: Optional[str] = "relu"
act_kwargs: Optional[Dict[str, Any]] = None # act_kwargs: Optional[Dict[str, Any]] = None
act_first: bool = False # act_first: bool = False
norm: Optional[str] = None # norm: Optional[str] = None
norm_kwargs: Optional[Dict[str, Any]] = None # norm_kwargs: Optional[Dict[str, Any]] = None
# jk_mode: Optional[str] = None
@dataclass # @dataclass
class BasicJKOptions: # class BasicStraightOptions:
jk_mode: Optional[str] = None # enabled: bool = False
# num_samples: Optional[int] = None
# beta: float = 1.0
@dataclass # class BasicGNN(nn.Module):
class BasicStraightOptions: # def __init__(
enabled: bool = False # self,
p: int = 2 # g: DistGraph,
beta: float = 1.0 # layer_options: BasicLayerOptions,
# input_options: BasicInputOptions = BasicInputOptions(),
class BasicGNN(nn.Module): # straight_options: BasicStraightOptions = BasicStraightOptions(),
def __init__( # **kwargs,
self, # ):
g: DistGraph, # super().__init__()
layer_options: BasicLayerOptions,
input_options: BasicInputOptions = BasicInputOptions(),
jk_options: BasicJKOptions = BasicJKOptions(),
straight_options: BasicStraightOptions = BasicStraightOptions(),
**kwargs,
):
super().__init__()
self.in_channels = in_channels = layer_options.in_channels # num_samples = straight_options.num_samples or g.dst_size
self.hidden_channels = hidden_channels = layer_options.hidden_channels # prev_straight = None
self.num_layers = num_layers = layer_options.num_layers
# self.in_channels = in_channels = layer_options.in_channels
# self.hidden_channels = hidden_channels = layer_options.hidden_channels
# self.num_layers = num_layers = layer_options.num_layers
# default out_channels is hidden_channels # # default out_channels is hidden_channels
self.out_channels = out_channels = hidden_channels \ # self.out_channels = out_channels = hidden_channels \
if layer_options.out_channels is None else layer_options.out_channels # if layer_options.out_channels is None else layer_options.out_channels
self.gather_beta = gather_beta = layer_options.gather_beta # self.gather_beta = gather_beta = layer_options.gather_beta
self.layer_options = layer_options # self.layer_options = layer_options
self.input_options = input_options # self.input_options = input_options
self.jk_options = jk_options # self.straight_options = straight_options
self.straight_options = straight_options
# initialize input layer # # initialize input layer
if input_options.weight: # if input_options.weight:
self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias) # self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias)
in_channels = hidden_channels # in_channels = hidden_channels
if straight_options.enabled and input_options.straight_enabled: # if straight_options.enabled and input_options.straight_enabled and not input_options.gather_first:
if input_options.gather_first: # sns = input_options.straight_num_samples or g.dst_size
self.straight_x = Straight( # self.straight_x = Straight(g.dst_size, sns, beta=straight_options.beta, prev=prev_straight)
g.src_size, p=straight_options.p, beta=straight_options.beta) # prev_straight = [self.straight_x]
else:
self.straight_x = Straight(
g.dst_size, p=straight_options.p, beta=straight_options.beta)
if input_options.gather: # if input_options.gather:
self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta) # self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta)
if input_options.act is not None: # if input_options.act is not None:
self.act_x = activation_resolver( # self.act_x = activation_resolver(
input_options.act, **(input_options.act_kwargs or {})) # input_options.act, **(input_options.act_kwargs or {}))
if input_options.norm is not None: # if input_options.norm is not None:
self.norm_x = normalization_resolver( # self.norm_x = normalization_resolver(
input_options.norm, in_channels, **(input_options.norm_kwargs or {})) # input_options.norm, in_channels, **(input_options.norm_kwargs or {}))
# initialize activation layers # # initialize activation layers
if layer_options.act is not None: # if layer_options.act is not None:
self.acts = nn.ModuleList() # self.acts = nn.ModuleList()
for _ in range(num_layers - 1): # for _ in range(num_layers - 1):
self.acts.append(activation_resolver( # self.acts.append(activation_resolver(
layer_options.act, **(layer_options.act_kwargs or {}))) # layer_options.act, **(layer_options.act_kwargs or {})))
if jk_options.jk_mode is not None: # if layer_options.jk_mode is not None:
self.acts.append(activation_resolver( # self.acts.append(activation_resolver(
layer_options.act, **(layer_options.act_kwargs or {}))) # layer_options.act, **(layer_options.act_kwargs or {})))
# initialize normalization layers # # initialize normalization layers
if layer_options.norm is not None: # if layer_options.norm is not None:
self.norms = nn.ModuleList() # self.norms = nn.ModuleList()
for _ in range(num_layers - 1): # for _ in range(num_layers - 1):
self.norms.append(normalization_resolver( # self.norms.append(normalization_resolver(
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {}))) # layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
if jk_options.jk_mode is not None: # if layer_options.jk_mode is not None:
self.norms.append(normalization_resolver( # self.norms.append(normalization_resolver(
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {}))) # layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# initialize straight layers # # initialize straight layers
if straight_options.enabled: # if straight_options.enabled:
self.straights = nn.ModuleList() # self.straights = nn.ModuleList()
for _ in range(num_layers - 1): # for _ in range(num_layers):
self.straights.append(Straight( # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
g.dst_size, p=straight_options.p, beta=straight_options.beta)) # self.straights.append(prev_straight[0])
if jk_options.jk_mode is not None: # # if layer_options.jk_mode is not None:
self.straights.append(Straight( # # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
g.dst_size, p=straight_options.p, beta=straight_options.beta)) # # self.straights.append(prev_straight[0])
# initialize gather and conv layers # # initialize gather and conv layers
self.convs = nn.ModuleList() # self.convs = nn.ModuleList()
self.gathers = nn.ModuleList() # self.gathers = nn.ModuleList()
for _ in range(num_layers - 1): # for _ in range(num_layers - 1):
self.convs.append( # self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs)) # self.init_conv(in_channels, hidden_channels, **kwargs))
self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta)) # self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta))
in_channels = hidden_channels # in_channels = hidden_channels
if jk_options.jk_mode is None: # if layer_options.jk_mode is None:
self.convs.append( # self.convs.append(
self.init_conv(in_channels, out_channels, **kwargs)) # self.init_conv(in_channels, out_channels, **kwargs))
self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings # self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings
else: # else:
self.convs.append( # self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs)) # self.init_conv(in_channels, hidden_channels, **kwargs))
self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings # self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings
if jk_options.jk_mode != "last": # if layer_options.jk_mode != "last":
self.jk = JumpingKnowledge(jk_options.jk_mode, hidden_channels, num_layers) # self.jk = JumpingKnowledge(layer_options.jk_mode, hidden_channels, num_layers)
if jk_options.jk_mode == "cat": # if layer_options.jk_mode == "cat":
jk_channels = num_layers * hidden_channels # jk_channels = num_layers * hidden_channels
else: # else:
jk_channels = hidden_channels # jk_channels = hidden_channels
self.lin_jk = nn.Linear(jk_channels, out_channels) # self.lin_jk = nn.Linear(jk_channels, out_channels)
self.reset_parameters() # self.reset_parameters()
def init_conv(self, # def init_conv(self,
in_channels: int, # in_channels: int,
out_channels: int, # out_channels: int,
**kwargs # **kwargs
) -> nn.Module: # ) -> nn.Module:
raise NotImplementedError # raise NotImplementedError
def reset_parameters(self): # def reset_parameters(self):
if hasattr(self, "lin_x"): # if hasattr(self, "lin_x"):
self.lin_x.reset_parameters() # self.lin_x.reset_parameters()
if hasattr(self, "straight_x"): # if hasattr(self, "straight_x"):
self.straight_x.reset_parameters() # self.straight_x.reset_parameters()
if hasattr(self, "gather_x"): # if hasattr(self, "gather_x"):
self.gather_x.reset_parameters() # self.gather_x.reset_parameters()
if hasattr(self, "norm_x"): # if hasattr(self, "norm_x"):
self.norm_x.reset_parameters() # self.norm_x.reset_parameters()
if hasattr(self, "norms"): # if hasattr(self, "norms"):
for norm in self.norms: # for norm in self.norms:
norm.reset_parameters() # norm.reset_parameters()
if hasattr(self, "straights"): # if hasattr(self, "straights"):
for straight in self.straights: # for straight in self.straights:
straight.reset_parameters() # straight.reset_parameters()
for conv in self.convs: # for conv in self.convs:
conv.reset_parameters() # conv.reset_parameters()
for gather in self.gathers: # for gather in self.gathers:
gather.reset_parameters() # gather.reset_parameters()
if hasattr(self, "jk"): # if hasattr(self, "jk"):
self.jk.reset_parameters() # self.jk.reset_parameters()
if hasattr(self, "lin_jk"): # if hasattr(self, "lin_jk"):
self.lin_jk.reset_parameters() # self.lin_jk.reset_parameters()
def forward( # def forward(
self, # self,
g: DistGraph, # g: DistGraph,
): # ):
# from ..utils.printer import main_print # from ..utils.printer import main_print, sync_print
x = g.ndata["x"]
# main_print(f"features_x: {x.size()}")
sample_k = g.args.get("sample_k", 0) # x = g.ndata["x"]
async_op = g.args.get("async_op", False) # async_op = g.args.get("async_op", False)
# input layer # # input layer
if hasattr(self, "gather_x") and self.input_options.gather_first: # if hasattr(self, "straight_x"):
x, m = self.gather_x(x, None, g.route, async_op=async_op) # dst_idx, _ = self.straight_x.pop_next_shrink_helper()
# else:
# dst_idx = None
# if dst_idx is not None:
# x = x[dst_idx]
# if hasattr(self, "gather_x") and self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
if self.input_options.dropout != 0.0: # if self.input_options.dropout != 0.0:
x = F.dropout(x, p=self.input_options.dropout, training=self.training) # x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# main_print(f"x_x: {x.size()} {x.requires_grad}")
if hasattr(self, "lin_x"):
# main_print("enter lin_x")
x = self.lin_x(x)
# main_print(f"lin_x: {x.size()} {x.requires_grad}")
if hasattr(self, "act_x") and self.input_options.act_first:
x = self.act_x(x)
if hasattr(self, "norm_x"):
x = self.norm_x(x)
if hasattr(self, "act_x") and not self.input_options.act_first:
x = self.act_x(x)
# main_print(f"act_x: {x.size()} {x.requires_grad}") # if hasattr(self, "lin_x"):
# x = self.lin_x(x)
# straight sampler # if hasattr(self, "act_x") and self.input_options.act_first:
if hasattr(self, "straight_x"): # x = self.act_x(x)
x = self.straight_x(x) # if hasattr(self, "norm_x"):
m = self.straight_x.multinomial_mask(sample_k) # x = self.norm_x(x)
else: # if hasattr(self, "act_x") and not self.input_options.act_first:
m = None # x = self.act_x(x)
# main_print(f"straight_x: {x.size()} {'' if m is None else m.size()} {x.requires_grad}")
# gather features # # straight sampler
if hasattr(self, "gather_x") and not self.input_options.gather_first: # if hasattr(self, "straight_x"):
x, m = self.gather_x(x, m, g.route, async_op=async_op) # # sync_print(f"{x.size()} - {'' if dst_idx is None else dst_idx.size()}")
# x = self.straight_x(x, dst_idx, g)
# main_print(f"rg {0}: {x.requires_grad}") # # gather features
# if hasattr(self, "gather_x") and not self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# conv layers # # conv layers
xs: List[Tensor] = [] # xs: List[Tensor] = []
for i in range(self.num_layers): # for i in range(self.num_layers):
if self.layer_options.dropout != 0.0: # if self.layer_options.dropout != 0.0:
x = F.dropout(x, p=self.layer_options.dropout, training=self.training) # x = F.dropout(x, p=self.layer_options.dropout, training=self.training)
with g.scoped_manager(): # if hasattr(self, "straights"):
g.ndata["x"] = x # straight: Straight = self.straights[i]
if m is not None: # dst_idx, sh = straight.pop_next_shrink_helper()
g.ndata["m"] = m
x, m = self.convs[i](g) # with g.scoped_manager():
# g.ndata["x"] = x
# main_print(f"rg {i+1}: {x.requires_grad}") # x = self.convs[i](g, sh=sh, dst_idx=dst_idx)
if i == self.num_layers - 1 and not hasattr(self, "jk"): # if i == self.num_layers - 1 and not hasattr(self, "jk"):
x = self.gathers[i].fuse_embeddings(x, m, inplace=True) # x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
break # break
# if hasattr(self, "acts") and self.layer_options.act_first:
# x = self.acts[i](x)
# if hasattr(self, "norms"):
# x = self.norms[i](x)
# if hasattr(self, "acts") and not self.layer_options.act_first:
# x = self.acts[i](x)
if hasattr(self, "acts") and self.layer_options.act_first: # if hasattr(self, "straights"):
x = self.acts[i](x) # x = self.straights[i](x, dst_idx, g)
if hasattr(self, "norms"):
x = self.norms[i](x)
if hasattr(self, "acts") and not self.layer_options.act_first:
x = self.acts[i](x)
if hasattr(self, "straights"):
x = self.straights[i](x, m)
if i == self.num_layers - 1: # if i == self.num_layers - 1:
x = self.gathers[i].fuse_embeddings(x, m, inplace=True) # x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
else: # else:
x, m = self.gathers[i](x, m, g.route, async_op=async_op) # x, _ = self.gathers[i](x, dst_idx, g.route, async_op=async_op)
if hasattr(self, "jk"): # if hasattr(self, "jk"):
xs.append(x[:g.dst_size]) # xs.append(x[:g.dst_size])
x = self.jk(xs) if hasattr(self, "jk") else x # x = self.jk(xs) if hasattr(self, "jk") else x
x = self.lin_jk(x) if hasattr(self, "lin_jk") else x # x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
# main_print(f"out: {x.size()}") # # sync_print(f"out: {x.size()}")
return x # return x
from .gcn_conv import GCNConv from .gcn_conv import GCNConv
from .gat_conv import GATConv from .gat_conv import GATConv
from .gin_conv import GINConv from .gin_conv import GINConv
from .shrink_gcn_conv import ShrinkGCNConv # from .shrink_gcn_conv import ShrinkGCNConv
from .shrink_gat_conv import ShrinkGATConv # from .shrink_gat_conv import ShrinkGATConv
from .shrink_gin_conv import ShrinkGINConv # from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
...@@ -10,7 +10,7 @@ from typing import * ...@@ -10,7 +10,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gat_conv import GATConv from .gat_conv import GATConv
from .utils import ShrinkData from .utils import ShrinkHelper
...@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv): ...@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels H, C = self.heads, self.out_channels
x = g.ndata["x"] x = g.ndata["x"]
s = ShrinkData(g, m) src_x = x[sh.src_idx]
x = x[s.src_mask] dst_x = x[sh.dst_idx]
edge_index = s.edge_index edge_index = sh.edges
x = (x @ self.weight).view(-1, H, C) src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1) alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]] alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1) alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]] alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None: if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"] edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[s.edge_mask] edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1: if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1) edge_attr = edge_attr.view(-1, 1)
...@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv): ...@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv):
alpha = softmax( alpha = softmax(
src=alpha, src=alpha,
index=edge_index[1], index=edge_index[1],
num_nodes=s.dst_size, num_nodes=sh.dst_size,
) )
alpha = F.dropout(alpha, p=self.dropout, training=self.training) alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1) x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size) x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat: if self.concat:
x = x.view(-1, H * C) x = x.view(-1, H * C)
...@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv): ...@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv):
if self.bias is not None: if self.bias is not None:
x += self.bias x += self.bias
return x, s.dst_mask return x
...@@ -8,7 +8,7 @@ from typing import * ...@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gcn_conv import GCNConv from .gcn_conv import GCNConv
from .utils import ShrinkData from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv): class ShrinkGCNConv(GCNConv):
def __init__(self, def __init__(self,
...@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv): ...@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"] x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1) gcn_norm = g.edata["gcn_norm"].view(-1, 1)
s = ShrinkData(g, m) x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
x = x[s.src_mask] edge_index = sh.edges
gcn_norm = gcn_norm[s.edge_mask]
edge_index = s.edge_index
x = x @ self.weight x = x @ self.weight
x = x[edge_index[0]] * gcn_norm x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size) x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None: if self.bias is not None:
x += self.bias x += self.bias
return x, s.dst_mask return x
...@@ -8,7 +8,7 @@ from typing import * ...@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gin_conv import GINConv from .gin_conv import GINConv
from .utils import ShrinkData from .utils import ShrinkHelper
...@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv): ...@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"] x = g.ndata["x"]
s = ShrinkData(g, m) sh = ShrinkHelper(g, dst_idx)
edge_index = s.edge_index
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
z = x[s.src_mask][edge_index[0]] z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
z = scatter_sum(z, edge_index[1], dim=0, dim_size=s.dst_size) x = z + (1 + self.eps) * dst_x
x = z + (1 + self.eps) * x[s.dst_mask] return self.nn(x)
return self.nn(x), s.dst_mask \ No newline at end of file
\ No newline at end of file
import torch # import torch
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from starrygl.graph import DistGraph
class ShrinkData: # class ShrinkHelper:
def __init__(self, g: DistGraph, m: Tensor) -> None: # def __init__(self, g, dst_idx: Tensor) -> None:
edge_index = g.edge_index # from starrygl.graph import DistGraph
# g: DistGraph = g
# src_to_dst edge's mask
m = m[edge_index[0]] # self.device = dst_idx.device
# dst's mask # dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
m = edge_index[1][m] # dst_m.index_fill_(0, dst_idx, 1)
dst_m: Tensor = torch.zeros(g.dst_size, dtype=torch.bool, device=m.device).index_fill_(0, m, 1)
self.dst_mask = dst_m
# dst_to_src edge's mask # edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
m = dst_m[edge_index[1]] # edge_index = g.edge_index[:, edge_idx]
self.edge_mask = m
# src's mask # src_idx = edge_index[0]
edge_index = edge_index[:,m] # src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
src_m: Tensor = torch.zeros(g.src_size, dtype=torch.bool, device=m.device).index_fill_(0, edge_index[0], 1) # src_m.index_fill_(0, src_idx, 1)
self.src_mask = src_m # src_idx = torch.where(src_m)[0]
self._src_size = src_m.count_nonzero() # imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
self._dst_size = dst_m.count_nonzero()
# print(f"{dst_m.size()}_{self._dst_size}") # imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# edge_index # imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
imp = torch.empty(g.src_size, dtype=torch.long, device=m.device) # dst = imp[edge_index[1]]
imp[src_m] = torch.arange(self._src_size, dtype=torch.long, device=m.device)
src = imp[edge_index[0]]
imp = torch.empty(g.dst_size, dtype=torch.long, device=m.device) # self.src_idx = src_idx
imp[dst_m] = torch.arange(self._dst_size, dtype=torch.long, device=m.device) # self.dst_idx = dst_idx
dst = imp[edge_index[1]] # self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
self.edge_index = torch.vstack([src, dst]) # @property
# def src_size(self) -> int:
@property # return self.src_idx.size(0)
def src_size(self) -> int:
return self._src_size.item()
@property # @property
def dst_size(self) -> int: # def dst_size(self) -> int:
return self._dst_size.item() # return self.dst_idx.size(0)
\ No newline at end of file
\ No newline at end of file
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from .degree import compute_degree, compute_gcn_norm from .degree import compute_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model( def convert_parallel_model(
net: nn.Module, net: nn.Module,
...@@ -20,7 +21,7 @@ def convert_parallel_model( ...@@ -20,7 +21,7 @@ def convert_parallel_model(
for name, buffer in net.named_buffers(): for name, buffer in net.named_buffers():
if name.endswith("last_embd"): if name.endswith("last_embd"):
continue continue
if name.endswith("grad_norm"): if name.endswith("last_w"):
continue continue
dist.broadcast(buffer, src=0) dist.broadcast(buffer, src=0)
return net return net
...@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device: ...@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device:
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
device = torch.device("cpu") device = torch.device("cpu")
get_executor() # initialize route executor
return device return device
\ No newline at end of file
...@@ -6,7 +6,7 @@ from typing import * ...@@ -6,7 +6,7 @@ from typing import *
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
from ..graph.distgraph import DistGraph from ..graph.distgraph import DistGraph
from ..core.gather import Gather # from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]: def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
......
...@@ -22,7 +22,7 @@ def train_epoch( ...@@ -22,7 +22,7 @@ def train_epoch(
if mask is not None: if mask is not None:
pred = pred[mask] pred = pred[mask]
targ = targ[mask] targ = targ[mask]
loss: Tensor = criterion(pred, targ) loss: Tensor = criterion(pred, targ)
opt.zero_grad() opt.zero_grad()
......
...@@ -13,4 +13,4 @@ def main_print(*args, **kwargs): ...@@ -13,4 +13,4 @@ def main_print(*args, **kwargs):
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: if rank == 0:
print(*args, **kwargs) print(*args, **kwargs)
dist.barrier() # dist.barrier()
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment