Commit 3c22398e by Wenjie Huang

add sources

parents
### PPGNN
分布式分区并行图神经网络实验代码,目前包括以下特性:
1. 同时支持CPU和GPU训练
2. 模型定义简洁,没有额外优化
3. 支持图数据点分割和边分割
### 目录结构
1. /dataset 里面包含数据处理代码
2. /core 包含分布式索引,分布式图数据结构等相关代码
3. /utils 包含一些常用的处理函数,包括数据集加载等
4. /model 包含模型的定义等
### 说明
这份代码是早起用于实验分布式分区并行所用,除了模型定义部分,训练推理接口和StarryGL静态图分区并行模式没有区别。StarryGL静态图组件目前仅支持GPU训练,仅支持点分割算法,重构了分布式通信和模型定义过程,暂时需要多块GPU卡进行调试。
from core.distributed.a2a import Work
from core.distributed.a2a import all_to_all, all_to_all_v
from core.distributed.pgraph import PGraph
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from abc import *
from typing import *
class Work:
def __init__(self) -> None:
pass
@abstractmethod
def wait(self) -> None:
raise NotImplementedError
class NcclAllToAllWork(Work):
def __init__(self,
work: Any,
) -> None:
super().__init__()
self._work = work
def wait(self) -> None:
self._work.wait()
class GlooAllToAllWork(Work):
def __init__(self,
recv_works: List[Any],
send_works: List[Any],
output_tensor_list: List[Tensor],
buffer_tensor_list: List[Tensor],
) -> None:
super().__init__()
self.recv_works = recv_works
self.send_works = send_works
self.output_tensor_list = output_tensor_list
self.buffer_tensor_list = buffer_tensor_list
def wait(self) -> None:
rank = dist.get_rank()
world_size = dist.get_world_size()
for i in range(1, world_size):
recv_i = (rank - i + world_size) % world_size
self.recv_works[recv_i].wait()
output_device = self.output_tensor_list[recv_i].device
self.output_tensor_list[recv_i][:] = self.buffer_tensor_list[recv_i].to(output_device)
for i in range(1, world_size):
send_i = (rank + i) % world_size
self.send_works[send_i].wait()
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
async_op: bool = False,
) -> Optional[Work]:
assert len(output_tensor_list) == len(input_tensor_list)
backend = dist.get_backend()
if backend == "nccl":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
async_op=True,
)
work = NcclAllToAllWork(work)
return work if async_op else work.wait()
elif backend == "gloo":
rank = dist.get_rank()
world_size = dist.get_world_size()
input_tensor_list = list(input_tensor_list)
buffer_tensor_list = [None,] * world_size
for i in range(1, world_size):
k = (rank + i) % world_size
input_tensor_list[k] = input_tensor_list[k].cpu()
buffer_tensor_list[k] = torch.empty_like(output_tensor_list[k], device="cpu")
output_tensor_list[rank][:] = input_tensor_list[rank]
send_works = [None,] * world_size
recv_works = [None,] * world_size
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_works[send_i] = dist.isend(input_tensor_list[send_i], dst=send_i)
recv_works[recv_i] = dist.irecv(buffer_tensor_list[recv_i], src=recv_i)
work = GlooAllToAllWork(recv_works, send_works, output_tensor_list, buffer_tensor_list)
return work if async_op else work.wait()
else:
raise ValueError(f"unsupported backend: {backend}")
def all_to_all_v(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
async_op: bool = False,
) -> Optional[Work]:
assert len(output_tensor_list) == len(input_tensor_list)
input_tensor_sizes = []
for x in input_tensor_list:
assert x.dim() == 1
input_tensor_sizes.append(x.numel())
input_tensor_sizes = torch.tensor(
input_tensor_sizes, dtype=torch.long, device=input_tensor_list[0].device)
output_tensor_sizes = torch.empty_like(input_tensor_sizes)
dist.all_to_all_single(
output=output_tensor_sizes,
input=input_tensor_sizes,
)
for i, sz in enumerate(output_tensor_sizes.tolist()):
output_tensor_list[i] = torch.empty(
sz, dtype=input_tensor_list[i].dtype, device=input_tensor_list[i].device)
return all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
async_op=async_op,
)
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from torch_sparse import SparseTensor
from core.distributed.a2a import all_to_all_v
from torch.types import _device
from typing import *
class P2PIndex:
def __init__(self,
local_node_table: LongTensor,
local_edge_table: LongTensor,
local_edge_route: LongTensor,
device: Union[_device, str, None] = None,
) -> None:
self._init_device(device)
self._init_node_map(local_node_table)
self._init_edge_map(local_edge_table, local_edge_route)
def _init_device(self, device: Union[_device, str, None]) -> None:
if device is None:
self._device = torch.cuda.current_device()
else:
self._device = torch.device(device)
def _init_node_map(self,
local_node_table: LongTensor,
) -> None:
assert local_node_table.dim() == 1
ind2vid = local_node_table.to(self._device)
vid2ind = SparseTensor.from_edge_index(
edge_index=torch.vstack([torch.zeros_like(ind2vid), ind2vid]),
edge_attr=torch.arange(ind2vid.size(0)).type_as(ind2vid)
).coalesce("min")
self._nbuf_ind2vid: LongTensor = ind2vid
self._nbuf_vid2ind: SparseTensor = vid2ind
def _init_edge_map(self,
local_edge_table: LongTensor,
local_edge_route: LongTensor,
) -> None:
self._init_edge_buf(local_edge_table)
self._local_edge_index: LongTensor = self.ebuf_vid2ind(local_edge_table)
assert local_edge_route.dim() == 2
assert local_edge_route.size(0) == 2
vid, rot = local_edge_route.to(self._device)
vid2rot = SparseTensor.from_edge_index(
edge_index=torch.vstack([torch.zeros_like(vid), vid]),
edge_attr=rot,
).coalesce("min")
_, vid, rot = vid2rot.csr()
assert vid.size(0) == self.ebuf.size(0)
ind = self.ebuf_vid2ind(vid)
ind, perm = ind.sort()
ind2rot = rot[perm]
self._ebuf_ind2rot: LongTensor = ind2rot
self._ebuf_vid2rot: SparseTensor = vid2rot
def _init_edge_buf(self,
local_edge_table: LongTensor,
) -> None:
assert local_edge_table.dim() == 2
assert local_edge_table.size(0) == 2
ind2vid = local_edge_table.to(self._device)
ind0 = torch.vstack([torch.zeros_like(ind2vid[0]), ind2vid[0]])
ind1 = torch.vstack([torch.zeros_like(ind2vid[1]), ind2vid[1]])
inds = torch.hstack([ind0, ind1])
val0 = torch.zeros_like(ind2vid[0])
val1 = torch.ones_like(ind2vid[1])
vals = torch.hstack([val0, val1])
_, col, tag = SparseTensor.from_edge_index(inds, vals).coalesce("max").csr()
col0 = col[tag == 0]
col1 = col[tag == 1]
ind2vid = torch.hstack([col1, col0])
vid2ind = SparseTensor.from_edge_index(
edge_index=torch.vstack([torch.zeros_like(ind2vid), ind2vid]),
edge_attr=torch.arange(ind2vid.size(0)).type_as(ind2vid)
).coalesce("min")
self._ebuf_sp_size: int = col1.size(0)
self._ebuf_ind2vid: LongTensor = ind2vid
self._ebuf_vid2ind: SparseTensor = vid2ind
@property
def local_edge_index(self) -> LongTensor:
return self._local_edge_index
@property
def nbuf(self) -> LongTensor:
return self._nbuf_ind2vid
@property
def ebuf(self) -> LongTensor:
return self._ebuf_ind2vid
@property
def ebuf_sp_size(self) -> int:
return self._ebuf_sp_size
def nbuf_ind2vid(self, ind: LongTensor) -> LongTensor:
return self._nbuf_ind2vid[ind.view(-1)].view(*ind.size())
def nbuf_vid2ind(self, vid: LongTensor) -> LongTensor:
return self._nbuf_vid2ind[0, vid.view(-1)].csr()[2].view(*vid.size())
def ebuf_ind2vid(self, ind: LongTensor) -> LongTensor:
return self._ebuf_ind2vid[ind.view(-1)].view(*ind.size())
def ebuf_vid2ind(self, vid: LongTensor) -> LongTensor:
return self._ebuf_vid2ind[0, vid.view(-1)].csr()[2].view(*vid.size())
def ebuf_ind2rot(self, ind: LongTensor) -> LongTensor:
return self._ebuf_ind2rot[ind.view(-1)].view(*ind.size())
def ebuf_vid2rot(self, vid: LongTensor) -> LongTensor:
return self._ebuf_vid2rot[0, vid.view(-1)].csr()[2].view(*vid.size())
def gather_index(self, src: bool, dst: bool) -> Tuple[int, List[LongTensor], List[LongTensor]]:
assert src or dst
world_size = dist.get_world_size()
dind = []
if src: dind.append(self._local_edge_index[0])
if dst: dind.append(self._local_edge_index[1])
dind: LongTensor = torch.hstack(dind).unique()
dvid = self.ebuf_ind2vid(dind)
drot = self.ebuf_ind2rot(dind)
recv_ind = []
recv_vid = []
for i in range(world_size):
mask = (drot == i)
recv_ind.append(dind[mask])
recv_vid.append(dvid[mask])
send_vid = [None] * len(recv_vid)
all_to_all_v(send_vid, recv_vid)
send_ind = []
for i in range(world_size):
send_ind.append(self.nbuf_vid2ind(send_vid[i]))
dst_size = dind.max().item() + 1
return dst_size, send_ind, recv_ind
def scatter_index(self, src: bool, dst: bool) -> Tuple[int, List[LongTensor], List[LongTensor]]:
assert src or dst
world_size = dist.get_world_size()
sind = []
if src: sind.append(self._local_edge_index[0])
if dst: sind.append(self._local_edge_index[1])
sind: LongTensor = torch.hstack(sind).unique()
svid = self.ebuf_ind2vid(sind)
srot = self.ebuf_ind2rot(sind)
send_ind = []
send_vid = []
for i in range(world_size):
mask = (srot == i)
send_ind.append(sind[mask])
send_vid.append(svid[mask])
recv_vid = [None] * len(send_vid)
all_to_all_v(recv_vid, send_vid)
recv_ind = []
for i in range(world_size):
recv_ind.append(self.nbuf_vid2ind(recv_vid[i]))
src_size = sind.max().item() + 1
return src_size, send_ind, recv_ind
@property
def device(self) -> _device:
return self._device
def to(self, device: Union[_device, str]):
self._device = torch.device(device)
for key in dir(self):
val = getattr(self, key)
if isinstance(val, Tensor):
setattr(self, key, val.to(self._device))
elif isinstance(val, SparseTensor):
setattr(self, key, val.to(self._device))
return self
def cuda(self):
device = torch.cuda.current_device()
return self.to(device)
def cpu(self):
device = torch.device("cpu")
return self.to(device)
\ No newline at end of file
import torch
import torch.autograd as autograd
from torch import Tensor, LongTensor
from typing import *
from torch_scatter import scatter
from core.distributed.a2a import all_to_all, all_to_all_v
class ReducerOptions:
def __init__(self,
src_size: int,
dst_size: int,
send_inds: List[LongTensor],
recv_inds: List[LongTensor],
reduce_op: str,
) -> None:
self.src_size = src_size
self.dst_size = dst_size
self.send_inds = send_inds
self.recv_inds = recv_inds
self.reduce_op = reduce_op
def backward(self):
return ReducerOptions(
src_size=self.dst_size,
dst_size=self.src_size,
send_inds=self.recv_inds,
recv_inds=self.send_inds,
reduce_op="sum",
)
class Reducer(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
src: Tensor,
opt: ReducerOptions,
) -> Tensor:
assert src.size(0) == opt.src_size
send_data = []
for ind in opt.send_inds:
send_data.append(src[ind])
recv_data = []
for ind in opt.recv_inds:
recv_data.append(torch.empty(ind.size(0), *src.shape[1:], dtype=src.dtype, device=src.device))
all_to_all(recv_data, send_data)
recv_inds = torch.cat(opt.recv_inds, dim=0)
recv_data = torch.cat(recv_data, dim=0)
dst = scatter(recv_data, recv_inds, dim=0, dim_size=opt.dst_size)
ctx.saved_opt = opt.backward()
return dst
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad_output: Tensor,
):
return Reducer.apply(grad_output, ctx.saved_opt), None
class DiffReducerOptions:
def __init__(self,
acti_src_memory: Tensor,
acti_dst_memory: Tensor,
grad_src_memory: Tensor,
grad_dst_memory: Tensor,
send_inds: List[LongTensor],
recv_inds: List[LongTensor],
reduce_op: str,
) -> None:
self.acti_src_memory = acti_src_memory
self.acti_dst_memory = acti_dst_memory
self.grad_src_memory = grad_src_memory
self.grad_dst_memory = grad_dst_memory
self.send_inds = send_inds
self.recv_inds = recv_inds
self.reduce_op = reduce_op
def backward(self):
return DiffReducerOptions(
acti_src_memory=self.grad_dst_memory,
acti_dst_memory=self.grad_src_memory,
grad_src_memory=self.acti_dst_memory,
grad_dst_memory=self.acti_src_memory,
send_inds=self.recv_inds,
recv_inds=self.send_inds,
reduce_op="sum",
)
class DiffReducer(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
src: Tensor,
opt: DiffReducerOptions,
) -> Tensor:
assert src.size() == opt.acti_src_memory.size()
diff_src = src - opt.acti_src_memory
diff_src = diff_src.type(torch.float16)
send_data = []
for ind in opt.send_inds:
send_data.append(src[ind])
recv_data = [None] * len(opt.recv_inds)
all_to_all_v(recv_data, send_data)
recv_inds = torch.cat(opt.recv_inds, dim=0)
recv_data = torch.cat(recv_data, dim=0)
recv_data = recv_data.type_as(opt.acti_dst_memory) + opt.acti_dst_memory[recv_inds]
dst = scatter(recv_data, recv_inds, dim=0, dim_size=opt.acti_dst_memory.size(0))
ctx.saved_opt = opt.backward()
return dst
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad_output: Tensor,
):
return DiffReducer.apply(grad_output, ctx.saved_opt), None
\ No newline at end of file
import torch
import torch.distributed as dist
from torch_sparse import SparseTensor
from torch_scatter import scatter
from torch import Tensor, LongTensor
from torch.types import _device
from typing import *
from core.distributed.ind import P2PIndex
import core.distributed.ops as ops
import contextlib
class PGraph:
def __init__(self,
local_node_table: LongTensor,
local_edge_table: LongTensor,
local_edge_route: LongTensor,
device: Union[_device, str, None] = None,
) -> None:
self._p2p_index = P2PIndex(
local_node_table=local_node_table,
local_edge_table=local_edge_table,
local_edge_route=local_edge_route,
device=device
)
self._device = self._p2p_index.device
self._nbuf_size = self._p2p_index.nbuf.size(0)
self._ebuf_size = self._p2p_index.ebuf.size(0)
self._local_edge_index = self._p2p_index.local_edge_index#.clone()
def _gather_index_names(self, src: bool, dst: bool) -> Tuple[str, str, str]:
tag = f"{int(src)}{int(dst)}"
gbuf_size = f"_gather_{tag}_gbuf_size"
send_inds = f"_gather_{tag}_send_inds"
recv_inds = f"_gather_{tag}_recv_inds"
return gbuf_size, send_inds, recv_inds
def _scatter_index_names(self, src: bool, dst: bool) -> Tuple[str, str, str]:
tag = f"{int(src)}{int(dst)}"
sbuf_size = f"_scatter_{tag}_sbuf_size"
send_inds = f"_scatter_{tag}_send_inds"
recv_inds = f"_scatter_{tag}_recv_inds"
return sbuf_size, send_inds, recv_inds
def _gather_index(self, src: bool, dst: bool) -> Tuple[int, List[LongTensor], List[LongTensor]]:
gbuf_size_name, send_inds_name, recv_inds_name = self._gather_index_names(src, dst)
gbuf_size = getattr(self, gbuf_size_name)
send_inds = getattr(self, send_inds_name)
recv_inds = getattr(self, recv_inds_name)
return gbuf_size, send_inds, recv_inds
def _scatter_index(self, src: bool, dst: bool) -> Tuple[int, List[LongTensor], List[LongTensor]]:
sbuf_size_name, send_inds_name, recv_inds_name = self._scatter_index_names(src, dst)
sbuf_size = getattr(self, sbuf_size_name)
send_inds = getattr(self, send_inds_name)
recv_inds = getattr(self, recv_inds_name)
return sbuf_size, send_inds, recv_inds
def _buf_gather(self,
x: Tensor,
src: bool,
dst: bool,
reduce_op: str,
cache: bool,
) -> Tensor:
gbuf_size, send_inds, recv_inds = self._gather_index(src, dst)
assert x.size(0) == self._nbuf_size
assert reduce_op in {"sum", "max", "min"}
opt = ops.ReducerOptions(
src_size=x.size(0),
dst_size=gbuf_size,
send_inds=send_inds,
recv_inds=recv_inds,
reduce_op=reduce_op,
)
return ops.Reducer.apply(x, opt)
def _buf_scatter(self,
x: Tensor,
src: bool,
dst: bool,
reduce_op: str,
cache: bool,
) -> Tensor:
sbuf_size, send_inds, recv_inds = self._scatter_index(src, dst)
assert x.size(0) == sbuf_size
assert reduce_op in {"sum", "max", "min"}
opt = ops.ReducerOptions(
src_size=x.size(0),
dst_size=self._nbuf_size,
send_inds=send_inds,
recv_inds=recv_inds,
reduce_op=reduce_op,
)
return ops.Reducer.apply(x, opt)
def enable_gather(self, src: bool = False, dst: bool = False):
if not src and not dst:
src = True
gbuf_size_name, send_inds_name, recv_inds_name = self._gather_index_names(src, dst)
if hasattr(self, gbuf_size_name):
return self
gbuf_size, send_inds, recv_inds = self._p2p_index.gather_index(src, dst)
send_inds = [x.to(self._device) for x in send_inds]
recv_inds = [x.to(self._device) for x in recv_inds]
setattr(self, gbuf_size_name, gbuf_size)
setattr(self, send_inds_name, send_inds)
setattr(self, recv_inds_name, recv_inds)
return self
def enable_scatter(self, src: bool = False, dst: bool = False):
if not src and not dst:
dst = True
sbuf_size_name, send_inds_name, recv_inds_name = self._scatter_index_names(src, dst)
if hasattr(self, sbuf_size_name):
return self
sbuf_size, send_inds, recv_inds = self._p2p_index.scatter_index(src, dst)
send_inds = [x.to(self._device) for x in send_inds]
recv_inds = [x.to(self._device) for x in recv_inds]
setattr(self, sbuf_size_name, sbuf_size)
setattr(self, send_inds_name, send_inds)
setattr(self, recv_inds_name, recv_inds)
return self
def enable_all_gather(self):
self.enable_gather(src=True, dst=True)
self.enable_gather(src=True, dst=False)
self.enable_gather(src=False, dst=True)
return self
def enable_all_scatter(self):
self.enable_scatter(src=True, dst=True)
self.enable_scatter(src=True, dst=False)
self.enable_scatter(src=False, dst=True)
return self
def gather(self,
x: Tensor,
src: bool = True,
dst: bool = False,
cache: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
gbuf = self._buf_gather(x, src, dst, reduce_op="sum", cache=cache)
if src and dst:
m_j = gbuf[self._local_edge_index[0]]
m_i = gbuf[self._local_edge_index[1]]
return m_i, m_j
elif src and not dst:
m_j = gbuf[self._local_edge_index[0]]
return m_j
elif not src and dst:
m_i = gbuf[self._local_edge_index[1]]
return m_i
else:
raise RuntimeError("not src and not dst")
def scatter(self,
x: Tensor,
src: bool = False,
dst: bool = True,
reduce_op: str = "mean",
cache: bool = False,
eps: float = 1e-12,
) -> Tensor:
if src and dst:
ind = self._local_edge_index.view(-1)
elif src and not dst:
ind = self._local_edge_index[0]
elif not src and dst:
ind = self._local_edge_index[1]
else:
raise RuntimeError("not src and not dst")
if reduce_op == "mean":
w = torch.ones(x.size(0), 1, dtype=x.dtype, device=x.device)
w = scatter(w, ind, dim=0, reduce="sum")
w = self._buf_scatter(w, src, dst, reduce_op="sum", cache=cache)
x = scatter(x, ind, dim=0, reduce="sum")
x = self._buf_scatter(x, src, dst, reduce_op="sum", cache=cache)
return x / (w + eps)
else:
x = scatter(x, ind, dim=0, reduce=reduce_op)
return self._buf_scatter(x, src, dst, reduce_op=reduce_op, cache=cache)
def softmax(self,
a: Tensor,
src: bool = False,
dst: bool = True,
eps: float = 1e-12,
):
assert src or dst
with torch.no_grad():
m = self.scatter(a, src, dst, reduce_op="max")
if src and dst:
m_i, m_j = self.gather(m, src, dst)
m = m_i + m_j
else:
m = self.gather(m, src, dst)
a = (a - m).exp()
s = self.scatter(a, src, dst, reduce_op="sum")
if src and dst:
s_i, s_j = self.gather(s, src, dst)
s = s_i + s_j
else:
s = self.gather(s, src, dst)
return a / (s + eps)
@property
def device(self) -> _device:
return self._device
def to(self, device: Union[_device, str]):
self._device = torch.device(device)
for key in dir(self):
val = getattr(self, key)
if isinstance(val, (Tensor, SparseTensor)):
setattr(self, key, val.to(self._device))
elif isinstance(val, (list, tuple)):
new_val = []
for x in val:
if isinstance(x, (Tensor, SparseTensor)):
x = x.to(self._device)
new_val.append(x)
if isinstance(val, tuple):
new_val = tuple(new_val)
setattr(self, key, new_val)
return self
def cuda(self):
device = torch.cuda.current_device()
return self.to(device)
def cpu(self):
device = torch.device("cpu")
return self.to(device)
\ No newline at end of file
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp
import sys
sys.path.extend([
osp.dirname(osp.dirname(osp.abspath(__file__))),
])
from utils.partition import partition_save
if __name__ == "__main__":
data = Planetoid("/mnt/nfs/hwj/pyg_datasets/Planetoid/Cora", "Cora")[0]
if data.is_directed():
data.edge_index, _ = to_undirected(data.edge_index)
data.edge_index, _ = add_remaining_self_loops(data.edge_index)
print(f"num_nodes: {data.num_nodes}")
print(f"num_edges: {data.num_edges}")
print(f"num_features: {data.num_features}")
num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", "dbh"]
root = osp.splitext(osp.abspath(__file__))[0]
print(f"root: {root}")
for num_parts in num_parts_list:
for algo in algos:
print(f"======== {num_parts} + {algo} ========")
partition_save(root, data, num_parts, algo)
# for num_parts in num_parts_list:
# for algo in algos:
# print(f"======== {num_parts} + {algo} ========")
# stat = PartitionStatistic(num_parts, algo)(base_dir)
# node_balance = min(stat["node_count"]) / max(stat["node_count"])
# edge_balance = min(stat["edge_count"]) / max(stat["edge_count"])
# local_node_percent = max(stat["local_node_percent"])
# global_node_percent = max(stat["global_node_percent"])
# print(f"node balance: {node_balance:.2f}")
# print(f"edge balance: {edge_balance:.2f}")
# print(f"local node percent: {local_node_percent:.2f}")
# print(f"global node percent: {global_node_percent:.2f}")
\ No newline at end of file
from model.basic_gnn import BasicGNN
from model.convs import GCNConv, GATConv
class GCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return GCNConv(in_channels, out_channels)
class GAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return GATConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import copy
import inspect
from core.distributed import PGraph
from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
class BasicGNN(nn.Module):
r"""An abstract class for implementing basic GNN models.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int,
dropout: float = 0.0,
dropout_input: bool = True,
transform_input: bool = True,
hidden_channels: Optional[int] = None,
act: Optional[str] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None,
act_first: bool = False,
norm: Optional[str] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
jk: Optional[str] = None,
**kwargs,
):
super().__init__()
if hidden_channels is None:
hidden_channels = out_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.dropout = dropout
self.dropout_input = dropout_input
self.act_first = act_first
self.jk_mode = jk
if transform_input:
self.lin0 = nn.Linear(in_channels, hidden_channels)
in_channels = hidden_channels
if act is not None:
self.act = activation_resolver(act, **(act_kwargs or {}))
self.convs = nn.ModuleList()
if num_layers > 1:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
in_channels = hidden_channels
for _ in range(num_layers - 2):
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
in_channels = hidden_channels
if jk is None:
self.convs.append(
self.init_conv(in_channels, out_channels, **kwargs))
else:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
if jk != "last":
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
if jk == "cat":
jk_channels = num_layers * hidden_channels
else:
jk_channels = hidden_channels
self.lin_jk = nn.Linear(jk_channels, out_channels)
if norm is not None:
self.norms = nn.ModuleList()
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm_layer))
if jk is not None:
self.norms.append(copy.deepcopy(norm_layer))
self.reset_parameters()
def init_conv(self,
in_channels: int,
out_channels: int,
**kwargs
) -> nn.Module:
raise NotImplementedError
def reset_parameters(self):
if hasattr(self, "lin0"):
self.lin0.reset_parameters()
if hasattr(self, "jk"):
self.jk.reset_parameters()
if hasattr(self, "lin_jk"):
self.lin_jk.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
if hasattr(self, "norms"):
for norm in self.norms:
norm.reset_parameters()
def forward(
self,
g: PGraph,
x: Tensor,
*,
edge_attr: Optional[Tensor] = None,
) -> Tensor:
if self.dropout_input:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin0(x) if hasattr(self, "lin0") else x
xs: List[Tensor] = []
for i in range(self.num_layers):
argspec = inspect.getfullargspec(self.convs[i].forward)
if "edge_attr" in argspec.args:
x = self.convs[i](g, x, edge_attr=edge_attr)
else:
x = self.convs[i](g, x)
x = F.dropout(x, p=self.dropout, training=self.training)
if i == self.num_layers - 1 and self.jk_mode is None:
break
if hasattr(self, "act") and self.act_first:
x = self.act(x)
if hasattr(self, "norms"):
x = self.norms[i](x)
if hasattr(self, "act") and not self.act_first:
x = self.act(x)
# x = F.dropout(x, p=self.dropout, training=self.training)
if hasattr(self, 'jk'):
xs.append(x)
x = self.jk(xs) if hasattr(self, 'jk') else x
x = self.lin_jk(x) if hasattr(self, 'lin_jk') else x
return x
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')
\ No newline at end of file
from .gcn_conv import GCNConv
from .gat_conv import GATConv
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
from core.distributed import PGraph
from torch_geometric.nn import Linear
import torch_geometric.nn.inits as pyg_inits
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = True,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_node = Linear(in_channels, heads * out_channels, bias=False, weight_initializer="glorot")
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False, weight_initializer="glorot")
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
else:
self.register_parameter("lin_edge", None)
self.register_parameter("att_edge", None)
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
self.lin_node.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
pyg_inits.glorot(self.att_src)
pyg_inits.glorot(self.att_dst)
pyg_inits.glorot(self.att_edge)
pyg_inits.zeros(self.bias)
def forward(self,
g: PGraph,
x: Tensor,
*,
edge_attr: Optional[Tensor] = None,
):
x_dst, x_src = g.gather(x, src=True, dst=True)
H, C = self.heads, self.out_channels
x_src: Tensor = self.lin_node(x_src).view(-1, H, C)
x_dst: Tensor = self.lin_node(x_dst).view(-1, H, C)
alpha = self._compute_alpha(g, x_src, x_dst, edge_attr)
w = g.softmax(alpha, src=False, dst=True)
out = g.scatter(x_src * w, src=False, dst=True, reduce_op="sum")
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.bias is not None:
out += self.bias
return out
def _compute_alpha(self,
g: PGraph,
x_src: Tensor,
x_dst: Tensor,
edge_attr: Optional[Tensor],
) -> Tensor:
H, C = self.heads, self.out_channels
alpha_src = (x_src * self.att_src).sum(dim=-1, keepdim=True)
alpha_dst = (x_dst * self.att_dst).sum(dim=-1, keepdim=True)
alpha = alpha_src + alpha_dst
if edge_attr is not None and self.lin_edge is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
edge_attr = self.lin_edge(edge_attr).view(-1, H, C)
alpha += (edge_attr * self.att_edge).sum(dim=-1, keepdim=True)
alpha = F.leaky_relu(alpha, self.negative_slope)
# alpha = alpha.exp()
# alpha = alpha.sigmoid().exp()
return alpha
def _scatter(self,
g: PGraph,
x: Tensor,
alpha: Tensor,
):
s = 1.0 / g.scatter(alpha)
# s[s.isinf()] = 0.0
a = F.dropout(alpha, p=self.dropout, training=self.training)
return g.scatter(x * a) * s
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from core.distributed import PGraph
from torch_geometric.nn import Linear
import torch_geometric.nn.inits as pyg_inits
class GCNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.linear = Linear(in_channels, out_channels, bias=False, weight_initializer="glorot")
self.bias = nn.Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
pyg_inits.zeros(self.bias)
def forward(self,
g: PGraph,
x: Tensor,
edge_attr: Tensor,
) -> Tensor:
x = self.linear(x)
x = g.gather(x, src=True, dst=False)
x = x * edge_attr.view(-1, 1)
x = g.scatter(x, src=False, dst=True, reduce_op="sum")
return x + self.bias
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import Tensor
import os
import os.path as osp
from typing import *
from core.distributed import PGraph
from model import GAT
from utils.functional import train_epoch, eval_epoch
from utils.partition import partition_load
import argparse
import time
import psutil
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--part_algo", type=str, required=True)
parser.add_argument("--hidden_channels", type=int, default=64)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gpu", action="store")
parser.add_argument("--jk", type=str, default=None)
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--heads", type=int, default=1)
def main():
args = parser.parse_args()
rank = dist.get_rank()
world_size = dist.get_world_size()
if args.gpu and torch.cuda.is_available():
device = torch.cuda.current_device()
else:
device = torch.device("cpu")
data = partition_load(
root=args.data_path,
algo=args.part_algo,
)
g = PGraph(
local_node_table=data.local_node_table,
local_edge_table=data.local_edge_table,
local_edge_route=data.local_edge_route,
device=device
).enable_all_gather().enable_all_scatter()
num_features = data.num_features
num_classes = data.y.max().item() + 1
print(f"{rank}: data loaded")
net = GAT(
in_channels=num_features,
hidden_channels=args.hidden_channels,
out_channels=num_classes,
num_layers=args.num_layers,
dropout=args.dropout,
norm=args.norm,
jk=args.jk,
heads=args.heads,
concat=False,
).to(device)
if device != torch.device("cpu"):
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net)
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
print(f"{rank}: initialize model and optimizer")
x = data.x.to(device)
y = data.y.to(device)
train_mask = data.train_mask.to(device)
val_mask = data.val_mask.to(device)
test_mask = data.test_mask.to(device)
avg_mem = 0.0
avg_dur = 0.0
avg_num = 0
best_val_acc = best_test_acc = 0
for ep in range(1, args.epochs+1):
time_start = time.time()
train_loss, train_acc = train_epoch(net, opt, g, x, y, mask=train_mask)
val_loss, val_acc = eval_epoch(net, g, x, y, mask=val_mask)
test_loss, test_acc = eval_epoch(net, g, x, y, mask=test_mask)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
duration = time.time() - time_start
cur_mem = psutil.Process(os.getpid()).memory_info().rss
cur_mem_mb = round(cur_mem / 1024**2)
if ep > 1:
avg_mem += cur_mem
avg_dur += duration
avg_num += 1
print(
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f"best_accuracy: {best_test_acc:.4f}")
avg_mem = round(avg_mem / avg_num / 1024**2)
avg_dur = avg_dur / avg_num
print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
if __name__ == "__main__":
torch.set_num_threads(24)
if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
backend = os.environ.get("C10D_BACKEND", "gloo")
dist.init_process_group(backend)
main()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import Tensor
import os
import os.path as osp
from typing import *
from core.distributed import PGraph
from model import GCN
from utils.functional import train_epoch, eval_epoch
from utils.partition import partition_load
import argparse
import time
import psutil
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--part_algo", type=str, required=True)
parser.add_argument("--hidden_channels", type=int, default=64)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--gpu", action="store")
parser.add_argument("--jk", type=str, default=None)
parser.add_argument("--norm", type=str, default=None)
def main(local_rank: int):
args = parser.parse_args()
rank = dist.get_rank()
world_size = dist.get_world_size()
if args.gpu and torch.cuda.is_available():
device = torch.cuda.current_device()
else:
device = torch.device("cpu")
data = partition_load(
root=args.data_path,
algo=args.part_algo,
)
g = PGraph(
local_node_table=data.local_node_table,
local_edge_table=data.local_edge_table,
local_edge_route=data.local_edge_route,
device=device
).enable_all_gather().enable_all_scatter()
num_features = data.num_features
num_classes = data.y.max().item() + 1
print(f"{rank}: data loaded")
net = GCN(
in_channels=num_features,
hidden_channels=args.hidden_channels,
out_channels=num_classes,
num_layers=args.num_layers,
dropout=args.dropout,
norm=args.norm,
jk=args.jk,
).to(device)
if device != torch.device("cpu"):
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net)
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
print(f"{rank}: initialize model and optimizer")
x = data.x.to(device)
y = data.y.to(device)
w = data.gcn_norm.to(device)
train_mask = data.train_mask.to(device)
val_mask = data.val_mask.to(device)
test_mask = data.test_mask.to(device)
avg_mem = 0.0
avg_dur = 0.0
avg_num = 0
best_val_acc = best_test_acc = 0
for ep in range(1, args.epochs+1):
time_start = time.time()
train_loss, train_acc = train_epoch(net, opt, g, x, y, w, train_mask)
val_loss, val_acc = eval_epoch(net, g, x, y, w, val_mask)
test_loss, test_acc = eval_epoch(net, g, x, y, w, test_mask)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
duration = time.time() - time_start
cur_mem = psutil.Process(os.getpid()).memory_info().rss
cur_mem_mb = round(cur_mem / 1024**2)
if ep > 1:
avg_mem += cur_mem
avg_dur += duration
avg_num += 1
if local_rank == 0:
print(
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f"best_accuracy: {best_test_acc:.4f}")
avg_mem = round(avg_mem / avg_num / 1024**2)
avg_dur = avg_dur / avg_num
print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
if __name__ == "__main__":
torch.set_num_threads(24)
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
backend = os.environ.get("C10D_BACKEND", "gloo")
dist.init_process_group(backend)
main(local_rank)
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from core.distributed import PGraph
from typing import *
from .metrics import *
def train_epoch(
model: nn.Module,
opt: torch.optim.Optimizer,
g: PGraph,
x: Tensor,
y: Tensor,
w: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> float:
model.train()
criterion = nn.CrossEntropyLoss()
if w is None:
pred: Tensor = model(g, x)
else:
pred: Tensor = model(g, x, edge_attr = w)
targ: Tensor = y
if mask is not None:
pred = pred[mask]
targ = targ[mask]
loss: Tensor = criterion(pred, targ)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
train_loss = all_reduce_loss(loss, targ.size(0))
train_acc = accuracy(pred.argmax(dim=-1), targ)
return train_loss, train_acc
@torch.no_grad()
def eval_epoch(
model: nn.Module,
g: PGraph,
x: Tensor,
y: Tensor,
w: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> Tuple[float, float]:
model.eval()
criterion = nn.CrossEntropyLoss()
if w is None:
pred: Tensor = model(g, x)
else:
pred: Tensor = model(g, x, edge_attr = w)
targ: Tensor = y
if mask is not None:
pred = pred[mask]
targ = targ[mask]
loss = criterion(pred, targ)
eval_loss = all_reduce_loss(loss, targ.size(0))
eval_acc = accuracy(pred.argmax(dim=-1), targ)
return eval_loss, eval_acc
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from typing import *
def _local_TP_FP_FN(pred: LongTensor, targ: LongTensor, num_classes: int) -> Tensor:
TP, FP, FN = 0, 1, 2
tmp = torch.empty(3, num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
pred_c = (pred == c)
targ_c = (targ == c)
tmp[TP, c] = torch.count_nonzero(pred_c and targ_c)
tmp[FP, c] = torch.count_nonzero(pred_c and not targ_c)
tmp[FN, c] = torch.count_nonzero(not pred_c and targ_c)
return tmp
def micro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes).sum(dim=-1)
dist.all_reduce(tmp)
TP, FP, FN = tmp.tolist()
precision = TP / (TP + FP)
recall = TP / (TP + FN)
return 2 * precision * recall / (precision + recall)
def macro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes)
dist.all_reduce(tmp)
TP, FP, FN = tmp
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
return f1.mean().item()
def accuracy(pred: LongTensor, targ: LongTensor) -> float:
tmp = torch.empty(2, dtype=torch.float32, device=pred.device)
tmp[0] = pred.eq(targ).count_nonzero()
tmp[1] = pred.size(0)
dist.all_reduce(tmp)
a, b = tmp.tolist()
return a / b
def all_reduce_loss(loss: Tensor, batch_size: int) -> float:
tmp = torch.tensor([
loss.item() * batch_size,
batch_size
], dtype=torch.float32, device=loss.device)
dist.all_reduce(tmp)
cum_loss, n = tmp.tolist()
return cum_loss / n
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import os
import os.path as osp
import shutil
from typing import *
def partition_save(root: str, data: Data, num_parts: int, algo: str = "dbh"):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
pdata = partition_data(data, num_parts, algo, verbose=True)
for i in range(num_parts):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata[i], fn)
def partition_load(root: str, algo: str = "dbh") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_data(data: Data, num_parts: int, algo: str = "dbh", verbose: bool = True) -> List[Data]:
if algo == "metis":
part_fn = metis
elif algo == "dbh":
part_fn = dbh
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
edge_index = data.edge_index
if verbose: print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose: print("computing GCN normalized edge_attr")
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
edge_attr = deg_j[data.edge_index[0]] * deg_i[data.edge_index[1]]
npart_inds = []
epart_inds = []
if verbose: print("computing local partition table")
pdata = [{} for _ in range(num_parts)]
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart_inds.append(npart_i)
epart_inds.append(epart_i)
npart = npart_i
epart = edge_index[:,epart_i]
route = epart.flatten().unique()
route = torch.vstack([
route, node_parts[route]
])
gcn_norm = edge_attr[epart_i]
pdata[i]["local_node_table"] = npart
pdata[i]["local_edge_table"] = epart
pdata[i]["local_edge_route"] = route
pdata[i]["gcn_norm"] = gcn_norm
if verbose: print("partition local features")
for key, val in data:
if key == "edge_index":
continue
if isinstance(val, Tensor):
if val.size(0) == num_nodes:
for i in range(num_parts):
npart_i = npart_inds[i]
pdata[i][key] = val[npart_i]
elif isinstance(val, SparseTensor):
pass
else:
for i in range(num_parts):
pdata[i][key] = val
for i in range(num_parts):
pdata[i] = Data(**pdata[i])
return pdata
def _nopart(edge_index: LongTensor, num_nodes: int) -> Tuple[LongTensor, LongTensor]:
node_parts = torch.zeros(num_nodes, dtype=torch.long)
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
def metis(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)).to_symmetric()
rowptr, col, _ = adj_t.csr()
node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def dbh(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
deg = degree(edge_index.flatten(), num_nodes, dtype=torch.long)
deg_u = deg[edge_index[0]]
deg_v = deg[edge_index[1]]
node_parts = torch.arange(num_nodes).type_as(deg_u)
edge_parts = torch.zeros_like(deg_v)
select_u = deg_u <= deg_v
select_v = deg_u > deg_v
edge_parts[select_u] = node_parts[edge_index[0]][select_u]
edge_parts[select_v] = node_parts[edge_index[1]][select_v]
return node_parts % num_parts, edge_parts % num_parts
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment