Commit 684facbb by Wenjie Huang

worked static full graph training

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
cora/
\ No newline at end of file
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp
import sys
from starrygl.utils.partition import partition_save
if __name__ == "__main__":
data = Planetoid("/mnt/nfs/hwj/pyg_datasets/Planetoid/Cora", "Cora")[0]
if data.is_directed():
data.edge_index, _ = to_undirected(data.edge_index)
data.edge_index, _ = add_remaining_self_loops(data.edge_index)
print(f"num_nodes: {data.num_nodes}")
print(f"num_edges: {data.num_edges}")
print(f"num_features: {data.num_features}")
num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", "random"]
root = osp.splitext(osp.abspath(__file__))[0]
print(f"root: {root}")
for num_parts in num_parts_list:
for algo in algos:
print(f"======== {num_parts} + {algo} ========")
partition_save(root, data, num_parts, algo)
\ No newline at end of file
from .route import Route, GatherWork
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def with_nccl() -> bool:
return dist.get_backend() == "nccl"
def with_gloo() -> bool:
return dist.get_backend() == "gloo"
class Works:
def __init__(self, *works) -> None:
self._works = list(works)
self._waited = False
def wait(self) -> None:
assert not self._waited
for w in self._works:
if w is None:
continue
w.wait()
self._waited = True
def push(self, *works) -> None:
assert not self._waited
self._works.extend(works)
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
) -> Works:
assert len(output_tensor_list) == len(input_tensor_list)
if with_nccl():
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
async_op=True,
)
return Works(work)
elif with_gloo():
rank = dist.get_rank()
world_size = dist.get_world_size()
works = Works()
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_w = dist.isend(input_tensor_list[send_i], send_i)
recv_w = dist.irecv(output_tensor_list[recv_i], recv_i)
works.push(recv_w, send_w)
output_tensor_list[rank][:] = input_tensor_list[rank]
return works
else:
backend = dist.get_backend()
raise RuntimeError(f"unsupported backend: {backend}")
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from contextlib import contextmanager
from typing import *
from .route import Route, GatherWork
class GatherContext:
def __init__(self,
this,
route: Route,
async_op: bool,
) -> None:
self.this = this
self.route = route
self.async_op = async_op
def gather(self,
values: Tensor,
active: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
values, active = GatherFunction.apply(values, active)
return values, active
class Gather:
def __init__(self) -> None:
self.cache: Dict[str, Tensor] = {}
self.last_fw_work: Optional[GatherWork] = None
self.last_bw_work: Optional[GatherWork] = None
def __call__(self,
values: Tensor,
active: Optional[Tensor],
route: Route,
async_op: bool = False,
) -> Tuple[Tensor, Tensor]:
with self._gather_manager(route=route, async_op=async_op) as gtx:
values, active = gtx.gather(values=values, active=active)
return values, active
def clear(self):
self.cache.clear()
self.last_fw_work = None
self.last_bw_work = None
@contextmanager
def _gather_manager(self, route: Route, async_op: bool):
global _global_gather_context
prev_gtx = _global_gather_context
try:
_global_gather_context = GatherContext(
this=self, route=route, async_op=async_op)
yield _global_gather_context
finally:
_global_gather_context = prev_gtx
class GatherFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
values: Tensor,
active: Optional[Tensor] = None,
):
gtx = _last_global_gather_context()
this: Gather = gtx.this
route: Route = gtx.route
fw_key = "__forward_activation__"
cached = this.cache.get(fw_key)
if cached is None:
cached = torch.zeros(
size=(route.dst_size,) + values.shape[1:],
dtype=values.dtype,
device=values.device,
)
current_work = route.gather_forward(values, active, async_op=gtx.async_op)
if gtx.async_op:
work = this.last_fw_work or current_work
this.last_fw_work = current_work
else:
work = current_work
recv_values = work.get_values()
recv_active = work.get_active()
cached[recv_active] = recv_values
this.cache[fw_key] = cached
# 如果输入的values是收缩过的,求解梯度的时候需要去除空洞
if active is not None and values.size(0) < active.size(0):
ctx.shrink_values = True
ctx.save_for_backward(active, recv_active)
else:
ctx.shrink_values = False
ctx.save_for_backward(recv_active)
ctx.gtx = gtx
return cached, recv_active
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad_values: Tensor,
grad_active: Optional[Tensor], # always None
):
gtx: GatherContext = ctx.gtx
this: Gather = gtx.this
route: Route = gtx.route
# 返回的梯度是否需要去空洞
shrink_values: bool = ctx.shrink_values
with torch.no_grad():
# # 反向传播激活值是沿着前向传播的反方向进行
if shrink_values:
active, recv_active = ctx.saved_tensors
else:
recv_active, = ctx.saved_tensors
bw_key = "__backward_gradient__"
cached = this.cache.get(bw_key)
if cached is None:
cached = torch.zeros(
size=(route.src_size,) + grad_values.shape[1:],
dtype=grad_values.dtype,
device=grad_values.device,
)
current_work = route.gather_backward(grad_values, recv_active, async_op=gtx.async_op)
if gtx.async_op:
work = this.last_bw_work or current_work
this.last_bw_work = current_work
else:
work = current_work
recv_grad_values = work.get_values()
recv_grad_active = work.get_active()
cached[recv_grad_active] = recv_grad_values
this.cache[bw_key] = cached
if shrink_values:
return cached[active], None
else:
return cached, None
#### private functions
_global_gather_context: Optional[GatherContext] = None
def _last_global_gather_context() -> GatherContext:
global _global_gather_context
assert _global_gather_context is not None
return _global_gather_context
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from .a2a import all_to_all, Works, with_nccl
class GatherWork:
def __init__(self,
values_buffer: Tensor,
active_buffer: Optional[Tensor],
recv_val_idx: List[Tensor],
recv_val_dat: List[Tensor],
works_list: List[Works],
) -> None:
self._waited = False
self._values_buffer = values_buffer
self._active_buffer = active_buffer
self._recv_val_idx = recv_val_idx
self._recv_val_dat = recv_val_dat
self._works_list = works_list
def _wait(self) -> None:
assert not self._waited
for w in self._works_list:
if w is None:
continue
w.wait()
if self._active_buffer is None:
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat
else:
for idx, dat in zip(self._recv_val_idx, self._recv_val_dat):
self._values_buffer[idx] += dat
self._active_buffer[idx] = True
self._waited = True
def get_values(self) -> Tensor:
if not self._waited:
self._wait()
return self._values_buffer
def get_active(self) -> Optional[Tensor]:
if not self._waited:
self._wait()
return self._active_buffer
class Route:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
):
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
rank = dist.get_rank()
world_size = dist.get_world_size()
# 注意这里的src_ids和dst_ids与PGraph中的二分图含义不一样
# src_ids表示发送方,实际是PGraph.dst_ids
# dst_ids表示接收方,实际是PGraph.src_ids
self.src_size: int = src_ids.size(0)
self.dst_size: int = dst_ids.size(0)
# 索引的格式
ikw = dict(dtype=torch.long, device=src_ids.device)
# 获得每个分区发送方的节点个数
all_src_sizes: List[torch.Size] = [None] * world_size
dist.all_gather_object(all_src_sizes, src_ids.size())
# 获得总节点个数
num_nodes = _all_reduce_num_nodes(src_ids, dst_ids)
# dst_ids节点到局部编号的映射
imp = torch.empty(num_nodes, **ikw).fill_((2**62-1)*2+1)
imp[dst_ids] = torch.arange(self.dst_size, **ikw)
# 这部分代码主要是异步获取其他分区的点,重叠一部分计算和IO
all_src_ids: List[Tensor] = [None] * world_size
all_src_get = [None] * world_size
def fetch_src_ids(i: int):
if i == rank:
all_src_ids[i] = src_ids
else:
s = all_src_sizes[i]
all_src_ids[i] = torch.empty(s, **ikw)
all_src_get[i] = dist.broadcast(
all_src_ids[i], src=i, async_op=True)
self.forward_routes: List[Tensor] = []
self.backward_routes: List[Tensor] = []
# xmp有两个作用,一个是计算节点交集,构建索引
# 另一个是映射src_ids到局部编号
xmp = torch.empty_like(imp)
for i in range(world_size):
# 预取数据
if i == 0:
fetch_src_ids(i)
if i + 1 < world_size:
fetch_src_ids(i + 1)
all_src_get[i].wait()
ids = all_src_ids[i]
# 计算交集,判断需要传输的部分
xmp.fill_(0)
xmp[ids] += 1
xmp[dst_ids] += 1
ind = torch.where(xmp > 1)[0]
# src_ids节点到局部编码的映射
xmp.fill_((2**62-1)*2+1)
xmp[ids] = torch.arange(ids.size(0), **ikw)
# 局部编码
src_ind = xmp[ind]
dst_ind = imp[ind]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_route = torch.vstack([src_ind, dst_ind])
bw_route = torch.vstack([dst_ind, src_ind])
self.forward_routes.append(fw_route)
self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _p2p_recv(self.forward_routes)
# 满足同构图条件,则每个点添加自环
if src_ids.size(0) <= dst_ids.size(0):
t = dst_ids[:src_ids.size(0)] == src_ids
if t.all():
rank_ind = torch.arange(src_ids.size(0), **ikw)
rank_route = torch.vstack([rank_ind, rank_ind])
self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route
dist.barrier()
def _gather_impl(self,
values_buffer: Tensor,
active_buffer: Optional[Tensor],
values: Tensor,
active: Optional[Tensor],
send_routes: List[Tensor],
recv_routes: List[Tensor],
async_op: bool,
):
ikw = dict(dtype=torch.long, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
# 当定义active时,只传输active标记的values到邻居分区,否则传输所有values
if active is not None:
if values.size(0) < active.size(0):
values = _spread_values(values, active)
assert values.size(0) == active.size(0)
# 计算新的路由表
active_routes = [ro[:,active[ro[0]]] for ro in send_routes]
# 计算接收缓冲区大小
scatter_sizes = [ro.size(1) for ro in active_routes]
scatter_sizes = torch.tensor(scatter_sizes, **ikw)
gather_sizes = torch.zeros_like(scatter_sizes)
dist.all_to_all_single(gather_sizes, scatter_sizes)
gather_sizes = gather_sizes.tolist()
def async_run():
if active is not None:
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], []
for i, s in enumerate(gather_sizes):
c = (s,) + values.shape[1:]
send_val_idx.append(active_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_val_dat.append(values[active_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_idx, send_val_idx),
all_to_all(recv_val_dat, send_val_dat),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
else:
send_val_dat, recv_val_dat = [], []
recv_val_idx = [ro[0] for ro in recv_routes]
for i, r in enumerate(recv_routes):
s = r.size(1)
c = (s,) + values.shape[1:]
send_val_dat.append(values[send_routes[i][0]])
recv_val_dat.append(torch.zeros(c, **fkw))
works_list = [
all_to_all(recv_val_dat, send_val_dat),
]
return GatherWork(
values_buffer=values_buffer,
active_buffer=active_buffer,
recv_val_idx=recv_val_idx,
recv_val_dat=recv_val_dat,
works_list=works_list,
)
if async_op and with_nccl():
_stream = _get_stream()
_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(_stream):
return async_run()
else:
return async_run()
def gather_forward(self,
values: Tensor,
active: Optional[Tensor],
async_op: bool = False,
):
bkw = dict(dtype=torch.bool, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
s = (self.dst_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
active_buffer = torch.zeros(self.dst_size, **bkw)
return self._gather_impl(
values_buffer=values_buffer,
active_buffer=active_buffer,
values=values,
active=active,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
async_op=async_op,
)
def gather_backward(self,
values: Tensor,
active: Optional[Tensor],
async_op: bool = False,
):
bkw = dict(dtype=torch.bool, device=values.device)
fkw = dict(dtype=values.dtype, device=values.device)
s = (self.src_size,) + values.shape[1:]
values_buffer = torch.zeros(s, **fkw)
active_buffer = torch.zeros(self.src_size, **bkw)
return self._gather_impl(
values_buffer=values_buffer,
active_buffer=active_buffer,
values=values,
active=active,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
async_op=async_op,
)
#### private functions
def _all_reduce_num_nodes(
src_ids: Tensor,
dst_ids: Tensor,
) -> int:
max_ids = torch.zeros(1, dtype=src_ids.dtype, device=src_ids.device)
max_ids = max_ids.max(src_ids.max()) if src_ids.numel() > 0 else max_ids
max_ids = max_ids.max(dst_ids.max()) if dst_ids.numel() > 0 else max_ids
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX)
return max_ids.item() + 1
def _p2p_recv(tensors: List[Tensor]) -> List[Tensor]:
rank = dist.get_rank()
world_size = dist.get_world_size()
tensor_sizes: List[torch.Size] = [t.size() for t in tensors]
all_tensor_sizes: List[List[torch.Size]] = [None] * world_size
dist.all_gather_object(all_tensor_sizes, tensor_sizes)
new_tensors: List[Tensor] = []
for i in range(world_size):
s = all_tensor_sizes[i][rank]
t = torch.zeros(s).type_as(tensors[i])
new_tensors.append(t)
all_to_all(new_tensors, tensors).wait()
return new_tensors
def _spread_values(values: Tensor, active: Tensor) -> Tensor:
new_values = torch.zeros(
size=active.shape[:1] + values.shape[1:],
dtype=values.dtype,
device=values.device,
)
new_values[active] = values
return new_values
# 创建一个新的stream,专用于Gather异步通信
_STREAM: Optional[torch.cuda.Stream] = None
def _get_stream() -> torch.cuda.Stream:
global _STREAM
if _STREAM is None:
_STREAM = torch.cuda.Stream()
return _STREAM
\ No newline at end of file
from .distgraph import DistGraph, NID, EID, SID, DID
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from contextlib import contextmanager
from .ndata import NID, NData
from .edata import EID, SID, DID, EData
from .utils import init_local_edge_index
from ..core import Route
class DistGraph:
def __init__(self,
ids: Tensor,
edge_index: Tensor,
):
# build local_edge_index
dst_ids = ids
src_ids, local_edge_index = init_local_edge_index(
dst_ids=dst_ids,
edge_index=edge_index,
)
# node's attributes
self.ndata = NData(
src_size=src_ids.size(0),
dst_size=dst_ids.size(0),
)
self.ndata[NID] = src_ids
# edge's attributes
self.edata = EData(
edge_size=local_edge_index.size(1),
)
self.edata[EID] = local_edge_index
# graph's attributes
self.data = {}
self.args = {}
self.route = Route(dst_ids, src_ids)
# @property
# def edge_index(self) -> Tensor:
# return self.edata[EID]
# @property
# def src_size(self) -> int:
# self.ndata.src_size
# @property
# def dst_size(self) -> int:
# self.ndata.dst_size
@contextmanager
def scoped_manager(self):
last_ndata = self.ndata
last_edata = self.edata
try:
self.ndata = NData(
src_size=None,
dst_size=None,
p=last_ndata,
)
self.edata = EData(
edge_size=None,
p=last_edata,
)
yield self
finally:
self.ndata = last_ndata
self.edata = last_edata
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
EID: str = "__edata_eid_key__"
SID: str = "__edata_sid_key__"
DID: str = "__edata_did_key__"
class EData:
def __init__(self, edge_size: int, p = None) -> None:
if p is None:
self.edge_size = edge_size
else:
assert edge_size is None
self.edge_size = p.edge_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
if name in (EID, SID, DID):
raise ValueError(f"not found EID in data")
else:
raise ValueError(f"not found '{name}' in data")
return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError(f"the second parameter's type must be Tensor")
if name in (SID, DID):
raise ValueError(f"please try EID")
if name == EID:
if tensor.dim() == 2 \
and tensor.size(0) == 2 \
and tensor.size(1) == self.edge_size:
self.data[name] = tensor
else:
raise ValueError(f"EID's shape must be 2 x edge_size")
else:
if tensor.size(0) == self.edge_size:
self.data[name] = tensor
else:
raise ValueError(f"tensor's shape must match the edge_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
def _get_impl(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = self.prev_data
return t
def get(self, name: str) -> Optional[Tensor]:
if name in (EID, SID, DID):
t = self._get_impl(EID)
if t is None:
raise ValueError(f"no edge info in data")
assert t.dim() == 2 and t.size(0) == 2
if name == SID:
t = t[0]
elif name == DID:
t = t[1]
return t
else:
return self._get_impl(name)
def pop(self, name: str) -> Tensor:
if name in (SID, DID):
raise ValueError(f"please try EID")
if name == EID:
return self.data.pop(EID)
else:
return self.data.pop(name)
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
NID: str = "__ndata_nid_key__"
class NData:
def __init__(self, src_size: int, dst_size: int, p = None) -> None:
if p is None:
self.src_size = src_size
self.dst_size = dst_size
else:
assert src_size is None
assert dst_size is None
self.src_size = p.src_size
self.dst_size = p.dst_size
self.prev_data = p
self.data: Dict[str, Tensor] = {}
def __getitem__(self, name: str) -> Tensor:
t = self.get(name)
if t is None:
raise ValueError(f"not found '{name}' in data")
return t
def __setitem__(self, name: str, tensor: Tensor):
if not isinstance(tensor, Tensor):
raise ValueError(f"the second parameter's type must be Tensor")
if tensor.size(0) == self.src_size:
self.data[name] = tensor
elif tensor.size(0) == self.dst_size:
self.data[name] = tensor
else:
raise ValueError(f"tensor's shape must match the src_size or dst_size")
def __delitem__(self, name: str) -> None:
self.pop(name)
def get(self, name: str) -> Optional[Tensor]:
p, t = self, None
while p is not None:
t = p.data.get(name)
if t is not None:
break
p = self.prev_data
return t
def pop(self, name: str) -> Tensor:
return self.data.pop(name)
def get_type(self, name: str):
t = self.get(name)
if t.size(0) == self.src_size:
return "src"
elif t.size(0) == self.dst_size:
return "dst"
else:
raise RuntimeError
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
local_edge_index = xmp[edge_index.flatten()].view_as(edge_index)
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
\ No newline at end of file
from .convs import GCNConv, GATConv, GINConv
from .basic_gnn import BasicGNN
class GCN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return GCNConv(in_channels, out_channels, **kwargs)
class GAT(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return GATConv(in_channels, out_channels, **kwargs)
class GIN(BasicGNN):
def init_conv(self, in_channels: int, out_channels: int, **kwargs):
return GINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import copy
import inspect
from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
from ..graph.distgraph import DistGraph
from ..core.gather import Gather
class BasicGNN(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
num_layers: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
act: Optional[str] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Optional[str] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
jk: Optional[str] = None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
out_channels = hidden_channels \
if out_channels is None else out_channels
self.out_channels = out_channels
self.dropout = dropout
self.gathers: List[Gather] = []
for _ in range(num_layers):
self.gathers.append(Gather())
self.convs = nn.ModuleList()
for _ in range(num_layers - 1):
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
in_channels = hidden_channels
if jk is None:
self.convs.append(
self.init_conv(in_channels, out_channels, **kwargs))
else:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
if jk != "last":
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
if jk == "cat":
jk_channels = num_layers * hidden_channels
else:
jk_channels = hidden_channels
self.lin_jk = nn.Linear(jk_channels, out_channels)
self.act = activation_resolver(act, **(act_kwargs or {}))
if norm is not None:
self.norms = nn.ModuleList()
for _ in range(num_layers - 1):
self.norms.append(normalization_resolver(
norm, hidden_channels, **(norm_kwargs or {})))
if jk is not None:
self.norms.append(normalization_resolver(
norm, hidden_channels, **(norm_kwargs or {})))
self.reset_parameters()
def init_conv(self,
in_channels: int,
out_channels: int,
**kwargs
) -> nn.Module:
raise NotImplementedError
def reset_parameters(self):
if hasattr(self, "jk"):
self.jk.reset_parameters()
if hasattr(self, "lin_jk"):
self.lin_jk.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
if hasattr(self, "norms"):
for norm in self.norms:
norm.reset_parameters()
def forward(
self,
g: DistGraph,
) -> Tensor:
xs: List[Tensor] = []
x = g.ndata["x"]
async_op = g.args.get("async_op", False)
for i in range(self.num_layers):
# 获取邻居信息
x, _ = self.gathers[i](x, None, g.route, async_op=async_op)
# !在考虑这个dropout层的位置
x = F.dropout(x, p=self.dropout, training=self.training)
# 执行图卷积
x = self.convs[i](g, x)
# 如果没有jk,则最后一层不加bn和act
if i == self.num_layers - 1 and not hasattr(self, "jk"):
break
if hasattr(self, "norms"):
x = self.norms[i](x)
if hasattr(self, "act"):
x = self.act(x)
if hasattr(self, "jk"):
xs.append(x)
if hasattr(self, "jk"):
x = self.jk([t for t in xs])
x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
return x
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
import torch_geometric.nn.inits as pyg_inits
import torch_geometric.utils as pyg_utils
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph, EID
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.lin = pyg_nn.Linear(in_channels, heads * out_channels,
bias=False, weight_initializer="glorot")
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = pyg_nn.Linear(edge_dim, heads * out_channels,
bias=False, weight_initializer="glorot")
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
pyg_inits.glorot(self.att_src)
pyg_inits.glorot(self.att_dst)
if hasattr(self, "lin_edge"):
self.lin_edge.reset_parameters()
pyg_inits.glorot(self.att_edge)
if hasattr(self, "bias"):
pyg_inits.zeros(self.bias)
def forward(self, g: DistGraph, x: Tensor) -> Tuple[Tensor, Tensor]:
H, C = self.heads, self.out_channels
edge_index = g.edata[EID]
x = self.lin(x).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if hasattr(self, "lin_edge"):
edge_attr = g.edata["edge_attr"]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = self.lin_edge(edge_attr).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = pyg_utils.softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.ndata.dst_size,
)
if self.dropout != 0.0:
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = alpha.unsqueeze(-1) * x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.ndata.dst_size)
if self.concat:
x = x.view(-1, self.heads * self.out_channels)
else:
x = x.mean(dim=1)
if hasattr(self, "bias"):
x += self.bias
return x
import torch
import torch.nn as nn
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph, EID
class GCNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lin = pyg_nn.Linear(in_channels, out_channels,
bias=False, weight_initializer="glorot")
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, g: DistGraph, x: Tensor) -> Tuple[Tensor, Tensor]:
edge_index = g.edata[EID]
x = self.lin(x)
x = x[edge_index[0]]
x = x * g.edata["gcn_norm"].view(-1, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.ndata.dst_size)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
import torch_geometric.nn.inits as pyg_inits
import torch_geometric.utils as pyg_utils
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph, EID
class GINConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mlp_channels is None:
mlp_channels = in_channels + out_channels
self.mlp_channels = mlp_channels
self.nn = nn.Sequential(
nn.Linear(in_channels, mlp_channels),
nn.ReLU(),
nn.Linear(mlp_channels, out_channels),
)
self.initial_eps = eps
if train_eps:
self.eps = torch.nn.Parameter(torch.tensor[eps])
else:
self.register_buffer("eps", torch.tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
pyg_inits.reset(self.nn)
self.eps.data.fill_(self.initial_eps)
def forward(self, g: DistGraph, x: Tensor) -> Tuple[Tensor, Tensor]:
edge_index = g.edata[EID]
t = x[edge_index[0]]
t = scatter_sum(t, edge_index[1], dim=0, dim_size=g.ndata.dst_size)
x = t + (1 + self.eps) * x[:g.ndata.dst_size]
return self.nn(x)
import torch
import torch.nn as nn
import torch.distributed as dist
import os
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
if dist.get_backend() == "nccl":
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
dist.init_process_group(backend)
if backend == "nccl":
local_rank = os.getenv("LOCAL_RANK")
if local_rank is None:
device = torch.device(f"cuda:{dist.get_rank()}")
else:
local_rank = int(local_rank)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
return device
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from core.distributed import PGraph
from typing import *
from .metrics import *
def train_epoch(
model: nn.Module,
opt: torch.optim.Optimizer,
g: PGraph,
x: Tensor,
y: Tensor,
w: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> float:
model.train()
criterion = nn.CrossEntropyLoss()
if w is None:
pred: Tensor = model(g, x)
else:
pred: Tensor = model(g, x, edge_attr = w)
targ: Tensor = y
if mask is not None:
pred = pred[mask]
targ = targ[mask]
loss: Tensor = criterion(pred, targ)
opt.zero_grad()
loss.backward()
opt.step()
with torch.no_grad():
train_loss = all_reduce_loss(loss, targ.size(0))
train_acc = accuracy(pred.argmax(dim=-1), targ)
return train_loss, train_acc
@torch.no_grad()
def eval_epoch(
model: nn.Module,
g: PGraph,
x: Tensor,
y: Tensor,
w: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> Tuple[float, float]:
model.eval()
criterion = nn.CrossEntropyLoss()
if w is None:
pred: Tensor = model(g, x)
else:
pred: Tensor = model(g, x, edge_attr = w)
targ: Tensor = y
if mask is not None:
pred = pred[mask]
targ = targ[mask]
loss = criterion(pred, targ)
eval_loss = all_reduce_loss(loss, targ.size(0))
eval_acc = accuracy(pred.argmax(dim=-1), targ)
return eval_loss, eval_acc
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from typing import *
def _local_TP_FP_FN(pred: LongTensor, targ: LongTensor, num_classes: int) -> Tensor:
TP, FP, FN = 0, 1, 2
tmp = torch.empty(3, num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
pred_c = (pred == c)
targ_c = (targ == c)
tmp[TP, c] = torch.count_nonzero(pred_c and targ_c)
tmp[FP, c] = torch.count_nonzero(pred_c and not targ_c)
tmp[FN, c] = torch.count_nonzero(not pred_c and targ_c)
return tmp
def micro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes).sum(dim=-1)
dist.all_reduce(tmp)
TP, FP, FN = tmp.tolist()
precision = TP / (TP + FP)
recall = TP / (TP + FN)
return 2 * precision * recall / (precision + recall)
def macro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes)
dist.all_reduce(tmp)
TP, FP, FN = tmp
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
return f1.mean().item()
def accuracy(pred: LongTensor, targ: LongTensor) -> float:
tmp = torch.empty(2, dtype=torch.float32, device=pred.device)
tmp[0] = pred.eq(targ).count_nonzero()
tmp[1] = pred.size(0)
dist.all_reduce(tmp)
a, b = tmp.tolist()
return a / b
def all_reduce_loss(loss: Tensor, batch_size: int) -> float:
tmp = torch.tensor([
loss.item() * batch_size,
batch_size
], dtype=torch.float32, device=loss.device)
dist.all_reduce(tmp)
cum_loss, n = tmp.tolist()
return cum_loss / n
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import os
import os.path as osp
import shutil
from typing import *
def partition_save(root: str, data: Data, num_parts: int, algo: str = "metis"):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
for i, pdata in enumerate(partition_data(data, num_parts, algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_data(data: Data, num_parts: int, algo: str, verbose: bool = False) -> List[Data]:
if algo == "metis":
part_fn = metis_partition
elif algo == "random":
part_fn = random_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose: print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose: print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose: print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:,epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index":
continue
if isinstance(val, Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def compute_gcn_norm(edge_index: LongTensor, num_nodes: int) -> Tensor:
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def _nopart(edge_index: LongTensor, num_nodes: int) -> Tuple[LongTensor, LongTensor]:
node_parts = torch.zeros(num_nodes, dtype=torch.long)
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
def metis_partition(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)).to_symmetric()
rowptr, col, _ = adj_t.csr()
node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def random_partition(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
node_parts = torch.randint(num_parts, size=(num_nodes,), dtype=edge_index.dtype)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
\ No newline at end of file
import torch.distributed as dist
def sync_print(*args, **kwargs):
rank = dist.get_rank()
world_size = dist.get_world_size()
for i in range(world_size):
if i == rank:
print(f"rank {rank}:", *args, **kwargs)
dist.barrier()
def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
dist.barrier()
\ No newline at end of file
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from starrygl.graph.distgraph import DistGraph
from starrygl.nn import *
from starrygl.nn.parallel import init_process_group, convert_parallel_model
from starrygl.utils.printer import main_print
from starrygl.utils.partition import partition_load
from starrygl.utils.metrics import accuracy
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# 定义GAT图神经网络模型
net = GCN(
in_channels=pdata.num_features,
hidden_channels=512,
num_layers=3,
out_channels=pdata.num_classes,
# norm="batchnorm",
dropout=0.5,
# heads=8,
).to(device)
# 转换成分布式并行版本
net = convert_parallel_model(net)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# 开始训练
for ep in range(100):
## 训练步
net.train()
with g.scoped_manager():
g.ndata["x"] = pdata.x
g.edata["gcn_norm"] = pdata.gcn_norm
h = net(g)
loss: Tensor = criterion(
h[pdata.train_mask],
pdata.y[pdata.train_mask],
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
## 验证步
net.eval()
with g.scoped_manager():
g.ndata["x"] = pdata.x
g.edata["gcn_norm"] = pdata.gcn_norm
with torch.no_grad():
pred = net(g).argmax(dim=-1)
acc = accuracy(
pred[pdata.val_mask],
pdata.y[pdata.val_mask],
)
main_print(f"{ep}: {loss.item()} acc: {acc}")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment