Commit 9586742a by Wenjie Huang

fix bugs. Route

parent 10c38111
......@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp
import sys
from starrygl.utils.data import partition_pyg
from starrygl.graph import GraphData
import logging
logging.getLogger().setLevel(logging.INFO)
......@@ -18,7 +18,9 @@ if __name__ == "__main__":
print(f"num_nodes: {data.num_nodes}")
print(f"num_edges: {data.num_edges}")
print(f"num_features: {data.num_features}")
data = GraphData.from_pyg_data(data)
num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", 'mt-metis', "random"]
......@@ -27,4 +29,4 @@ if __name__ == "__main__":
for num_parts in num_parts_list:
for algo in algos:
print(f"======== {num_parts} + {algo} ========")
partition_pyg(root, data, num_parts, algo)
\ No newline at end of file
data.save_partition(root, num_parts, algo)
......@@ -5,7 +5,7 @@ from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.graph import new_vc_route
from starrygl.graph import *
from torch_scatter import scatter_sum
......@@ -28,32 +28,38 @@ all_eparts = [
],
]
def get_route(bipartite: bool = True):
def get_data():
ctx = DistributedContext.get_default_context()
assert ctx.world_size == 3
dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device)
edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t()
return new_vc_route(dst_ids, edge_index, bipartite=bipartite)
src_ids, edge_index = init_vc_edge_index(dst_ids, edge_index)
return GraphData.from_bipartite(edge_index, raw_src_ids=src_ids, raw_dst_ids=dst_ids)
if __name__ == "__main__":
ctx = DistributedContext.init(backend="gloo", use_gpu=True)
src_ids, edge_index, dst_ids, route = get_route(False)
src_size = route.src_len
dst_size = route.dst_len
g = get_data()
route = g.to_route()
edge_index = g.edge_index()
# src_ids, edge_index, dst_ids, route = get_route(False)
# src_size = route.src_len
# dst_size = route.dst_len
ctx.sync_print(route.src_len, route.dst_len)
ctx.sync_print(route._fw_ptr, route._fw_ind)
ctx.sync_print(route._bw_ptr, route._bw_ind)
edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_()
src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len)
dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len)
# ctx.sync_print(route.fw_tensor(dst_ones))
# ctx.sync_print(route.bw_tensor(src_ones))
ctx.sync_print(route.fw_tensor(dst_ones))
ctx.sync_print(route.bw_tensor(src_ones))
out = route.reverse_route().apply(src_ones)
out = route.rev().apply(src_ones)
ctx.sync_print(out)
out.sum().backward()
......@@ -61,4 +67,13 @@ if __name__ == "__main__":
ctx.sync_print(route.get_src_part_ids())
dst_mask = torch.full((route.dst_len,), ctx.rank % 2, dtype=torch.bool, device=ctx.device)
ctx.main_print("="*64)
ctx.sync_print(dst_mask)
_, _, r2 = route.filter(dst_mask)
ctx.sync_print(r2.apply(dst_ones).detach())
ctx.sync_print(r2.rev().apply(src_ones).detach())
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
ctx.shutdown()
......@@ -45,8 +45,6 @@ def all_to_all_v(
assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size
# if group is None:
# group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
......
from .route import Route
from .utils import init_vc_edge_index
from torch import Tensor
from typing import Tuple
__all__ = [
"Route",
"init_vc_edge_index",
"new_vc_route",
]
def new_vc_route(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Route]:
src_ids, local_edge_index = init_vc_edge_index(
dst_ids, edge_index, bipartite=bipartite)
route = Route.from_raw_indices(
src_ids, dst_ids, bipartite=bipartite)
return src_ids, local_edge_index, dst_ids, route
from .data import *
from .route import *
\ No newline at end of file
......@@ -3,19 +3,18 @@ import torch
from torch import Tensor
from typing import *
import os
import os.path as osp
import shutil
from pathlib import Path
from torch_sparse import SparseTensor
from .partition import *
from starrygl.utils.partition import *
from .route import Route
import logging
__all__ = [
"GraphData",
"partition_pyg",
"partition_load",
"init_vc_edge_index",
]
......@@ -75,6 +74,17 @@ class GraphData:
return data
return self._edge_indices[edge_type]
def node_types(self) -> List[str]:
return list(self._node_data.keys())
def edge_types(self) -> List[Tuple[str, str, str]]:
return list(self._edge_data.keys())
def to_route(self, group: Any = None) -> Route:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
......@@ -91,6 +101,126 @@ class GraphData:
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self
@staticmethod
def from_bipartite(
edge_index: Tensor,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
raw_src_ids: Optional[Tensor] = None,
raw_dst_ids: Optional[Tensor] = None,
) -> 'GraphData':
if num_src_nodes is None:
num_src_nodes = raw_src_ids.numel()
if num_dst_nodes is None:
num_dst_nodes = raw_dst_ids.numel()
g = GraphData(
edge_indices={
("src", "@", "dst"): edge_index,
},
num_nodes={
"src": num_src_nodes,
"dst": num_dst_nodes,
}
)
if raw_src_ids is not None:
g.node("src")["raw_ids"] = raw_src_ids
if raw_dst_ids is not None:
g.node("dst")["raw_ids"] = raw_dst_ids
return g
@staticmethod
def from_pyg_data(data) -> 'GraphData':
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
g = GraphData(data.edge_index, data.num_nodes)
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == data.num_nodes:
g.node()[key] = val
elif val.size(0) == data.num_edges:
g.edge()[key] = val
elif isinstance(val, SparseTensor):
logging.warning(f"found sparse matrix {key}, but ignored.")
else:
g.meta()[key] = val
return g
@staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}")
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
local_edges = edge_index[:, epart_mask]
raw_src_ids, local_edges = init_vc_edge_index(
raw_dst_ids, local_edges, bipartite=True,
)
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
)
for key in self.node().keys():
g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys():
g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys():
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData:
def __init__(self) -> None:
......@@ -193,97 +323,137 @@ class EdgeData:
return self
def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
# def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
# p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
# return torch.load(p.__str__())
if base_path.exists():
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True, exist_ok=True)
# def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
# root_path = Path(root).expanduser().resolve()
# base_path = root_path / f"{algo}_{num_parts}"
# if base_path.exists():
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
# for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
# logging.info(f"saving partition data: {i+1}/{num_parts}")
# torch.save(g, (base_path / f"{i:03d}").__str__())
def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
# def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
# from torch_geometric.data import Data
# assert isinstance(data, Data), f"must be Data class in pyg"
logging.info(f"running partition aglorithm: {algo}")
# logging.info(f"running partition aglorithm: {algo}")
num_nodes: int = data.num_nodes
num_edges: int = data.num_edges
edge_index: Tensor = data.edge_index
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
# num_nodes: int = data.num_nodes
# num_edges: int = data.num_edges
# edge_index: Tensor = data.edge_index
# if algo == "metis":
# node_parts = metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "mt-metis":
# node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "random":
# node_parts = random_partition(edge_index, num_nodes, num_parts)
# else:
# raise ValueError(f"unknown partition algorithm: {algo}")
if data.y.dtype == torch.long:
if data.y.dim() == 1:
num_classes = data.y.max().item() + 1
else:
num_classes = data.y.size(1)
else:
num_classes = None
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
local_edges = edge_index[:, epart_mask]
raw_src_ids: Tensor = local_edges[0].unique()
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
# if data.y.dtype == torch.long:
# if data.y.dim() == 1:
# num_classes = data.y.max().item() + 1
# else:
# num_classes = data.y.size(1)
# else:
# num_classes = None
# for i in range(num_parts):
# npart_mask = node_parts == i
# epart_mask = npart_mask[edge_index[1]]
# local_edges = edge_index[:, epart_mask]
# raw_src_ids: Tensor = local_edges[0].unique()
# raw_dst_ids: Tensor = torch.where(npart_mask)[0]
M: int = raw_src_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
local_src = imap[local_edges[0]]
# M: int = raw_src_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
# imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
# local_src = imap[local_edges[0]]
M: int = raw_dst_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
local_dst = imap[local_edges[1]]
local_edges = torch.vstack([local_src, local_dst])
g = GraphData(
edge_indices={
("src", "@", "dst"): local_edges,
},
num_nodes={
"src": raw_src_ids.numel(),
"dst": raw_dst_ids.numel(),
},
)
g.node("src")["raw_ids"] = raw_src_ids
g.node("dst")["raw_ids"] = raw_dst_ids
if num_classes is not None:
g.meta()["num_classes"] = num_classes
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == num_nodes:
g.node("dst")[key] = val[npart_mask]
elif val.size(0) == num_edges:
g.edge()[key] = val[epart_mask]
elif isinstance(val, SparseTensor):
pass
else:
g.meta()[key] = val
yield g
# M: int = raw_dst_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
# imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
# local_dst = imap[local_edges[1]]
# local_edges = torch.vstack([local_src, local_dst])
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# },
# )
# g.node("src")["raw_ids"] = raw_src_ids
# g.node("dst")["raw_ids"] = raw_dst_ids
# if num_classes is not None:
# g.meta()["num_classes"] = num_classes
# for key, val in data:
# if key == "edge_index":
# continue
# elif isinstance(val, Tensor):
# if val.size(0) == num_nodes:
# g.node("dst")[key] = val[npart_mask]
# elif val.size(0) == num_edges:
# g.edge()[key] = val[epart_mask]
# elif isinstance(val, SparseTensor):
# pass
# else:
# g.meta()[key] = val
# yield g
......@@ -2,13 +2,18 @@ import torch
import torch.autograd as autograd
import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.distributed.cclib import all_to_all_s, all_to_all_v
__all__ = [
"Route",
"RouteWork",
"RouteWorkCache",
"RouteAlltoAll",
]
class Route:
......@@ -31,50 +36,51 @@ class Route:
group=group,
)
def subroute(self,
dst_mask: Tensor,
def filter(self,
dst_mask: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
remap: bool = False,
):
if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len
if src_mask is None:
fw_dst_masks, work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True,
)
else:
if src_mask.dtype != torch.bool:
src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
assert src_mask.size(0) == self.src_len
fw_tables: List[Tensor] = []
for i in range(self.num_parts):
m = dst_mask[self.fw_table(i)[0]]
fw_tables.append(self.fw_table(i)[:,m])
bw_tables: List[Tensor] = []
if src_mask is None:
src_mask = torch.zeros(self.src_len, dtype=dst_mask.dtype, device=dst_mask.device)
work.wait()
for i, m in enumerate(fw_dst_masks):
src_mask[self.bw_table(i)[0]] |= m
bw_tables.append(self.bw_table(i)[:,m])
if dst_mask is None:
if src_mask is None:
raise ValueError("please provide at least one parameter.")
else:
assert src_mask.dtype == torch.bool
assert src_mask.numel() == self.src_len
dst_mask = self.bw_tensor(src_mask.long()) != 0
else:
for i in range(self.num_parts):
m = src_mask[self.bw_table(i)[0]]
bw_tables.append(self.bw_table(i)[:,m])
return dst_mask, Route(
assert dst_mask.dtype == torch.bool
assert dst_mask.numel() == self.dst_len
tmp_src_mask = self.fw_tensor(dst_mask.long()) != 0
if src_mask is None:
src_mask = tmp_src_mask
else:
tmp_src_mask &= src_mask
src_mask = tmp_src_mask
dst_mask = self.bw_tensor(src_mask.long()) != 0
fw_ptr, fw_ind = Route.__filter_ind_and_ptr(self._fw_ptr, self._fw_ind, dst_mask)
bw_ptr, bw_ind = Route.__filter_ind_and_ptr(self._bw_ptr, self._bw_ind, src_mask)
route = Route(
src_len=self.src_len,
dst_len=self.dst_len,
**Route.__tables_to_indptr(fw_tables, bw_tables),
fw_ptr=fw_ptr, fw_ind=fw_ind,
bw_ptr=bw_ptr, bw_ind=bw_ind,
group=self.group,
), src_mask, dst_mask
)
if remap:
fw_ind, dst_len = Route.__remap_ind(route._fw_ind, dst_mask)
bw_ind, src_len = Route.__remap_ind(route._bw_ind, src_mask)
route = Route(
src_len=src_len,
dst_len=dst_len,
fw_ptr=route._fw_ptr, fw_ind=fw_ind,
bw_ptr=route._bw_ptr, bw_ind=bw_ind,
group=route.group,
)
return dst_mask, src_mask, route
def reverse_route(self):
def rev(self):
return Route(
src_len=self.dst_len,
dst_len=self.src_len,
......@@ -83,62 +89,21 @@ class Route:
group=self.group,
)
def filter_nodes(self,
dst_mask: Tensor,
src_mask: Optional[Tensor] = None,
):
if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len
new_dst_len = dst_mask.count_nonzero().item()
xmp = torch.empty(self.dst_len, dtype=torch.long, device=dst_mask.device)
xmp.fill_((2**62-1)*2+1)
xmp[dst_mask] = torch.arange(new_dst_len, dtype=torch.long, device=dst_mask.device)
fw_dst_masks, fw_m_work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def filter_edges(self,
edge_mask: Tensor,
edge_index: Tensor,
):
pass
def __init__(self,
src_len: int, dst_len: int,
fw_ptr: List[int], fw_ind: Tensor,
bw_ptr: List[int], bw_ind: Tensor,
group: Any,
) -> None:
self._ctx = DistributedContext.get_default_context()
assert len(fw_ptr) == len(bw_ptr)
self._src_len = src_len
self._dst_len = dst_len
self._fw_ptr = fw_ptr
self._fw_ptr = tuple(fw_ptr)
self._fw_ind = fw_ind
self._bw_ptr = bw_ptr
self._bw_ptr = tuple(bw_ptr)
self._bw_ind = bw_ind
self._group = group
@property
def ctx(self):
return self._ctx
@property
def group(self):
return self._group
......@@ -164,104 +129,84 @@ class Route:
self._bw_ind = self._bw_ind.to(device)
return self
def fw_table(self, i: int):
return self._fw_ind[:,self._fw_ptr[i]:self._fw_ptr[i+1]]
# def fw_table(self, i: int):
# return self._fw_ind[self._fw_ptr[i]:self._fw_ptr[i+1]]
def bw_table(self, i: int):
return self._bw_ind[:,self._bw_ptr[i]:self._bw_ptr[i+1]]
# def bw_table(self, i: int):
# return self._bw_ind[self._bw_ptr[i]:self._bw_ptr[i+1]]
def apply(self, data: Tensor) -> Tensor:
return RouteAlltoAll.apply(data, self)
def apply(self,
data: Tensor,
cache: Optional['RouteWorkCache'] = None,
cache_key: Optional[str] = None,
) -> Tensor:
return RouteAlltoAll.apply(data, self, cache, cache_key)
def fw_tensor(self, data: Tensor) -> Tensor:
@torch.no_grad()
def fw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.dst_len
xs = self.all_to_all_fw(
[data[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros(
self.src_len, *data.shape[1:],
output_tensor = torch.empty(
self._bw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device,
)
for i, t in enumerate(xs):
out[self.bw_table(i)[0]] += t
return out
work = all_to_all_s(
output_tensor, data[self._fw_ind],
self._bw_ptr, self._fw_ptr,
group=self.group,
async_op=async_op,
)
work = RouteWork(
work if async_op else None,
self._bw_ptr, self._bw_ind,
self.src_len, output_tensor,
)
return work if async_op else work.wait()
def bw_tensor(self, data: Tensor) -> Tensor:
@torch.no_grad()
def bw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.src_len
xs = self.all_to_all_bw(
[data[self.bw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros(
self.dst_len, *data.shape[1:],
output_tensor = torch.empty(
self._fw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device,
)
for i, t in enumerate(xs):
out[self.fw_table(i)[0]] += t
return out
def get_src_part_ids(self) -> Tensor:
xs = self.all_to_all_fw(
[
torch.full(
(self.fw_table(i).size(1),), self.part_id,
dtype=torch.long, device=self.ctx.device,
)
for i in range(self.num_parts)
],
async_op=False,
work = all_to_all_s(
output_tensor, data[self._bw_ind],
self._fw_ptr, self._bw_ptr,
group=self.group,
async_op=async_op,
)
out = torch.full((self.src_len,), 2**16-1, dtype=torch.long, device=self.ctx.device)
for i, t in enumerate(xs):
out[self.bw_table(i)[0]] = t
return out.int() & 0xFFFF
def all_to_all_fw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = []
for i in range(self.num_parts):
t = input_tensor_list[i]
assert t.size(0) == self._fw_ptr[i+1] - self._fw_ptr[i]
s = self._bw_ptr[i+1] - self._bw_ptr[i]
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device)
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
work = RouteWork(
work if async_op else None,
self._fw_ptr, self._fw_ind,
self.dst_len, output_tensor,
)
if async_op:
return output_tensor_list, work
else:
return output_tensor_list
return work if async_op else work.wait()
def all_to_all_bw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = []
for i in range(self.num_parts):
t = input_tensor_list[i]
assert t.size(0) == self._bw_ptr[i+1] - self._bw_ptr[i]
s = self._fw_ptr[i+1] - self._fw_ptr[i]
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device)
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
@torch.no_grad()
def get_src_part_ids(self) -> Tensor:
input_tensor = torch.full_like(self._fw_ind, self.part_id)
output_tensor = torch.empty_like(self._bw_ind)
all_to_all_s(
output_tensor, input_tensor,
self._bw_ptr, self._fw_ptr,
group=self.group,
)
if async_op:
return output_tensor_list, work
else:
return output_tensor_list
out = torch.full(
(self.src_len,), 2**16-1,
dtype=self._bw_ind.dtype,
device=self._bw_ind.device,
)
for s, t in zip(self._bw_ptr, self._bw_ptr[1:]):
ind = self._bw_ind[s:t]
assert (out[ind] == 2**16-1).all(), f"some vertices exist in more than one partition"
out[ind] = output_tensor[s:t] & 0xFF
return out
@staticmethod
def _build_route_tables(
......@@ -275,8 +220,6 @@ class Route:
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
ctx = DistributedContext.get_default_context()
src_len = src_ids.size(0)
dst_len = dst_ids.size(0)
......@@ -292,6 +235,7 @@ class Route:
all_dst_lens: List[int] = [None] * world_size
dist.all_gather_object(all_dst_lens, dst_len, group=group)
# all_reduce number of nodes
num_nodes = torch.zeros(1, **ikw)
if src_ids.numel() > 0:
num_nodes = src_ids.max().max(num_nodes)
......@@ -342,25 +286,54 @@ class Route:
src_ind = smp[ind]
dst_ind = xmp[ind]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_tables.append(torch.vstack([dst_ind, src_ind]))
bw_tables.append(torch.vstack([src_ind, dst_ind]))
fw_tables = [t.t().contiguous() for t in fw_tables]
fw_tables = ctx.all_to_all_g(fw_tables, group=group)
fw_tables = [t.t().contiguous() for t in fw_tables]
fw_tables.append(dst_ind)
bw_tables.append(src_ind)
fw_tables = Route.__backward_fw_tables(fw_tables, group=group)
# 非二分图,每个点添加自环
# add self-loops if not bipartite graph
if not bipartite:
rank_ind = torch.arange(dst_len, **ikw)
fw_tables[rank] = bw_tables[rank] = torch.vstack([rank_ind, rank_ind])
fw_tables[rank] = bw_tables[rank] = rank_ind
return fw_tables, bw_tables
@staticmethod
def __expand_index_to_mask(index: Tensor, len: int) -> Tensor:
mask = torch.zeros(len, dtype=torch.bool, device=index.device)
mask[index] = True
return mask
def __filter_ind_and_ptr(ptr: List[int], ind: Tensor, mask: Tensor) -> Tuple[List[int], Tensor]:
m = mask[ind]
new_ptr: List[int] = [0]
new_ind: List[Tensor] = []
for s, t in zip(ptr, ptr[1:]):
new_ind.append(ind[s:t][m[s:t]])
new_ptr.append(new_ptr[-1] + new_ind[-1].numel())
return new_ptr, torch.cat(new_ind, dim=0)
@staticmethod
def __remap_ind(ind: Tensor, mask: Tensor) -> Tuple[Tensor, int]:
n: int = mask.count_nonzero().item()
imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
imp[mask] = torch.arange(n, dtype=ind.dtype, device=ind.device)
return ind, int(n)
@staticmethod
def __backward_fw_tables(
fw_tables: List[Tensor],
group: Any,
) -> List[Tensor]:
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
send_sizes = [t.size() for t in fw_tables]
recv_sizes = [None] * world_size
dist.all_gather_object(recv_sizes, send_sizes, group=group)
recv_sizes = [s[rank] for s in recv_sizes]
fixed_tables = []
for s, t in zip(recv_sizes, fw_tables):
t = torch.empty(*s, dtype=t.dtype, device=t.device)
fixed_tables.append(t)
all_to_all_v(fixed_tables, fw_tables, group=group)
return fixed_tables
@staticmethod
def __tables_to_indptr(
......@@ -369,21 +342,101 @@ class Route:
):
fw_ptr: List[int] = [0]
for t in fw_tables:
last_n = fw_ptr[-1]
fw_ptr.append(last_n + t.size(1))
fw_ind = torch.cat(fw_tables, dim=1)
assert t.dim() == 1
fw_ptr.append(fw_ptr[-1] + t.numel())
fw_ind = torch.cat(fw_tables, dim=0)
bw_ptr: List[int] = [0]
for t in bw_tables:
last_n = bw_ptr[-1]
bw_ptr.append(last_n + t.size(1))
bw_ind = torch.cat(bw_tables, dim=1)
assert t.dim() == 1
bw_ptr.append(bw_ptr[-1] + t.numel())
bw_ind = torch.cat(bw_tables, dim=0)
return {
"fw_ptr": fw_ptr, "bw_ptr": bw_ptr,
"fw_ind": fw_ind, "bw_ind": bw_ind,
}
class RouteWork:
def __init__(self,
work: Any,
ptr: List[int],
ind: Tensor,
len: int,
recv_t: Tensor,
) -> None:
self._work = work
self._ptr = ptr
self._ind = ind
self._len = len
self._recv_t = recv_t
if self._work is None:
self._reduce()
def _reduce(self):
out = torch.zeros(
self._len, *self._recv_t.shape[1:],
dtype=self._recv_t.dtype,
device=self._recv_t.device,
)
for s, t in zip(self._ptr, self._ptr[1:]):
ind = self._ind[s:t]
out[ind] += self._recv_t[s:t]
self._work = None
self._ptr = None
self._ind = None
self._len = None
self._recv_t = out
def wait(self) -> Tensor:
if self._work is None:
return self._recv_t
self._work.wait()
self._reduce()
return self._recv_t
class RouteWorkCache:
def __init__(self,
enable_fw: bool = True,
enable_bw: bool = True,
) -> None:
self.enable_fw = enable_fw
self.enable_bw = enable_bw
self._cached_works: Dict[str, RouteWork] = {}
def enable_fw_(self, enable: bool = True):
self.enable_fw = enable
return self
def enable_bw_(self, enable: bool = True):
self.enable_bw = enable
return self
def wait(self):
for work in self._cached_works.values():
work.wait()
def clear(self):
self._cached_works.clear()
def get_and_set(self,
key: str,
work: RouteWork,
bw: bool = False,
) -> Optional[RouteWork]:
if bw and self.enable_bw:
key = key + "_bw"
elif not bw and self.enable_fw:
key = key + "_fw"
else:
return work
t = self._cached_works.get(key, work)
self._cached_works[key] = work
return t
class RouteAlltoAll(autograd.Function):
@staticmethod
......@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function):
ctx: autograd.function.FunctionCtx,
x: Tensor,
route: Route,
cache: Optional[RouteWorkCache],
cache_key: Optional[str],
):
ctx.saved_route = route
return route.fw_tensor(x)
ctx.saved_cache = cache
ctx.saved_cache_key = cache_key
if cache is None or cache_key is None:
return route.fw_tensor(x)
else:
work = route.fw_tensor(x, async_op=True)
work = cache.get_and_set(cache_key, work, bw=False)
return work.wait()
@staticmethod
def backward(
......@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function):
grad: Tensor,
) -> Tensor:
route: Route = ctx.saved_route
return route.bw_tensor(grad), None
\ No newline at end of file
cache: Optional[RouteWorkCache] = ctx.saved_cache
cache_key: Optional[str] = ctx.saved_cache_key
if cache is None or cache_key is None:
return route.bw_tensor(grad), None, None, None
else:
work = route.bw_tensor(grad, async_op=True)
work = cache.get_and_set(cache_key, work, bw=True)
return work.wait(), None, None, None
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment