Commit a4c4cadb by Wenjie Huang

with bugs

parent 06d10ed3
from .a2a import all_to_all, with_gloo, with_nccl from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache # from .cache import EmbeddingCache
from .gather import Gather # from .gather import Gather
from .route import Route, GatherWork # from .route import Route, GatherWork
\ No newline at end of file from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
...@@ -30,12 +30,14 @@ class Works: ...@@ -30,12 +30,14 @@ class Works:
def all_to_all( def all_to_all(
output_tensor_list: List[Tensor], output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor], input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works: ) -> Works:
assert len(output_tensor_list) == len(input_tensor_list) assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl(): if with_nccl():
work = dist.all_to_all( work = dist.all_to_all(
output_tensor_list=output_tensor_list, output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list, input_tensor_list=input_tensor_list,
group=group,
async_op=True, async_op=True,
) )
return Works(work) return Works(work)
...@@ -48,8 +50,8 @@ def all_to_all( ...@@ -48,8 +50,8 @@ def all_to_all(
send_i = (rank + i) % world_size send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i) send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i) recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w) works.push(recv_w, send_w)
output_tensor_list[rank][:] = input_tensor_list[rank] output_tensor_list[rank][:] = input_tensor_list[rank]
......
import torch
from torch import Tensor
from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
def wait(self):
pass
def get(self):
pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
def wait(self):
pass
def get(self) -> Tensor:
return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
import torch import torch
import torch.nn as nn
from torch.utils import hooks
from torch import Tensor from torch import Tensor
from typing import * from typing import *
class EmbeddingCache: from .route import Route, RouteWorkBase
def __init__(self) -> None: from .shrink import ShrinkData
self.values: Optional[Tensor] = None from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
self.active: Optional[Tensor] = None from .route import get_executor as get_route_executor
def get_values(self) -> Optional[Tensor]:
return self.values
def get_active(self) -> Optional[Tensor]: class ProbeLayer:
return self.active def __init__(self, layer_id, probe_obj) -> None:
self.layer_id: int = layer_id
self.probe_obj: NodeProbe = probe_obj
self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
self._val_hook_handle: Optional[hooks.RemovableHandle] = None
def get(self) -> Tuple[Optional[Tensor], Optional[Tensor]]: def to(self, device):
return self.get_values(), self.get_active() self.last_w = self.last_w.to(device)
return self
def set(self, values: Tensor, active: Optional[Tensor] = None): def remove_val_hook(self):
values = values.detach() if self._val_hook_handle is not None:
active = active.detach() self._val_hook_handle.remove()
if active is None: self._val_hook_handle = None
if self.values is None:
self.values = values.clone() def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
else: assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
self.values[:] = values
self.active = None def hook(grad: Tensor):
else: from starrygl.utils.printer import main_print
if self.values is None: import time
self.values = torch.zeros( self.probe_obj.update_sample_w(self.last_w, grad, idx)
size=active.shape[:1] + values.shape[1:], self.remove_val_hook()
dtype=values.dtype, self._backward_sample(False)
device=values.device, main_print(self.layer_id, grad.size(0), idx.size(0))
)
self.val_hook_handle = val.register_hook(hook)
def warmup_sample(self):
assert self._is_last_layer()
for probe_layer in self.probe_obj.layers[::-1]:
probe_layer._backward_sample()
# max_norm = max(probe_layer.last_w.max().item(), 1.0)
# probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
def _backward_sample(self, x: bool = True):
dst_idx = self._collect_prev_dst_idx()
if dst_idx is None:
return
for cache in self.probe_obj.hook_caches:
next_cache_layer = self._next_cache_layer(cache)
shrink = next_cache_layer.set_shrink_data(dst_idx)
src_idx = shrink.src_idx
work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
next_cache_layer.sync_pull_work(work)
if not self._is_first_layer() and x:
next_cache_layer.sync_backward_sample_work(src_idx)
def _collect_prev_dst_idx(self) -> Optional[Tensor]:
if len(self.probe_obj.hook_caches) == 0:
return None
if values.size(0) == active.size(0): if self._is_last_layer():
self.values[active] = values[active] return self._wrapped_sample(self.probe_obj.num_samples, None)
if len(self.probe_obj.hook_caches) == 1:
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
dst_idx = prev_cache_layer.sync_backward_sample_work()
else: else:
self.values[active] = values tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
for cache in self.probe_obj.hook_caches:
prev_cache_layer = self._prev_cache_layer(cache)
prev_idx = prev_cache_layer.sync_backward_sample_work()
if prev_idx is not None:
tmp.index_fill_(0, prev_idx, 1)
dst_idx = torch.where(tmp)[0]
if dst_idx.size(0) == 0:
dst_idx = None
if dst_idx is None:
return None
return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
def _prev_cache_layer(self, cache):
cache: MessageCache = cache
i = self.layer_id + 1
if i < self.probe_obj.num_layers:
return cache.layers[i]
return None
def _next_cache_layer(self, cache):
cache: MessageCache = cache
return cache.layers[self.layer_id]
if self.active is None: def _is_last_layer(self):
self.active = active.clone() return self.layer_id + 1 >= self.probe_obj.num_layers
def _is_first_layer(self):
return self.layer_id == 0
def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
w = self.last_w
if idx is None:
return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
if 0 < k and k < idx.size(0):
t = self.probe_obj.sample(w[idx], k)
return idx[t]
else: else:
self.active[:] = active return idx
class NodeProbe:
def __init__(self,
num_nodes: int,
num_layers: int,
num_samples: int = 0,
p: str = "fro",
dim: int = -1,
beta: float = 1.0,
) -> None:
super().__init__()
self.num_nodes = num_nodes
self.num_layers = num_layers
self.num_samples = num_samples
self.p = p
self.dim = dim
self.beta = beta
self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
self.hook_caches: Set[MessageCache] = set()
def fuse(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor: def to(self, device):
if active is None: for i in range(self.num_layers):
return values self.layers[i].to(device)
return self
def assign_message_cache(self, cache):
cache: MessageCache = cache
assert self.num_layers == cache.num_layers
self.hook_caches.add(cache)
def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
val_norm = grad.norm(p=self.p, dim=self.dim)
if idx is None:
last_w[:] = val_norm
else: else:
if self.values is None: if self.beta != 1.0:
x = torch.zeros( last_w.mul_(self.beta)
size=active.shape[:1] + values.shape[1:], last_w[idx] = val_norm
dtype=values.dtype,
device=values.device, def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
w = w / w.sum()
return torch.multinomial(w, num_samples=k, replacement=False)
def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
self.layers[i].register_val_hook(val, idx)
return val
class CacheLayer:
def __init__(self, layer_id, cache_obj, num_features) -> None:
self.layer_id: int = layer_id
self.cache_obj: MessageCache = cache_obj
self.data: Tensor = torch.randn(
size=(self.cache_obj.src_size, num_features),
dtype=torch.float32,
device=self.cache_obj.cache_device,
) )
self.state: Dict[str, Any] = {}
self.shrink: Optional[ShrinkData] = None
self._async_push_work: Optional[AsyncCopyWorkBase] = None
self._async_pull_work: Optional[AsyncCopyWorkBase] = None
self._backward_sample_work: Optional[RouteWorkBase] = None
def to(self, device):
if self.shrink is not None:
self.shrink = self.shrink.to(device)
return self
def cache_data_to(self):
self.data = self.data.to(self.cache_obj.cache_device)
return self
def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
self.shrink = self.cache_obj.new_shrink_data(dst_idx)
return self.shrink
def get_shrink_data(self) -> Optional[ShrinkData]:
return self.shrink
# def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# if src_idx is None:
# return
# assert self._backward_sample_work is None
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# work = self._backward_sample_work
# self._backward_sample_work = None
# return work
def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
dst_idx = None
if self._backward_sample_work is not None:
_, dst_idx = self._backward_sample_work.get()
if src_idx is not None:
# self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
self._backward_sample_work = work
else: else:
x = self.values.clone() self._backward_sample_work = None
return dst_idx
def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
if self._async_push_work is not None:
self._async_push_work.get()
self._async_push_work = work
def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
out = None
if self._async_pull_work is not None:
out = self._async_pull_work.get()
self._async_pull_work = work
return out
if values.size(0) == active.size(0): class MessageCache:
x[active] = values[active] def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
edge_index: Tensor,
num_features: Union[int, torch.Size],
num_layers: int,
cache_device: Union[str, torch.device, None],
bipartite: bool = False,
) -> None:
self.src_size = src_ids.size(0)
self.dst_size = dst_ids.size(0)
self.edge_index = edge_index
self.num_layers = num_layers
if cache_device is None:
cache_device = edge_index.device
self.cache_device = cache_device
self.bipartite = bipartite
self.route = Route(dst_ids, src_ids, bipartite=bipartite)
self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
def to(self, device):
self.edge_index = self.edge_index.to(device)
self.route = self.route.to(device)
for i in range(self.num_layers):
self.layers[i].to(device)
return self
def cached_data_to(self, device):
self.cache_device = torch.device(device)
for i in range(self.num_layers):
self.layers[i].cache_data_to()
return self
# @property
# def is_offload(self) -> bool:
# return self.edge_index.device != self.cache_device
def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
if dst_idx is None:
return None
return ShrinkData(
src_size=self.src_size,
dst_size=self.dst_size,
dst_idx=dst_idx,
edge_index=self.edge_index,
bipartite=self.bipartite,
)
# def clear_shrink_data(self):
# for i in range(self.num_layers):
# self.layers[i].shrink = None
def replace_layer_data(self, i: int, data: Tensor):
layer = self.layers[i]
assert layer.data.size(0) == data.size(0)
layer.data = data
return self
def update_cache(self,
i: int,
val: Tensor,
idx: Optional[Tensor],
async_op: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
layer: CacheLayer = self.layers[i]
src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# full graph
if idx is None:
return src_val, None
shrink = layer.get_shrink_data()
if shrink is None: # 这里可能存在问题,在初始化的时候
data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
data.index_copy_(0, src_idx, src_val)
return data, src_idx
else: else:
x[active] = values push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
return x layer.sync_push_work(push_work)
data = layer.sync_pull_work()
sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
data.index_copy_(0, sidx, sval)
return data, sidx
# 先异步更新缓存
# 同时异步下载缓存
# 然后本地和远程缓存混合,作为最终结果返回
def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
return self.layers[i].sync_pull_work()
# def _update_cache_impl(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# # communcation
# state = self.layers[i].state
# src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # push latest embeddings
# data = self.layers[i].data
# if src_idx is None:
# data[:] = src_val
# else:
# data[src_idx] = src_val
# # get previous generated shrink data
# shr: ShrinkData = self.layers[i].shrink
# if shr is None:
# return src_val, src_idx
# else:
# # pull latest embeddings
# return data[shr.src_idx], shr.src_idx
# def _update_cache_offload(self,
# i: int,
# dst_val: Tensor,
# dst_idx: Optional[Tensor],
# route: Route,
# async_op: bool,
# ):
# raise NotImplementedError
# def _compute_cached_data_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cached_data_size = torch.Size([self.src_size,])
# else:
# cached_data_size = torch.Size([self.src_size, num_features])
# return cached_data_size
# def _compute_cpu_buf_size(self) -> torch.Size:
# num_features = self.num_features
# if num_features is None:
# cpu_buf_size = torch.Size([2**32,])
# else:
# cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# return cpu_buf_size
\ No newline at end of file
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.autograd as autograd # import torch.autograd as autograd
from torch import Tensor # from torch import Tensor
from contextlib import contextmanager # from contextlib import contextmanager
from typing import * # from typing import *
from .route import Route, GatherWork # from .route import Route, GatherWork
from .cache import EmbeddingCache # # from .cache import CachedEmbeddings
class GatherContext: # class GatherContext:
def __init__(self, # def __init__(self,
this, # this,
route: Route, # route: Route,
async_op: bool, # async_op: bool,
) -> None: # ) -> None:
self.this = this # self.this = this
self.route = route # self.route = route
self.async_op = async_op # self.async_op = async_op
class Gather(nn.Module): # class Gather(nn.Module):
def __init__(self, # def __init__(self,
num_nodes: int, # num_nodes: int,
num_features: Optional[int] = None, # num_features: Optional[int] = None,
beta: float = 1.0, # beta: float = 1.0,
) -> None: # ) -> None:
super().__init__() # super().__init__()
self.num_nodes = num_nodes # self.beta = beta
self.num_features = num_features # if num_features is None:
self.beta = beta # self.register_buffer("last_embd", torch.zeros(num_nodes))
if num_features is None: # else:
self.register_buffer("last_embd", torch.zeros(num_nodes)) # self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
else:
self.register_buffer("last_embd", torch.zeros(num_nodes, num_features)) # self.last_fw_work: Optional[GatherWork] = None
self.last_fw_work: Optional[GatherWork] = None # self.last_bw_work: Optional[GatherWork] = None
self.last_bw_work: Optional[GatherWork] = None # self.reset_parameters()
self.reset_parameters()
# def forward(self,
def forward(self, # val: Tensor,
values: Tensor, # idx: Optional[Tensor],
active: Optional[Tensor], # route: Route,
route: Route, # async_op: bool = False,
async_op: bool = False, # ) -> Tuple[Tensor, Tensor]:
) -> Tuple[Tensor, Tensor]: # with self._manager(route=route, async_op=async_op):
with self._gather_manager(route=route, async_op=async_op): # return GatherFunction.apply(val, idx)
return GatherFunction.apply(values, active)
# def reset_parameters(self):
def reset_parameters(self): # last_embd = self.get_buffer("last_embd")
last_embd = self.get_buffer("last_embd") # nn.init.normal_(last_embd, mean=0, std=1)
nn.init.normal_(last_embd, mean=0, std=1)
# def fuse_embeddings(self,
def fuse_embeddings(self, # val: Tensor,
values: Tensor, # idx: Optional[Tensor] = None,
active: Optional[Tensor] = None, # inplace: bool = False,
inplace: bool = False, # ) -> Tensor:
) -> Tensor: # last_embd = self.get_buffer("last_embd")
last_embd = self.get_buffer("last_embd") # return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
return GatherFuseFunction.apply(values, active, last_embd, self.beta, inplace, self.training)
# if not inplace: # @contextmanager
# last_embd = last_embd.clone() # def _manager(self, route: Route, async_op: bool):
# global _global_gather_context
# if active is None: # stacked = _global_gather_context
# if self.beta != 1.0:
# values = values * self.beta + last_embd * (1 - self.beta) # try:
# last_embd[:] = values # _global_gather_context = GatherContext(
# else: # this=self,
# if values.size(0) == active.size(0): # route=route,
# values = values[active] # async_op=async_op,
# if self.beta != 1.0: # )
# values = values * self.beta + last_embd[active] * (1 - self.beta) # yield _global_gather_context
# last_embd[active] = values # finally:
# return last_embd # _global_gather_context = stacked
@contextmanager # class GatherFunction(autograd.Function):
def _gather_manager(self, route: Route, async_op: bool): # @staticmethod
global _global_gather_context # def forward(
stacked = _global_gather_context # ctx: autograd.function.FunctionCtx,
# val: Tensor,
try: # idx: Optional[Tensor] = None,
_global_gather_context = GatherContext( # ):
this=self, # gather_ctx = _last_global_gather_context()
route=route, # this: Gather = gather_ctx.this
async_op=async_op, # route: Route = gather_ctx.route
) # async_op: bool = gather_ctx.async_op
yield _global_gather_context # return_idx: bool = idx is not None
finally:
_global_gather_context = stacked # current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
# if async_op:
class GatherFunction(autograd.Function): # work = this.last_fw_work or current_work
@staticmethod # this.last_fw_work = current_work
def forward( # else:
ctx: autograd.function.FunctionCtx, # work = current_work
values: Tensor,
active: Optional[Tensor] = None, # recv_val, recv_idx = work.get()
): # recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
gather_ctx = _last_global_gather_context() # recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
this: Gather = gather_ctx.this
route: Route = gather_ctx.route # if this.training:
async_op: bool = gather_ctx.async_op # ctx.save_for_backward(idx, recv_idx)
# ctx.this = this
current_work = route.gather_forward(values, active, async_op=async_op) # ctx.route = route
if async_op: # ctx.async_op = async_op
work = this.last_fw_work or current_work # return recv_val, recv_idx
this.last_fw_work = current_work
else: # @staticmethod
work = current_work # def backward(
# ctx: autograd.function.FunctionCtx,
recv_values, recv_active = work.get() # val_grad: Tensor,
recv_values = this.fuse_embeddings( # idx_grad: Optional[Tensor],
values=recv_values, # ):
active=recv_active, # this: Gather = ctx.this
inplace=True, # route: Route = ctx.route
) # async_op: bool = ctx.async_op
if this.training: # with torch.no_grad():
# 如果输入的values是收缩过的,求解梯度的时候需要去除空洞 # recv_idx, idx_grad = ctx.saved_tensors
if active is not None and values.size(0) < active.size(0): # if idx_grad is not None:
ctx.shrink_grad = True # val_grad = val_grad[idx_grad]
ctx.save_for_backward(active, recv_active)
else: # current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
ctx.shrink_grad = False # if async_op:
ctx.save_for_backward(recv_active) # work = this.last_bw_work or current_work
ctx.this = this # this.last_bw_work = current_work
ctx.route = route # else:
ctx.async_op = async_op # work = current_work
return recv_values, recv_active
# recv_val = work.get_val()
@staticmethod # if recv_idx is not None:
def backward( # recv_val = recv_val[recv_idx]
ctx: autograd.function.FunctionCtx, # return recv_val, None
grad_values: Tensor,
grad_active: Optional[Tensor], # class GatherFuseFunction(autograd.Function):
): # @staticmethod
this: Gather = ctx.this # def forward(
route: Route = ctx.route # ctx: autograd.function.FunctionCtx,
async_op: bool = ctx.async_op # val: Tensor,
shrink_grad: bool = ctx.shrink_grad # idx: Optional[Tensor],
# last_embd: Tensor,
with torch.no_grad(): # beta: float,
# # 反向传播激活值是沿着前向传播的反方向进行 # inplace: bool,
if shrink_grad: # training: bool,
recv_active, grad_active = ctx.saved_tensors # ):
else: # if not inplace:
grad_active, = ctx.saved_tensors # last_embd = last_embd.clone()
# ctx.beta = beta
current_work = route.gather_backward(grad_values, grad_active, async_op=async_op)
if async_op: # if idx is None:
work = this.last_bw_work or current_work # assert val.size(0) == last_embd.size(0)
this.last_bw_work = current_work # if beta != 1.0:
else: # last_embd.mul_(1 - beta).add_(val * beta)
work = current_work # else:
# last_embd[:] = (val)
recv_values = work.get_values() # else:
# assert val.size(0) == idx.size(0)
if shrink_grad: # if beta != 1.0:
recv_values = recv_values[recv_active] # last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
return recv_values, None # else:
# last_embd[idx] = val
class GatherFuseFunction(autograd.Function): # if training:
@staticmethod # ctx.beta = beta
def forward( # ctx.save_for_backward(idx)
ctx: autograd.function.FunctionCtx, # return last_embd
values: Tensor,
active: Optional[Tensor], # @staticmethod
last_embd: Tensor, # def backward(
beta: float, # ctx: autograd.function.FunctionCtx,
inplace: bool, # grad: Tensor,
training: bool, # ):
): # beta: float = ctx.beta
if not inplace: # idx, = ctx.saved_tensors
last_embd = last_embd.clone()
ctx.beta = beta # if idx is not None:
# grad = grad[idx]
if active is None:
if beta != 1.0: # if beta != 1.0:
values = values * beta + last_embd * (1 - beta) # grad = grad * beta
last_embd[:] = values
ctx.shrink_grad = False # return grad, None, None, None, None, None
else:
if values.size(0) == active.size(0): # #### private functions
values = values[active] # _global_gather_context: Optional[GatherContext] = None
ctx.shrink_grad = False # def _last_global_gather_context() -> GatherContext:
else: # global _global_gather_context
if training: # assert _global_gather_context is not None
ctx.save_for_backward(active) # return _global_gather_context
ctx.shrink_grad = True
if beta != 1.0:
values = values * beta + last_embd[active] * (1 - beta)
last_embd[active] = values
return last_embd
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
if ctx.shrink_grad:
active, = ctx.saved_tensors
return grad[active] * ctx.beta, None, None, None, None, None
else:
return grad * ctx.beta, None, None, None, None, None
#### private functions
_global_gather_context: Optional[GatherContext] = None
def _last_global_gather_context() -> GatherContext:
global _global_gather_context
assert _global_gather_context is not None
return _global_gather_context
import torch import torch
import torch.autograd as autograd
import torch.distributed as dist import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from contextlib import contextmanager
from .a2a import all_to_all, Works, with_nccl from .a2a import all_to_all, Works, with_nccl
from .event import Event
# class GatherWork:
# def __init__(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# recv_val_idx: List[Tensor],
# recv_val_dat: List[Tensor],
# works_list: List[Works],
# ) -> None:
# self._waited = False
# self._values_buffer = values_buffer
# self._active_buffer = active_buffer
# self._recv_val_idx = recv_val_idx
# self._recv_val_dat = recv_val_dat
# self._works_list = works_list
# self._cached_idx: Optional[Tensor] = None
# def _wait(self) -> None:
# assert not self._waited
# for w in self._works_list:
# if w is None:
# continue
# w.wait()
# if self._active_buffer is None:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# else:
# for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
# self._values_buffer[idx] += dat
# self._active_buffer[idx] = True
# self._active_buffer = torch.where(self._active_buffer)[0]
# self._works_list = []
# self._recv_val_dat = []
# self._recv_val_idx = []
# self._waited = True
# def get_val(self) -> Tensor:
# if not self._waited:
# self._wait()
# return self._values_buffer
# def get_idx(self) -> Optional[Tensor]:
# if not self._waited:
# self._wait()
# return self._active_buffer
# def get(self) -> Tuple[Tensor, Optional[Tensor]]:
# return self.get_val(), self.get_idx()
class RouteWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
raise NotImplementedError
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class GatherWork: class RouteWork(RouteWorkBase):
def __init__(self, def __init__(self,
values_buffer: Tensor, buf: Union[Tensor, int],
active_buffer: Optional[Tensor],
recv_val_idx: List[Tensor], recv_val_idx: List[Tensor],
recv_val_dat: List[Tensor], recv_val_dat: List[Tensor],
works_list: List[Works], works_list: List[Works],
) -> None: ) -> None:
self._waited = False super().__init__()
self._values_buffer = values_buffer assert len(recv_val_idx) != 0
self._active_buffer = active_buffer assert len(recv_val_idx) == len(recv_val_dat)
self._buf = buf
self._recv_val_idx = recv_val_idx self._recv_val_idx = recv_val_idx
self._recv_val_dat = recv_val_dat self._recv_val_dat = recv_val_dat
self._works_list = works_list self._works_list = works_list
def _wait(self) -> None: self._val: Optional[Tensor] = None
assert not self._waited self._idx: Optional[Tensor] = None
def wait(self) -> None:
if self._val is not None:
return
for w in self._works_list: for w in self._works_list:
if w is None: if w is None:
continue continue
w.wait() w.wait()
if self._active_buffer is None: ikw = dict(dtype=self._recv_val_idx[0].dtype, device=self._recv_val_idx[0].device)
fkw = dict(dtype=self._recv_val_dat[0].dtype, device=self._recv_val_dat[0].device)
if isinstance(self._buf, Tensor):
for idx in self._recv_val_idx:
self._buf[idx] = True
self._idx = torch.where(self._buf)[0]
imp = torch.empty(self._buf.size(0), **ikw).fill_((2**62-1)*2+1)
imp[self._idx] = torch.arange(self._idx.size(0), **ikw)
s = self._idx.shape[:1] + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat): for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat idx = imp[idx]
self._val[idx] += dat
else: else:
s = (self._buf,) + self._recv_val_dat[0].shape[1:]
self._val = torch.zeros(s, **fkw)
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat): for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat self._val[idx] += dat
self._active_buffer[idx] = True
self._waited = True self._buf = None
self._recv_val_dat = None
self._recv_val_idx = None
self._works_list = None
def get_values(self) -> Tensor: if self._events is not None:
if not self._waited: self._events[1].record()
self._wait()
return self._values_buffer
def get_active(self) -> Optional[Tensor]: def get(self) -> Tuple[Tensor, Optional[Tensor]]:
if not self._waited: if self._val is None:
self._wait() self.wait()
return self._active_buffer return self._val, self._idx
class RouteExecutorWork(RouteWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.wait()
def get(self) -> Tuple[Tensor, Optional[Tensor]]: def get(self) -> Tuple[Tensor, Optional[Tensor]]:
return self.get_values(), self.get_active() if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class RouteExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
self._group = dist.new_group()
def async_forward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.forward_a2a, val, idx)
def async_backward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.backward_a2a, val, idx)
def _async_a2a_impl(self, func, val: Tensor, idx: Optional[Tensor]) -> RouteWorkBase:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
ret = func(val, idx, group=self._group).get()
end.record()
return ret
start.record()
handle = self._executor.apply_async(run)
return RouteExecutorWork(handle).set_events(start, end)
# def apply_async(self, func, val, idx, async_op) -> RouteWorkBase:
# start = Event(val.is_cuda)
# end = Event(val.is_cuda)
# if not async_op:
# start.record()
# return func(val, idx).set_events(start, end)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# ret = func(val, idx, group=self._group).get()
# end.record()
# return ret
# start.record()
# handle = self._executor.apply_async(run)
# return RouteExecutorWork(handle).set_events(start, end)
class RouteCtx:
def __init__(self, route, state, async_op) -> None:
self.route = route
self.state = state
self.async_op = async_op
class Route: class Route:
def __init__(self, def __init__(self,
src_ids: Tensor, src_ids: Tensor,
dst_ids: Tensor, dst_ids: Tensor,
bipartite: bool = False,
): ):
assert src_ids.dtype == torch.long assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1 assert src_ids.dim() == 1
assert dst_ids.dim() == 1 assert dst_ids.dim() == 1
if not bipartite:
# 要求数据为点分割,即所有分区的src_ids正交
assert src_ids.size(0) <= dst_ids.size(0)
assert (dst_ids[:src_ids.size(0)] == src_ids).all()
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -135,146 +307,347 @@ class Route: ...@@ -135,146 +307,347 @@ class Route:
self.backward_routes.append(bw_route) self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表 # 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _p2p_recv(self.forward_routes) self.forward_routes = _fix_fw_routes(self.forward_routes)
# 满足同构图条件,则每个点添加自环 # 满足同构图条件,则每个点添加自环
if src_ids.size(0) <= dst_ids.size(0): if not bipartite:
t = dst_ids[:src_ids.size(0)] == src_ids
if t.all():
rank_ind = torch.arange(src_ids.size(0), **ikw) rank_ind = torch.arange(src_ids.size(0), **ikw)
rank_route = torch.vstack([rank_ind, rank_ind]) rank_route = torch.vstack([rank_ind, rank_ind])
self.forward_routes[rank] = rank_route self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route self.backward_routes[rank] = rank_route
self.total_time_used: float = 0.0
dist.barrier() dist.barrier()
def _gather_impl(self, def to(self, device):
values_buffer: Tensor, self.forward_routes = [ro.to(device) for ro in self.forward_routes]
active_buffer: Optional[Tensor], self.backward_routes = [ro.to(device) for ro in self.backward_routes]
values: Tensor, return self
active: Optional[Tensor],
def _a2a_impl(self,
val: Tensor,
idx: Optional[Tensor],
send_buf_size: int,
recv_buf_size: int,
send_routes: List[Tensor], send_routes: List[Tensor],
recv_routes: List[Tensor], recv_routes: List[Tensor],
async_op: bool, group: Optional[Any],
): ) -> RouteWork:
ikw = dict(dtype=torch.long, device=values.device) bkw = dict(dtype=torch.bool, device=val.device)
fkw = dict(dtype=values.dtype, device=values.device) ikw = dict(dtype=torch.long, device=val.device)
fkw = dict(dtype=val.dtype, device=val.device)
# 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
if active is not None: if idx is not None:
if values.size(0) < active.size(0): msk = torch.zeros(send_buf_size, **bkw).index_fill_(0, idx, 1)
values = _spread_values(values, active) send_routes = [ro[:,msk[ro[0]]] for ro in send_routes]
assert values.size(0) == active.size(0)
send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
# 计算新的路由表 recv_sizes = torch.zeros_like(send_sizes)
active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
dist.all_to_all_single(recv_sizes, send_sizes, group=group)
# 计算接收缓冲区大小 recv_sizes = recv_sizes.tolist()
scatter_sizes = [ro.size(1) for ro in active_routes]
scatter_sizes = torch.tensor(scatter_sizes, **ikw) if idx is None:
gather_sizes = torch.zeros_like(scatter_sizes)
dist.all_to_all_single(gather_sizes, scatter_sizes)
gather_sizes = gather_sizes.tolist()
def async_run():
if active is not None:
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], [] send_val_dat, recv_val_dat = [], []
for i, s in enumerate(gather_sizes): recv_val_idx = [ro[0] for ro in recv_routes]
c = (s,) + values.shape[1:] for i, ro in enumerate(recv_routes):
send_val_idx.append(active_routes[i][1]) s = ro.size(1)
recv_val_idx.append(torch.zeros(s, **ikw)) c = (s,) + val.shape[1:]
send_val_dat.append(values[active_routes[i][0]]) send_val_dat.append(val[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw)) recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [ works_list = [
all_to_all(recv_val_idx, send_val_idx), all_to_all(recv_val_dat, send_val_dat, group=group),
all_to_all(recv_val_dat, send_val_dat),
] ]
return GatherWork( return RouteWork(
values_buffer=values_buffer, buf=recv_buf_size,
active_buffer=active_buffer,
recv_val_idx=recv_val_idx, recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat, recv_val_dat=recv_val_dat,
works_list=works_list, works_list=works_list,
) )
else: else:
imp = torch.empty(send_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[idx] = torch.arange(idx.size(0), **ikw)
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], [] send_val_dat, recv_val_dat = [], []
recv_val_idx = [ro[0] for ro in recv_routes] for i, s in enumerate(recv_sizes):
for i, r in enumerate(recv_routes): c = (s,) + val.shape[1:]
s = r.size(1)
c = (s,) + values.shape[1:] send_val_idx.append(send_routes[i][1])
send_val_dat.append(values[send_routes[i][0]]) recv_val_idx.append(torch.zeros(s, **ikw))
send_index = imp[send_routes[i][0]]
send_val_dat.append(val[send_index])
recv_val_dat.append(torch.zeros(c, **fkw)) recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [ works_list = [
all_to_all(recv_val_dat, send_val_dat), all_to_all(recv_val_idx, send_val_idx, group=group),
all_to_all(recv_val_dat, send_val_dat, group=group),
] ]
return GatherWork(
values_buffer=values_buffer, recv_buf = torch.zeros(recv_buf_size, **bkw)
active_buffer=active_buffer, return RouteWork(
buf=recv_buf,
recv_val_idx=recv_val_idx, recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat, recv_val_dat=recv_val_dat,
works_list=works_list, works_list=works_list,
) )
if async_op and with_nccl(): def forward_a2a(self,
stream = get_stream() val: Tensor,
stream.wait_stream(torch.cuda.current_stream()) idx: Optional[Tensor] = None,
with torch.cuda.stream(stream): group: Optional[Any] = None,
return async_run() ) -> RouteWork:
else: if idx is None:
return async_run() assert val.size(0) == self.src_size
def gather_forward(self,
values: Tensor,
active: Optional[Tensor],
async_op: bool = False,
return_active: bool = True,
):
bkw = dict(dtype=torch.bool, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
s = (self.dst_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
if return_active:
active_buffer = torch.zeros(self.dst_size, **bkw)
else: else:
active_buffer = None assert val.size(0) == idx.size(0)
return self._gather_impl( if group is None:
values_buffer=values_buffer, start = Event(use_cuda=val.is_cuda)
active_buffer=active_buffer, end = Event(use_cuda=val.is_cuda)
values=values, start.record()
active=active, work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.src_size,
recv_buf_size=self.dst_size,
send_routes=self.forward_routes, send_routes=self.forward_routes,
recv_routes=self.backward_routes, recv_routes=self.backward_routes,
async_op=async_op, group=group,
) )
if group is None:
def gather_backward(self, return work.set_events(start, end)
values: Tensor, return work
active: Optional[Tensor],
def backward_a2a(self,
val: Tensor,
idx: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> RouteWork:
if idx is None:
assert val.size(0) == self.dst_size
else:
assert val.size(0) == idx.size(0)
if group is None:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
start.record()
work = self._a2a_impl(
val=val, idx=idx,
send_buf_size=self.dst_size,
recv_buf_size=self.src_size,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
if group is None:
return work.set_events(start, end)
return work
def apply(self,
val: Tensor,
idx: Optional[Tensor],
state: Dict[str, Any],
async_op: bool = False, async_op: bool = False,
return_active: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]:
with self._a2a_manager(state, async_op):
return RouteFunction.apply(val, idx)
@contextmanager
def _a2a_manager(self, state, async_op) -> RouteCtx:
global _global_route_context
stacked_ctx = _global_route_context
try:
_global_route_context = RouteCtx(self, state, async_op)
yield _global_route_context
finally:
_global_route_context = stacked_ctx
# def _gather_impl(self,
# values_buffer: Tensor,
# active_buffer: Optional[Tensor],
# values: Tensor,
# active: Optional[Tensor],
# send_routes: List[Tensor],
# recv_routes: List[Tensor],
# async_op: bool,
# ):
# ikw = dict(dtype=torch.long, device=values.device)
# fkw = dict(dtype=values.dtype, device=values.device)
# # 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
# if active is not None:
# # 计算新的路由表
# active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# # 计算接收缓冲区大小
# scatter_sizes = [ro.size(1) for ro in active_routes]
# scatter_sizes = torch.tensor(scatter_sizes, **ikw)
# gather_sizes = torch.zeros_like(scatter_sizes)
# dist.all_to_all_single(gather_sizes, scatter_sizes)
# gather_sizes = gather_sizes.tolist()
# def async_run():
# if active is not None:
# send_val_idx, recv_val_idx = [], []
# send_val_dat, recv_val_dat = [], []
# for i, s in enumerate(gather_sizes):
# c = (s,) + values.shape[1:]
# send_val_idx.append(active_routes[i][1])
# recv_val_idx.append(torch.zeros(s, **ikw))
# send_val_dat.append(values[active_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_idx, send_val_idx),
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# else:
# send_val_dat, recv_val_dat = [], []
# recv_val_idx = [ro[0] for ro in recv_routes]
# for i, r in enumerate(recv_routes):
# s = r.size(1)
# c = (s,) + values.shape[1:]
# send_val_dat.append(values[send_routes[i][0]])
# recv_val_dat.append(torch.zeros(c, **fkw))
# works_list = [
# all_to_all(recv_val_dat, send_val_dat),
# ]
# return GatherWork(
# values_buffer=values_buffer,
# active_buffer=active_buffer,
# recv_val_idx=recv_val_idx,
# recv_val_dat=recv_val_dat,
# works_list=works_list,
# )
# if async_op and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# return async_run()
# else:
# return async_run()
# def gather_forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.src_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.dst_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.dst_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.forward_routes,
# recv_routes=self.backward_routes,
# async_op=async_op,
# )
# def gather_backward(self,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# return_idx: bool = True,
# ):
# val, idx = _spread_val_idx(val, idx, self.dst_size)
# bkw = dict(dtype=torch.bool, device=val.device)
# fkw = dict(dtype=val.dtype, device=val.device)
# val_buf = torch.zeros((self.src_size,) + val.shape[1:], **fkw)
# idx_buf = torch.zeros(self.src_size, **bkw) if return_idx else None
# return self._gather_impl(
# values_buffer=val_buf,
# active_buffer=idx_buf,
# values=val,
# active=idx,
# send_routes=self.backward_routes,
# recv_routes=self.forward_routes,
# async_op=async_op,
# )
class RouteFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
val: Tensor,
idx: Optional[Tensor],
): ):
bkw = dict(dtype=torch.bool, device=values.device) route_ctx = get_global_route_context()
fkw = dict(dtype=values.dtype, device=values.device) route: Route = route_ctx.route
state: Dict[str, Any] = route_ctx.state
async_op: bool = route_ctx.async_op
# cur_work = get_executor().apply_async(route.forward_a2a, val, idx, async_op)
cur_work = get_executor().async_forward_a2a(val, idx, route)
if async_op:
# cur_work = get_executor().async_forward_a2a(val, idx, route)
work = state.get("__route_fw_work__") or cur_work
state["__route_fw_work__"] = cur_work
else:
# work = route.forward_a2a(val, idx)
work = cur_work
val, idx = work.get()
if work.has_events():
route.total_time_used += work.time_used()
ctx.route = route
ctx.state = state
ctx.async_op = async_op
ctx.save_for_backward(idx)
s = (self.src_size,) + values.shape[1:] return val, idx
values_buffer = torch.zeros(s, **fkw)
if return_active: @staticmethod
active_buffer = torch.zeros(self.src_size, **bkw) def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
_: None,
):
route: Route = ctx.route
state: Dict[str, Any] = ctx.state
async_op: bool = ctx.async_op
with torch.no_grad():
idx, = ctx.saved_tensors
# cur_work = get_executor().apply_async(route.backward_a2a, grad, idx, async_op)
cur_work = get_executor().async_backward_a2a(grad, idx, route)
if async_op:
# cur_work = get_executor().async_backward_a2a(grad, idx, route)
work = state.get("__route_bw_work__") or cur_work
state["__route_bw_work__"] = cur_work
else: else:
active_buffer = None # work = route.backward_a2a(grad, idx)
work = cur_work
return self._gather_impl( val, idx = work.get()
values_buffer=values_buffer, if work.has_events():
active_buffer=active_buffer, route.total_time_used += work.time_used()
values=values,
active=active, return val, idx
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
async_op=async_op,
)
#### private functions #### private functions
...@@ -288,7 +661,7 @@ def _all_reduce_num_nodes( ...@@ -288,7 +661,7 @@ def _all_reduce_num_nodes(
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX) dist.all_reduce(max_ids, op=dist.ReduceOp.MAX)
return max_ids.item() + 1 return max_ids.item() + 1
def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]: def _fix_fw_routes(tensors: List[Tensor]) -> List[Tensor]:
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -305,19 +678,52 @@ def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]: ...@@ -305,19 +678,52 @@ def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]:
all_to_all(new_tensors, tensors).wait() all_to_all(new_tensors, tensors).wait()
return new_tensors return new_tensors
def _spread_values(values: Tensor, active: Tensor) -> Tensor: # def _spread_val_idx(val: Tensor, idx: Optional[Tensor], size: int) -> Tuple[Tensor, Optional[Tensor]]:
new_values = torch.zeros( # if idx is None:
size=active.shape[:1] + values.shape[1:], # assert val.size(0) == size
dtype=values.dtype, # return val, None
device=values.device, # else:
) # assert val.size(0) == idx.size(0)
new_values[active] = values # x = torch.zeros(
return new_values # size=(size,) + val.shape[1:],
# dtype=val.dtype,
# device=val.device,
# )
# x[idx] = val
# y = torch.zeros(
# size=(size,),
# dtype=torch.bool,
# device=idx.device,
# ).index_fill_(0, idx, 1)
# return x, y
_global_route_context: Optional[RouteCtx] = None
def get_global_route_context() -> RouteCtx:
global _global_route_context
assert _global_route_context is not None
return _global_route_context
# 创建一个新的stream,专用于Gather异步通信 # 创建一个新的stream,专用于Gather异步通信
_STREAM: Optional[torch.cuda.Stream] = None # _STREAM: Optional[torch.cuda.Stream] = None
def get_stream() -> torch.cuda.Stream: # def get_stream() -> torch.cuda.Stream:
global _STREAM # global _STREAM
if _STREAM is None: # if _STREAM is None:
_STREAM = torch.cuda.Stream() # _STREAM = torch.cuda.Stream()
return _STREAM # return _STREAM
\ No newline at end of file
# @contextmanager
# def stream_manager(enable: bool = True) -> Optional[torch.cuda.Stream]:
# if enable and with_nccl():
# stream = get_stream()
# stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(stream):
# yield stream
# else:
# yield
_THREAD_EXEC: Optional[RouteExecutor] = None
def get_executor() -> RouteExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = RouteExecutor()
return _THREAD_EXEC
import torch
from torch import Tensor
from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.autograd as autograd # import torch.autograd as autograd
from torch import Tensor # from torch import Tensor
from contextlib import contextmanager # from contextlib import contextmanager
from typing import * # from typing import *
class StraightContext:
def __init__(self, this) -> None: # class StraightContext:
self.this = this # def __init__(self, this, g) -> None:
# self.this = this
class Straight(nn.Module): # self.g = g
def __init__(self,
num_nodes: int, # class Straight(nn.Module):
p: int = 2, # def __init__(self,
dim: int = -1, # num_nodes: int,
beta: float = 1.0, # num_samples: int,
) -> None: # norm_kwargs: Optional[Dict[str, Any]] = None,
super().__init__() # beta: float = 1.0,
self.num_nodes = num_nodes # prev: Optional[List[Any]] = None,
self.norm_p = p # ) -> None:
self.norm_dim = dim # super().__init__()
self.norm_beta = beta # assert num_samples <= num_nodes
self.register_buffer("grad_norm", torch.ones(num_nodes)) # self.num_nodes = num_nodes
# self.num_samples = num_samples
self.reset_parameters() # self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
def reset_parameters(self): # self.prev = prev
grad_norm = self.get_buffer("grad_norm")
nn.init.constant_(grad_norm, 1.0) # self.register_buffer("last_w", torch.ones(num_nodes))
def forward(self, values: Tensor, active: Optional[Tensor] = None) -> Tensor: # self._next_idx = None
with self._manager(self): # self._next_shrink_helper = None
return StraightFunction.apply(values, active)
# self.reset_parameters()
def multinomial(self,
num_samples: int, # def reset_parameters(self):
replacement: bool = False, # last_w = self.get_buffer("last_w")
) -> Tensor: # nn.init.constant_(last_w, 1.0)
w = self.get_buffer("grad_norm")
if num_samples <= 0: # def forward(self,
return torch.arange(self.num_nodes, dtype=torch.long, device=w.device) # val: Tensor,
# idx: Optional[Tensor],
# print(w) # g,
w = w / w.sum() # ) -> Tensor:
return torch.multinomial(w, num_samples=num_samples, replacement=replacement) # with self._manager(self, g):
# return StraightFunction.apply(val, idx)
def multinomial_mask(self,
num_samples: int, # def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
replacement: bool = False, # if not self.training:
) -> Tensor: # return None, None
device = self.get_buffer("grad_norm").device
if num_samples <= 0: # next_idx = self._next_idx
return torch.ones(self.num_nodes, dtype=torch.bool, device=device) # self._next_idx = None
w = self.multinomial(num_samples, replacement) # next_sh = self._next_shrink_helper
m = torch.zeros(self.num_nodes, dtype=torch.bool, device=device) # self._next_shrink_helper = None
m[w] = True
return m # return next_idx, next_sh
@contextmanager # def _sample_next(self) -> Tensor:
def _manager(self, this): # w = self.get_buffer("last_w")
global _global_straight_context # if self._next_idx is None:
stacked = _global_straight_context # if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
try: # else:
_global_straight_context = StraightContext(this=this) # self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
yield _global_straight_context # elif self.num_samples < self._next_idx.size(0):
finally: # idx = self.sample_impl(w[self._next_idx])
_global_straight_context = stacked # self._next_idx = self._next_idx[idx]
# return self._next_idx
class StraightFunction(autograd.Function): # def sample_impl(self, w: Tensor) -> Tensor:
@staticmethod # w = w / w.sum()
def forward( # return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
ctx: autograd.function.FunctionCtx,
values: Tensor, # # def multinomial(self,
active: Optional[Tensor], # # num_samples: int,
): # # replacement: bool = False,
stx = _last_global_straight_context() # # ) -> Tensor:
this: Straight = stx.this # # w = self.get_buffer("last_w")
if this.training: # # if num_samples <= 0:
ctx.this = this # # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
ctx.save_for_backward(active)
return values # # w = w / w.sum()
# # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
@staticmethod
def backward( # # def multinomial_mask(self,
ctx: autograd.function.FunctionCtx, # # num_samples: int,
grad: Tensor, # # replacement: bool = False,
): # # ) -> Tensor:
this: Straight = ctx.this # # w = self.get_buffer("last_w")
active, = ctx.saved_tensors # # if num_samples <= 0:
# # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
grad_norm = this.get_buffer("grad_norm")
if this.norm_beta != 1.0: # # w = self.multinomial(num_samples, replacement)
grad_norm.mul_(this.norm_beta) # # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
if active is None: # # return m
x = grad.detach()
x = x.norm(p=this.norm_p, dim=this.norm_dim) # @contextmanager
grad_norm[:] = x # def _manager(self, this, g):
else: # global _global_straight_context
if grad.size(0) == grad_norm.size(0): # stacked = _global_straight_context
x = grad.detach()[active]
else: # try:
x = grad.detach() # _global_straight_context = StraightContext(this=this, g=g)
x = x.norm(p=this.norm_p, dim=this.norm_dim) # yield _global_straight_context
grad_norm[active] = x # finally:
return grad, None # _global_straight_context = stacked
#### private functions
_global_straight_context: Optional[StraightContext] = None # class StraightFunction(autograd.Function):
def _last_global_straight_context() -> StraightContext: # @staticmethod
global _global_straight_context # def forward(
assert _global_straight_context is not None # ctx: autograd.function.FunctionCtx,
return _global_straight_context # val: Tensor,
# idx: Optional[Tensor],
# ):
if __name__ == "__main__": # from ..graph import DistGraph
s = Straight(3, beta=1.1)
x = torch.rand(3, 10).requires_grad_() # stx = _last_global_straight_context()
m = torch.tensor([0, 1, 0], dtype=torch.bool) # this: Straight = stx.this
# g: DistGraph = stx.g
s(x).sum().backward()
print(s.grad_norm) # last_w = this.get_buffer("last_w")
print(s.multinomial(2)) # if idx is None:
# assert val.size(0) == last_w.size(0)
s(x, m).sum().backward() # else:
print(s.grad_norm) # assert val.size(0) == idx.size(0)
print(s.multinomial(2))
# if this.training:
s(x[m], m).sum().backward() # ctx.this = this
print(s.grad_norm) # ctx.g = g
print(s.multinomial(2)) # ctx.save_for_backward(idx)
# return val
print(s.multinomial_mask(2))
\ No newline at end of file # @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
# if this.prev is not None or this._next_idx is None:
# from ..nn.convs.utils import ShrinkHelper
# dst_idx = this._sample_next()
# if this.prev is not None and this._next_idx is not None:
# this._next_shrink_helper = ShrinkHelper(g, dst_idx)
# src_idx = this._next_shrink_helper.src_idx
# work = g.route.gather_backward(
# src_idx, src_idx, async_op=False, return_idx=True)
# prev_dst_idx = work.get_idx()
# for p in this.prev:
# assert isinstance(p, Straight)
# p._next_idx = prev_dst_idx
# return grad, None
# #### private functions
# _global_straight_context: Optional[StraightContext] = None
# def _last_global_straight_context() -> StraightContext:
# global _global_straight_context
# assert _global_straight_context is not None
# return _global_straight_context
# if __name__ == "__main__":
# s = Straight(3, beta=1.1)
# x = torch.rand(3, 10).requires_grad_()
# m = torch.tensor([0, 1, 0], dtype=torch.bool)
# s(x).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x, m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x[m], m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# print(s.multinomial_mask(2))
\ No newline at end of file
...@@ -5,26 +5,23 @@ from torch import Tensor ...@@ -5,26 +5,23 @@ from torch import Tensor
from typing import * from typing import *
from contextlib import contextmanager from contextlib import contextmanager
import torch_sparse # import torch_sparse
from .ndata import NData from .ndata import NData
from .edata import EData from .edata import EData
from .utils import init_local_edge_index from .utils import init_local_edge_index
from ..core import Route from ..core import MessageCache, Route
class DistGraph: class DistGraph:
def __init__(self, def __init__(self,
ids: Tensor, ids: Tensor,
edge_index: Tensor, edge_index: Tensor,
args: Dict[str, Any] = {}, num_features: int,
device: Union[torch.device, str, int, None] = None, num_layers: int,
cache_device: str = "cpu",
**args: Dict[str, Any],
): ):
# graph's target device
if device is None:
device = ids.device
self.device = torch.device(device)
# build local_edge_index # build local_edge_index
dst_ids = ids dst_ids = ids
src_ids, local_edge_index = init_local_edge_index( src_ids, local_edge_index = init_local_edge_index(
...@@ -34,8 +31,15 @@ class DistGraph: ...@@ -34,8 +31,15 @@ class DistGraph:
self._src_ids = src_ids self._src_ids = src_ids
self._dst_ids = dst_ids self._dst_ids = dst_ids
self._local_edge_index = local_edge_index self._message_cache = MessageCache(
self._local_edge_ptr: Optional[Tensor] = None src_ids=src_ids,
dst_ids=dst_ids,
edge_index=local_edge_index,
num_features=num_features,
num_layers=num_layers,
cache_device=cache_device,
bipartite=False,
)
# node's attributes # node's attributes
self.ndata = NData( self.ndata = NData(
...@@ -49,24 +53,30 @@ class DistGraph: ...@@ -49,24 +53,30 @@ class DistGraph:
) )
# graph's attributes # graph's attributes
self.args = args self.args = dict(args)
# route table def to(self, device):
self.route = Route(dst_ids, src_ids) self._message_cache.to(device)
return self
def cache_data_to(self, device):
self._message_cache.cached_data_to(device)
@property @property
def edge_index(self) -> Tensor: def cache(self) -> MessageCache:
if self._local_edge_index.device != self.device: return self._message_cache
self._local_edge_index = self._local_edge_index.to(self.device)
return self._local_edge_index
@property @property
def edge_ptr(self) -> Optional[Tensor]: def route(self) -> Route:
if self._local_edge_ptr is None: return self._message_cache.route
return None
if self._local_edge_ptr.device != self.device: @property
self._local_edge_ptr = self._local_edge_ptr.to(self.device) def device(self) -> torch.device:
return self._local_edge_ptr return self.edge_index.device
@property
def edge_index(self) -> Tensor:
return self._message_cache.edge_index
@property @property
def src_ids(self) -> Tensor: def src_ids(self) -> Tensor:
...@@ -88,16 +98,6 @@ class DistGraph: ...@@ -88,16 +98,6 @@ class DistGraph:
def dst_size(self) -> int: def dst_size(self) -> int:
return self._dst_ids.size(0) return self._dst_ids.size(0)
@torch.no_grad()
def gather(self, values: Tensor, direct: str = "dst_to_src") -> Tensor:
if direct == "dst_to_src":
work = self.route.gather_forward(values, None, async_op=False, return_active=False)
elif direct == "src_to_dst":
work = self.route.gather_backward(values, None, async_op=False, return_active=False)
else:
raise ValueError(f"direct must be 'src_to_dst' or 'dst_to_src', but got '{direct}'")
return work.get_values()
@contextmanager @contextmanager
def scoped_manager(self): def scoped_manager(self):
stacked_ndata = self.ndata stacked_ndata = self.ndata
......
...@@ -6,6 +6,7 @@ from typing import * ...@@ -6,6 +6,7 @@ from typing import *
def init_local_edge_index( def init_local_edge_index(
dst_ids: Tensor, dst_ids: Tensor,
edge_index: Tensor, edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index) max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device) ikw = dict(dtype=torch.long, device=dst_ids.device)
...@@ -17,6 +18,10 @@ def init_local_edge_index( ...@@ -17,6 +18,10 @@ def init_local_edge_index(
if not (xmp != 0x01).all(): if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph") raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids] # src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0) xmp.fill_(0)
xmp[edge_index[0]] = 1 xmp[edge_index[0]] = 1
...@@ -26,8 +31,13 @@ def init_local_edge_index( ...@@ -26,8 +31,13 @@ def init_local_edge_index(
# 计算局部索引 # 计算局部索引
xmp.fill_((2**62-1)*2+1) xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw) xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
local_edge_index = xmp[edge_index.flatten()].view_as(edge_index) src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int: def calc_max_ids(*ids: Tensor) -> int:
......
from .convs import GCNConv, GATConv, GINConv from .convs import GCNConv, GATConv, GINConv
from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv # from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicJKOptions, BasicStraightOptions # from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
class GCN(BasicGNN): # class ShrinkGCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGCNConv(in_channels, out_channels, **kwargs) # return ShrinkGCNConv(in_channels, out_channels, **kwargs)
class GAT(BasicGNN): # class ShrinkGAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGATConv(in_channels, out_channels, **kwargs) # return ShrinkGATConv(in_channels, out_channels, **kwargs)
class GIN(BasicGNN): # class ShrinkGIN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs): # def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGINConv(in_channels, out_channels, **kwargs) # return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
import torch # import torch
import torch.nn as nn # import torch.nn as nn
import torch.nn.functional as F # import torch.nn.functional as F
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
import copy # import copy
import inspect # import inspect
from torch_geometric.nn import JumpingKnowledge # from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver # from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from dataclasses import dataclass # from dataclasses import dataclass
from ..graph.distgraph import DistGraph # from ..core.cache import NodeProbe
from ..core.gather import Gather # from ..graph.distgraph import DistGraph
from ..core.straight import Straight # # from ..core.gather import Gather
# # from ..core.straight import Straight
@dataclass
class BasicInputOptions: # @dataclass
weight: bool = True # class BasicInputOptions:
bias: bool = True # weight: bool = True
gather: bool = True # bias: bool = True
gather_first: bool = False # gather: bool = True
dropout: float = 0.0 # gather_first: bool = False
act: Optional[str] = None # dropout: float = 0.0
act_kwargs: Optional[Dict[str, Any]] = None # act: Optional[str] = None
act_first: bool = False # act_kwargs: Optional[Dict[str, Any]] = None
norm: Optional[str] = None # act_first: bool = False
norm_kwargs: Optional[Dict[str, Any]] = None # norm: Optional[str] = None
straight_enabled: bool = True # norm_kwargs: Optional[Dict[str, Any]] = None
# straight_enabled: bool = True
@dataclass # straight_num_samples: Optional[int] = None
class BasicLayerOptions:
in_channels: int # @dataclass
hidden_channels: int # class BasicLayerOptions:
num_layers: int # in_channels: int
out_channels: Optional[int] = None # hidden_channels: int
gather_beta: float = 1.0 # num_layers: int
dropout: float = 0.0 # out_channels: Optional[int] = None
act: Optional[str] = "relu" # gather_beta: float = 1.0
act_kwargs: Optional[Dict[str, Any]] = None # dropout: float = 0.0
act_first: bool = False # act: Optional[str] = "relu"
norm: Optional[str] = None # act_kwargs: Optional[Dict[str, Any]] = None
norm_kwargs: Optional[Dict[str, Any]] = None # act_first: bool = False
# norm: Optional[str] = None
@dataclass # norm_kwargs: Optional[Dict[str, Any]] = None
class BasicJKOptions: # jk_mode: Optional[str] = None
jk_mode: Optional[str] = None
# @dataclass
@dataclass # class BasicStraightOptions:
class BasicStraightOptions: # enabled: bool = False
enabled: bool = False # num_samples: Optional[int] = None
p: int = 2 # beta: float = 1.0
beta: float = 1.0
# class BasicGNN(nn.Module):
class BasicGNN(nn.Module): # def __init__(
def __init__( # self,
self, # g: DistGraph,
g: DistGraph, # layer_options: BasicLayerOptions,
layer_options: BasicLayerOptions, # input_options: BasicInputOptions = BasicInputOptions(),
input_options: BasicInputOptions = BasicInputOptions(), # straight_options: BasicStraightOptions = BasicStraightOptions(),
jk_options: BasicJKOptions = BasicJKOptions(), # **kwargs,
straight_options: BasicStraightOptions = BasicStraightOptions(), # ):
**kwargs, # super().__init__()
):
super().__init__() # num_samples = straight_options.num_samples or g.dst_size
# prev_straight = None
self.in_channels = in_channels = layer_options.in_channels
self.hidden_channels = hidden_channels = layer_options.hidden_channels # self.in_channels = in_channels = layer_options.in_channels
self.num_layers = num_layers = layer_options.num_layers # self.hidden_channels = hidden_channels = layer_options.hidden_channels
# self.num_layers = num_layers = layer_options.num_layers
# default out_channels is hidden_channels
self.out_channels = out_channels = hidden_channels \ # # default out_channels is hidden_channels
if layer_options.out_channels is None else layer_options.out_channels # self.out_channels = out_channels = hidden_channels \
self.gather_beta = gather_beta = layer_options.gather_beta # if layer_options.out_channels is None else layer_options.out_channels
# self.gather_beta = gather_beta = layer_options.gather_beta
self.layer_options = layer_options
self.input_options = input_options # self.layer_options = layer_options
self.jk_options = jk_options # self.input_options = input_options
self.straight_options = straight_options # self.straight_options = straight_options
# initialize input layer # # initialize input layer
if input_options.weight: # if input_options.weight:
self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias) # self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias)
in_channels = hidden_channels # in_channels = hidden_channels
if straight_options.enabled and input_options.straight_enabled: # if straight_options.enabled and input_options.straight_enabled and not input_options.gather_first:
if input_options.gather_first: # sns = input_options.straight_num_samples or g.dst_size
self.straight_x = Straight( # self.straight_x = Straight(g.dst_size, sns, beta=straight_options.beta, prev=prev_straight)
g.src_size, p=straight_options.p, beta=straight_options.beta) # prev_straight = [self.straight_x]
else:
self.straight_x = Straight( # if input_options.gather:
g.dst_size, p=straight_options.p, beta=straight_options.beta) # self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta)
if input_options.gather: # if input_options.act is not None:
self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta) # self.act_x = activation_resolver(
# input_options.act, **(input_options.act_kwargs or {}))
if input_options.act is not None:
self.act_x = activation_resolver( # if input_options.norm is not None:
input_options.act, **(input_options.act_kwargs or {})) # self.norm_x = normalization_resolver(
# input_options.norm, in_channels, **(input_options.norm_kwargs or {}))
if input_options.norm is not None:
self.norm_x = normalization_resolver( # # initialize activation layers
input_options.norm, in_channels, **(input_options.norm_kwargs or {})) # if layer_options.act is not None:
# self.acts = nn.ModuleList()
# initialize activation layers # for _ in range(num_layers - 1):
if layer_options.act is not None: # self.acts.append(activation_resolver(
self.acts = nn.ModuleList() # layer_options.act, **(layer_options.act_kwargs or {})))
for _ in range(num_layers - 1): # if layer_options.jk_mode is not None:
self.acts.append(activation_resolver( # self.acts.append(activation_resolver(
layer_options.act, **(layer_options.act_kwargs or {}))) # layer_options.act, **(layer_options.act_kwargs or {})))
if jk_options.jk_mode is not None:
self.acts.append(activation_resolver( # # initialize normalization layers
layer_options.act, **(layer_options.act_kwargs or {}))) # if layer_options.norm is not None:
# self.norms = nn.ModuleList()
# initialize normalization layers # for _ in range(num_layers - 1):
if layer_options.norm is not None: # self.norms.append(normalization_resolver(
self.norms = nn.ModuleList() # layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
for _ in range(num_layers - 1): # if layer_options.jk_mode is not None:
self.norms.append(normalization_resolver( # self.norms.append(normalization_resolver(
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {}))) # layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
if jk_options.jk_mode is not None:
self.norms.append(normalization_resolver( # # initialize straight layers
layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {}))) # if straight_options.enabled:
# self.straights = nn.ModuleList()
# initialize straight layers # for _ in range(num_layers):
if straight_options.enabled: # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
self.straights = nn.ModuleList() # self.straights.append(prev_straight[0])
for _ in range(num_layers - 1): # # if layer_options.jk_mode is not None:
self.straights.append(Straight( # # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
g.dst_size, p=straight_options.p, beta=straight_options.beta)) # # self.straights.append(prev_straight[0])
if jk_options.jk_mode is not None:
self.straights.append(Straight( # # initialize gather and conv layers
g.dst_size, p=straight_options.p, beta=straight_options.beta)) # self.convs = nn.ModuleList()
# self.gathers = nn.ModuleList()
# initialize gather and conv layers # for _ in range(num_layers - 1):
self.convs = nn.ModuleList() # self.convs.append(
self.gathers = nn.ModuleList() # self.init_conv(in_channels, hidden_channels, **kwargs))
for _ in range(num_layers - 1): # self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta))
self.convs.append( # in_channels = hidden_channels
self.init_conv(in_channels, hidden_channels, **kwargs))
self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta)) # if layer_options.jk_mode is None:
in_channels = hidden_channels # self.convs.append(
# self.init_conv(in_channels, out_channels, **kwargs))
if jk_options.jk_mode is None: # self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings
self.convs.append( # else:
self.init_conv(in_channels, out_channels, **kwargs)) # self.convs.append(
self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings # self.init_conv(in_channels, hidden_channels, **kwargs))
else: # self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs)) # if layer_options.jk_mode != "last":
self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings # self.jk = JumpingKnowledge(layer_options.jk_mode, hidden_channels, num_layers)
if jk_options.jk_mode != "last": # if layer_options.jk_mode == "cat":
self.jk = JumpingKnowledge(jk_options.jk_mode, hidden_channels, num_layers) # jk_channels = num_layers * hidden_channels
# else:
if jk_options.jk_mode == "cat": # jk_channels = hidden_channels
jk_channels = num_layers * hidden_channels # self.lin_jk = nn.Linear(jk_channels, out_channels)
else:
jk_channels = hidden_channels # self.reset_parameters()
self.lin_jk = nn.Linear(jk_channels, out_channels)
# def init_conv(self,
self.reset_parameters() # in_channels: int,
# out_channels: int,
def init_conv(self, # **kwargs
in_channels: int, # ) -> nn.Module:
out_channels: int, # raise NotImplementedError
**kwargs
) -> nn.Module: # def reset_parameters(self):
raise NotImplementedError # if hasattr(self, "lin_x"):
# self.lin_x.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lin_x"): # if hasattr(self, "straight_x"):
self.lin_x.reset_parameters() # self.straight_x.reset_parameters()
if hasattr(self, "straight_x"): # if hasattr(self, "gather_x"):
self.straight_x.reset_parameters() # self.gather_x.reset_parameters()
if hasattr(self, "gather_x"): # if hasattr(self, "norm_x"):
self.gather_x.reset_parameters() # self.norm_x.reset_parameters()
if hasattr(self, "norm_x"): # if hasattr(self, "norms"):
self.norm_x.reset_parameters() # for norm in self.norms:
# norm.reset_parameters()
if hasattr(self, "norms"):
for norm in self.norms: # if hasattr(self, "straights"):
norm.reset_parameters() # for straight in self.straights:
# straight.reset_parameters()
if hasattr(self, "straights"):
for straight in self.straights: # for conv in self.convs:
straight.reset_parameters() # conv.reset_parameters()
for conv in self.convs: # for gather in self.gathers:
conv.reset_parameters() # gather.reset_parameters()
for gather in self.gathers: # if hasattr(self, "jk"):
gather.reset_parameters() # self.jk.reset_parameters()
if hasattr(self, "jk"): # if hasattr(self, "lin_jk"):
self.jk.reset_parameters() # self.lin_jk.reset_parameters()
if hasattr(self, "lin_jk"): # def forward(
self.lin_jk.reset_parameters() # self,
# g: DistGraph,
def forward( # ):
self, # from ..utils.printer import main_print, sync_print
g: DistGraph,
): # x = g.ndata["x"]
# from ..utils.printer import main_print # async_op = g.args.get("async_op", False)
x = g.ndata["x"] # # input layer
# main_print(f"features_x: {x.size()}") # if hasattr(self, "straight_x"):
# dst_idx, _ = self.straight_x.pop_next_shrink_helper()
sample_k = g.args.get("sample_k", 0) # else:
async_op = g.args.get("async_op", False) # dst_idx = None
# input layer # if dst_idx is not None:
if hasattr(self, "gather_x") and self.input_options.gather_first: # x = x[dst_idx]
x, m = self.gather_x(x, None, g.route, async_op=async_op)
# if hasattr(self, "gather_x") and self.input_options.gather_first:
if self.input_options.dropout != 0.0: # x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# main_print(f"x_x: {x.size()} {x.requires_grad}") # if self.input_options.dropout != 0.0:
if hasattr(self, "lin_x"): # x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# main_print("enter lin_x")
x = self.lin_x(x) # if hasattr(self, "lin_x"):
# main_print(f"lin_x: {x.size()} {x.requires_grad}") # x = self.lin_x(x)
if hasattr(self, "act_x") and self.input_options.act_first:
x = self.act_x(x) # if hasattr(self, "act_x") and self.input_options.act_first:
if hasattr(self, "norm_x"): # x = self.act_x(x)
x = self.norm_x(x) # if hasattr(self, "norm_x"):
if hasattr(self, "act_x") and not self.input_options.act_first: # x = self.norm_x(x)
x = self.act_x(x) # if hasattr(self, "act_x") and not self.input_options.act_first:
# x = self.act_x(x)
# main_print(f"act_x: {x.size()} {x.requires_grad}")
# # straight sampler
# straight sampler # if hasattr(self, "straight_x"):
if hasattr(self, "straight_x"): # # sync_print(f"{x.size()} - {'' if dst_idx is None else dst_idx.size()}")
x = self.straight_x(x) # x = self.straight_x(x, dst_idx, g)
m = self.straight_x.multinomial_mask(sample_k)
else: # # gather features
m = None # if hasattr(self, "gather_x") and not self.input_options.gather_first:
# main_print(f"straight_x: {x.size()} {'' if m is None else m.size()} {x.requires_grad}") # x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# gather features # # conv layers
if hasattr(self, "gather_x") and not self.input_options.gather_first: # xs: List[Tensor] = []
x, m = self.gather_x(x, m, g.route, async_op=async_op) # for i in range(self.num_layers):
# if self.layer_options.dropout != 0.0:
# main_print(f"rg {0}: {x.requires_grad}") # x = F.dropout(x, p=self.layer_options.dropout, training=self.training)
# conv layers # if hasattr(self, "straights"):
xs: List[Tensor] = [] # straight: Straight = self.straights[i]
for i in range(self.num_layers): # dst_idx, sh = straight.pop_next_shrink_helper()
if self.layer_options.dropout != 0.0:
x = F.dropout(x, p=self.layer_options.dropout, training=self.training) # with g.scoped_manager():
# g.ndata["x"] = x
with g.scoped_manager(): # x = self.convs[i](g, sh=sh, dst_idx=dst_idx)
g.ndata["x"] = x
if m is not None: # if i == self.num_layers - 1 and not hasattr(self, "jk"):
g.ndata["m"] = m # x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
x, m = self.convs[i](g) # break
# if hasattr(self, "acts") and self.layer_options.act_first:
# main_print(f"rg {i+1}: {x.requires_grad}") # x = self.acts[i](x)
# if hasattr(self, "norms"):
if i == self.num_layers - 1 and not hasattr(self, "jk"): # x = self.norms[i](x)
x = self.gathers[i].fuse_embeddings(x, m, inplace=True) # if hasattr(self, "acts") and not self.layer_options.act_first:
break # x = self.acts[i](x)
if hasattr(self, "acts") and self.layer_options.act_first: # if hasattr(self, "straights"):
x = self.acts[i](x) # x = self.straights[i](x, dst_idx, g)
if hasattr(self, "norms"):
x = self.norms[i](x) # if i == self.num_layers - 1:
if hasattr(self, "acts") and not self.layer_options.act_first: # x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
x = self.acts[i](x) # else:
if hasattr(self, "straights"): # x, _ = self.gathers[i](x, dst_idx, g.route, async_op=async_op)
x = self.straights[i](x, m)
# if hasattr(self, "jk"):
if i == self.num_layers - 1: # xs.append(x[:g.dst_size])
x = self.gathers[i].fuse_embeddings(x, m, inplace=True)
else: # x = self.jk(xs) if hasattr(self, "jk") else x
x, m = self.gathers[i](x, m, g.route, async_op=async_op) # x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
if hasattr(self, "jk"): # # sync_print(f"out: {x.size()}")
xs.append(x[:g.dst_size]) # return x
x = self.jk(xs) if hasattr(self, "jk") else x
x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
# main_print(f"out: {x.size()}")
return x
from .gcn_conv import GCNConv from .gcn_conv import GCNConv
from .gat_conv import GATConv from .gat_conv import GATConv
from .gin_conv import GINConv from .gin_conv import GINConv
from .shrink_gcn_conv import ShrinkGCNConv # from .shrink_gcn_conv import ShrinkGCNConv
from .shrink_gat_conv import ShrinkGATConv # from .shrink_gat_conv import ShrinkGATConv
from .shrink_gin_conv import ShrinkGINConv # from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
...@@ -10,7 +10,7 @@ from typing import * ...@@ -10,7 +10,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gat_conv import GATConv from .gat_conv import GATConv
from .utils import ShrinkData from .utils import ShrinkHelper
...@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv): ...@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels H, C = self.heads, self.out_channels
x = g.ndata["x"] x = g.ndata["x"]
s = ShrinkData(g, m) src_x = x[sh.src_idx]
x = x[s.src_mask] dst_x = x[sh.dst_idx]
edge_index = s.edge_index edge_index = sh.edges
x = (x @ self.weight).view(-1, H, C) src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1) alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]] alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1) alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]] alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None: if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"] edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[s.edge_mask] edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1: if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1) edge_attr = edge_attr.view(-1, 1)
...@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv): ...@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv):
alpha = softmax( alpha = softmax(
src=alpha, src=alpha,
index=edge_index[1], index=edge_index[1],
num_nodes=s.dst_size, num_nodes=sh.dst_size,
) )
alpha = F.dropout(alpha, p=self.dropout, training=self.training) alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1) x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size) x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat: if self.concat:
x = x.view(-1, H * C) x = x.view(-1, H * C)
...@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv): ...@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv):
if self.bias is not None: if self.bias is not None:
x += self.bias x += self.bias
return x, s.dst_mask return x
...@@ -8,7 +8,7 @@ from typing import * ...@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gcn_conv import GCNConv from .gcn_conv import GCNConv
from .utils import ShrinkData from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv): class ShrinkGCNConv(GCNConv):
def __init__(self, def __init__(self,
...@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv): ...@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"] x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1) gcn_norm = g.edata["gcn_norm"].view(-1, 1)
s = ShrinkData(g, m) x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
x = x[s.src_mask] edge_index = sh.edges
gcn_norm = gcn_norm[s.edge_mask]
edge_index = s.edge_index
x = x @ self.weight x = x @ self.weight
x = x[edge_index[0]] * gcn_norm x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size) x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None: if self.bias is not None:
x += self.bias x += self.bias
return x, s.dst_mask return x
...@@ -8,7 +8,7 @@ from typing import * ...@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph from starrygl.graph import DistGraph
from .gin_conv import GINConv from .gin_conv import GINConv
from .utils import ShrinkData from .utils import ShrinkHelper
...@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv): ...@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv):
**kwargs **kwargs
) )
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]: def forward(self,
m = g.ndata.get("m") g: DistGraph,
if m is None: sh: Optional[ShrinkHelper] = None,
return super().forward(g), None dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"] x = g.ndata["x"]
s = ShrinkData(g, m) sh = ShrinkHelper(g, dst_idx)
edge_index = s.edge_index
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
z = x[s.src_mask][edge_index[0]] z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
z = scatter_sum(z, edge_index[1], dim=0, dim_size=s.dst_size) x = z + (1 + self.eps) * dst_x
x = z + (1 + self.eps) * x[s.dst_mask] return self.nn(x)
return self.nn(x), s.dst_mask \ No newline at end of file
\ No newline at end of file
import torch # import torch
from torch import Tensor # from torch import Tensor
from typing import * # from typing import *
from starrygl.graph import DistGraph
class ShrinkData: # class ShrinkHelper:
def __init__(self, g: DistGraph, m: Tensor) -> None: # def __init__(self, g, dst_idx: Tensor) -> None:
edge_index = g.edge_index # from starrygl.graph import DistGraph
# g: DistGraph = g
# src_to_dst edge's mask # self.device = dst_idx.device
m = m[edge_index[0]]
# dst's mask # dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
m = edge_index[1][m] # dst_m.index_fill_(0, dst_idx, 1)
dst_m: Tensor = torch.zeros(g.dst_size, dtype=torch.bool, device=m.device).index_fill_(0, m, 1)
self.dst_mask = dst_m
# dst_to_src edge's mask # edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
m = dst_m[edge_index[1]] # edge_index = g.edge_index[:, edge_idx]
self.edge_mask = m
# src's mask # src_idx = edge_index[0]
edge_index = edge_index[:,m] # src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
src_m: Tensor = torch.zeros(g.src_size, dtype=torch.bool, device=m.device).index_fill_(0, edge_index[0], 1) # src_m.index_fill_(0, src_idx, 1)
self.src_mask = src_m # src_idx = torch.where(src_m)[0]
self._src_size = src_m.count_nonzero() # imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
self._dst_size = dst_m.count_nonzero()
# print(f"{dst_m.size()}_{self._dst_size}") # imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# edge_index # imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
imp = torch.empty(g.src_size, dtype=torch.long, device=m.device) # dst = imp[edge_index[1]]
imp[src_m] = torch.arange(self._src_size, dtype=torch.long, device=m.device)
src = imp[edge_index[0]]
imp = torch.empty(g.dst_size, dtype=torch.long, device=m.device) # self.src_idx = src_idx
imp[dst_m] = torch.arange(self._dst_size, dtype=torch.long, device=m.device) # self.dst_idx = dst_idx
dst = imp[edge_index[1]] # self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
self.edge_index = torch.vstack([src, dst]) # @property
# def src_size(self) -> int:
@property # return self.src_idx.size(0)
def src_size(self) -> int:
return self._src_size.item()
@property
def dst_size(self) -> int:
return self._dst_size.item()
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
\ No newline at end of file
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from .degree import compute_degree, compute_gcn_norm from .degree import compute_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model( def convert_parallel_model(
net: nn.Module, net: nn.Module,
...@@ -20,7 +21,7 @@ def convert_parallel_model( ...@@ -20,7 +21,7 @@ def convert_parallel_model(
for name, buffer in net.named_buffers(): for name, buffer in net.named_buffers():
if name.endswith("last_embd"): if name.endswith("last_embd"):
continue continue
if name.endswith("grad_norm"): if name.endswith("last_w"):
continue continue
dist.broadcast(buffer, src=0) dist.broadcast(buffer, src=0)
return net return net
...@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device: ...@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device:
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
device = torch.device("cpu") device = torch.device("cpu")
get_executor() # initialize route executor
return device return device
\ No newline at end of file
...@@ -6,7 +6,7 @@ from typing import * ...@@ -6,7 +6,7 @@ from typing import *
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
from ..graph.distgraph import DistGraph from ..graph.distgraph import DistGraph
from ..core.gather import Gather # from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]: def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
......
...@@ -13,4 +13,4 @@ def main_print(*args, **kwargs): ...@@ -13,4 +13,4 @@ def main_print(*args, **kwargs):
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: if rank == 0:
print(*args, **kwargs) print(*args, **kwargs)
dist.barrier() # dist.barrier()
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment