Commit a4c4cadb by Wenjie Huang

with bugs

parent 06d10ed3
from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
from .gather import Gather
from .route import Route, GatherWork
\ No newline at end of file
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
from .cache import MessageCache, NodeProbe
\ No newline at end of file
......@@ -30,12 +30,14 @@ class Works:
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
) -> Works:
assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl():
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=True,
)
return Works(work)
......@@ -48,8 +50,8 @@ def all_to_all(
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i)
send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
works.push(recv_w, send_w)
output_tensor_list[rank][:] = input_tensor_list[rank]
......
import torch
from torch import Tensor
from typing import *
from multiprocessing.pool import ThreadPool
from .event import Event
class AsyncCopyWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self):
raise NotImplementedError
def get(self) -> Tensor:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class AsyncPushWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
data.index_copy_(0, index, values)
self._events[1].record()
def wait(self):
pass
def get(self):
pass
class AsyncPullWork(AsyncCopyWorkBase):
def __init__(self, data: Tensor, index: Tensor) -> None:
super().__init__()
assert data.device == index.device
self.set_events(
Event(use_cuda=index.is_cuda),
Event(use_cuda=index.is_cuda),
)
self._events[0].record()
self._val = data.index_select(0, index)
self._events[1].record()
def wait(self):
pass
def get(self) -> Tensor:
return self._val
class AsyncOffloadWork(AsyncCopyWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
def wait(self):
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
self._handle.wait()
def get(self) -> Tensor:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class AsyncCopyExecutor:
def __init__(self) -> None:
self._stream = torch.cuda.Stream()
self._executor = ThreadPool(processes=1)
@torch.no_grad()
def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
dst = torch.zeros(
size=(index.size(0),) + data.shape[1:],
dtype=data.dtype,
device=index.device,
)
dst.copy_(data[idx])
end.record()
return dst
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
# data and index in the same device
return AsyncPullWork(data, index)
@torch.no_grad()
def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
assert index.device == values.device
if data.device != index.device:
assert not data.is_cuda
assert index.is_cuda
start = Event(index.is_cuda)
end = Event(index.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
idx = index.to(data.device)
val = values.to(data.device)
data[idx] = val
end.record()
start.record()
handle = self._executor.apply_async(run)
return AsyncOffloadWork(handle).set_events(start, end)
else:
return AsyncPushWork(data, index, values)
_THREAD_EXEC: Optional[AsyncCopyExecutor] = None
def get_executor() -> AsyncCopyExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = AsyncCopyExecutor()
return _THREAD_EXEC
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
import torch
from torch import Tensor
from typing import *
class ShrinkData:
no_idx: int = (2**62-1)*2+1
def __init__(self,
src_size: int,
dst_size: int,
dst_idx: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> None:
device = dst_idx.device
tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
tmp.fill_(0)
tmp.index_fill_(0, dst_idx, 1)
edge_idx = torch.where(tmp[edge_index[1]])[0]
edge_index = edge_index[:, edge_idx]
if bipartite:
tmp.fill_(0)
tmp.index_fill_(0, edge_index[0], 1)
src_idx = torch.where(tmp)[0]
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
dst = imp[edge_index[1]]
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
src = imp[edge_index[0]]
edge_index = torch.vstack([src, dst])
else:
tmp.index_fill_(0, edge_index[0], 1)
tmp.index_fill_(0, dst_idx, 0)
src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
imp.fill_(self.no_idx)
imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
self.src_idx = src_idx
self.dst_idx = dst_idx
self.edge_idx = edge_idx
self.edge_index = edge_index
self._src_imp = imp[:src_size]
def to(self, device):
self.src_idx = self.src_idx.to(device)
self.dst_idx = self.dst_idx.to(device)
self.edge_idx = self.edge_idx.to(device)
self.edge_index = self.edge_index.to(device)
self._src_imp = self._src_imp.to(device)
return self
def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
idx = self._src_imp[idx]
m = (idx != self.no_idx)
return val[m], idx[m]
@property
def src_size(self) -> int:
return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self.dst_idx.size(0)
@property
def edge_size(self) -> int:
return self.edge_idx.size(0)
......@@ -5,26 +5,23 @@ from torch import Tensor
from typing import *
from contextlib import contextmanager
import torch_sparse
# import torch_sparse
from .ndata import NData
from .edata import EData
from .utils import init_local_edge_index
from ..core import Route
from ..core import MessageCache, Route
class DistGraph:
def __init__(self,
ids: Tensor,
edge_index: Tensor,
args: Dict[str, Any] = {},
device: Union[torch.device, str, int, None] = None,
num_features: int,
num_layers: int,
cache_device: str = "cpu",
**args: Dict[str, Any],
):
# graph's target device
if device is None:
device = ids.device
self.device = torch.device(device)
# build local_edge_index
dst_ids = ids
src_ids, local_edge_index = init_local_edge_index(
......@@ -34,8 +31,15 @@ class DistGraph:
self._src_ids = src_ids
self._dst_ids = dst_ids
self._local_edge_index = local_edge_index
self._local_edge_ptr: Optional[Tensor] = None
self._message_cache = MessageCache(
src_ids=src_ids,
dst_ids=dst_ids,
edge_index=local_edge_index,
num_features=num_features,
num_layers=num_layers,
cache_device=cache_device,
bipartite=False,
)
# node's attributes
self.ndata = NData(
......@@ -49,24 +53,30 @@ class DistGraph:
)
# graph's attributes
self.args = args
self.args = dict(args)
def to(self, device):
self._message_cache.to(device)
return self
def cache_data_to(self, device):
self._message_cache.cached_data_to(device)
@property
def cache(self) -> MessageCache:
return self._message_cache
# route table
self.route = Route(dst_ids, src_ids)
@property
def route(self) -> Route:
return self._message_cache.route
@property
def edge_index(self) -> Tensor:
if self._local_edge_index.device != self.device:
self._local_edge_index = self._local_edge_index.to(self.device)
return self._local_edge_index
def device(self) -> torch.device:
return self.edge_index.device
@property
def edge_ptr(self) -> Optional[Tensor]:
if self._local_edge_ptr is None:
return None
if self._local_edge_ptr.device != self.device:
self._local_edge_ptr = self._local_edge_ptr.to(self.device)
return self._local_edge_ptr
def edge_index(self) -> Tensor:
return self._message_cache.edge_index
@property
def src_ids(self) -> Tensor:
......@@ -88,16 +98,6 @@ class DistGraph:
def dst_size(self) -> int:
return self._dst_ids.size(0)
@torch.no_grad()
def gather(self, values: Tensor, direct: str = "dst_to_src") -> Tensor:
if direct == "dst_to_src":
work = self.route.gather_forward(values, None, async_op=False, return_active=False)
elif direct == "src_to_dst":
work = self.route.gather_backward(values, None, async_op=False, return_active=False)
else:
raise ValueError(f"direct must be 'src_to_dst' or 'dst_to_src', but got '{direct}'")
return work.get_values()
@contextmanager
def scoped_manager(self):
stacked_ndata = self.ndata
......
......@@ -6,6 +6,7 @@ from typing import *
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
......@@ -17,17 +18,26 @@ def init_local_edge_index(
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
local_edge_index = xmp[edge_index.flatten()].view_as(edge_index)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
......
from .convs import GCNConv, GATConv, GINConv
from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicJKOptions, BasicStraightOptions
# from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
# from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
class GCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGCNConv(in_channels, out_channels, **kwargs)
# class ShrinkGCN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGCNConv(in_channels, out_channels, **kwargs)
class GAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGATConv(in_channels, out_channels, **kwargs)
# class ShrinkGAT(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGATConv(in_channels, out_channels, **kwargs)
class GIN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return ShrinkGINConv(in_channels, out_channels, **kwargs)
# class ShrinkGIN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
from .shrink_gcn_conv import ShrinkGCNConv
from .shrink_gat_conv import ShrinkGATConv
from .shrink_gin_conv import ShrinkGINConv
# from .shrink_gcn_conv import ShrinkGCNConv
# from .shrink_gat_conv import ShrinkGATConv
# from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
......@@ -10,7 +10,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gat_conv import GATConv
from .utils import ShrinkData
from .utils import ShrinkHelper
......@@ -38,29 +38,36 @@ class ShrinkGATConv(GATConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
H, C = self.heads, self.out_channels
x = g.ndata["x"]
s = ShrinkData(g, m)
x = x[s.src_mask]
edge_index = s.edge_index
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
x = (x @ self.weight).view(-1, H, C)
src_x = (src_x @ self.weight).view(-1, H, C)
dst_x = (dst_x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_i = (src_x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = (dst_x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
edge_attr = edge_attr[s.edge_mask]
edge_attr = edge_attr[sh.edge_idx]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
......@@ -73,13 +80,13 @@ class ShrinkGATConv(GATConv):
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=s.dst_size,
num_nodes=sh.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.concat:
x = x.view(-1, H * C)
......@@ -88,5 +95,5 @@ class ShrinkGATConv(GATConv):
if self.bias is not None:
x += self.bias
return x, s.dst_mask
return x
......@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gcn_conv import GCNConv
from .utils import ShrinkData
from .utils import ShrinkHelper
class ShrinkGCNConv(GCNConv):
def __init__(self,
......@@ -24,24 +24,28 @@ class ShrinkGCNConv(GCNConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
s = ShrinkData(g, m)
x = x[s.src_mask]
gcn_norm = gcn_norm[s.edge_mask]
edge_index = s.edge_index
x = x[sh.src_idx]
gcn_norm = gcn_norm[sh.edge_idx]
edge_index = sh.edges
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=s.dst_size)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
if self.bias is not None:
x += self.bias
return x, s.dst_mask
return x
......@@ -8,7 +8,7 @@ from typing import *
from starrygl.graph import DistGraph
from .gin_conv import GINConv
from .utils import ShrinkData
from .utils import ShrinkHelper
......@@ -30,17 +30,25 @@ class ShrinkGINConv(GINConv):
**kwargs
)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
m = g.ndata.get("m")
if m is None:
return super().forward(g), None
def forward(self,
g: DistGraph,
sh: Optional[ShrinkHelper] = None,
dst_idx: Optional[Tensor] = None,
) -> Tensor:
if sh is None and dst_idx is None:
return super().forward(g)
if sh is None:
sh = ShrinkHelper(g, dst_idx)
x = g.ndata["x"]
s = ShrinkData(g, m)
edge_index = s.edge_index
sh = ShrinkHelper(g, dst_idx)
src_x = x[sh.src_idx]
dst_x = x[sh.dst_idx]
edge_index = sh.edges
z = x[s.src_mask][edge_index[0]]
z = scatter_sum(z, edge_index[1], dim=0, dim_size=s.dst_size)
x = z + (1 + self.eps) * x[s.dst_mask]
return self.nn(x), s.dst_mask
\ No newline at end of file
z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
x = z + (1 + self.eps) * dst_x
return self.nn(x)
\ No newline at end of file
import torch
# import torch
from torch import Tensor
from typing import *
# from torch import Tensor
# from typing import *
from starrygl.graph import DistGraph
class ShrinkData:
def __init__(self, g: DistGraph, m: Tensor) -> None:
edge_index = g.edge_index
# src_to_dst edge's mask
m = m[edge_index[0]]
# class ShrinkHelper:
# def __init__(self, g, dst_idx: Tensor) -> None:
# from starrygl.graph import DistGraph
# g: DistGraph = g
# self.device = dst_idx.device
# dst's mask
m = edge_index[1][m]
dst_m: Tensor = torch.zeros(g.dst_size, dtype=torch.bool, device=m.device).index_fill_(0, m, 1)
self.dst_mask = dst_m
# dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
# dst_m.index_fill_(0, dst_idx, 1)
# dst_to_src edge's mask
m = dst_m[edge_index[1]]
self.edge_mask = m
# edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
# edge_index = g.edge_index[:, edge_idx]
# src's mask
edge_index = edge_index[:,m]
src_m: Tensor = torch.zeros(g.src_size, dtype=torch.bool, device=m.device).index_fill_(0, edge_index[0], 1)
self.src_mask = src_m
# src_idx = edge_index[0]
# src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
# src_m.index_fill_(0, src_idx, 1)
# src_idx = torch.where(src_m)[0]
self._src_size = src_m.count_nonzero()
self._dst_size = dst_m.count_nonzero()
# imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
# print(f"{dst_m.size()}_{self._dst_size}")
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# edge_index
imp = torch.empty(g.src_size, dtype=torch.long, device=m.device)
imp[src_m] = torch.arange(self._src_size, dtype=torch.long, device=m.device)
src = imp[edge_index[0]]
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
# dst = imp[edge_index[1]]
imp = torch.empty(g.dst_size, dtype=torch.long, device=m.device)
imp[dst_m] = torch.arange(self._dst_size, dtype=torch.long, device=m.device)
dst = imp[edge_index[1]]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
self.edge_index = torch.vstack([src, dst])
@property
def src_size(self) -> int:
return self._src_size.item()
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
@property
def dst_size(self) -> int:
return self._dst_size.item()
\ No newline at end of file
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
\ No newline at end of file
......@@ -7,6 +7,7 @@ import os
from .degree import compute_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
from ..core import with_gloo, with_nccl
from ..core.route import get_executor
def convert_parallel_model(
net: nn.Module,
......@@ -20,7 +21,7 @@ def convert_parallel_model(
for name, buffer in net.named_buffers():
if name.endswith("last_embd"):
continue
if name.endswith("grad_norm"):
if name.endswith("last_w"):
continue
dist.broadcast(buffer, src=0)
return net
......@@ -37,4 +38,6 @@ def init_process_group(backend: str = "gloo") -> torch.device:
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
get_executor() # initialize route executor
return device
\ No newline at end of file
......@@ -6,7 +6,7 @@ from typing import *
from torch_scatter import scatter_sum
from ..graph.distgraph import DistGraph
from ..core.gather import Gather
# from ..core.gather import Gather
def compute_degree(g: DistGraph) -> Tuple[Tensor, Tensor]:
......
......@@ -22,7 +22,7 @@ def train_epoch(
if mask is not None:
pred = pred[mask]
targ = targ[mask]
loss: Tensor = criterion(pred, targ)
opt.zero_grad()
......
......@@ -13,4 +13,4 @@ def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
dist.barrier()
\ No newline at end of file
# dist.barrier()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment