Commit 9586742a by Wenjie Huang

fix bugs. Route

parent 10c38111
...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected ...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp import os.path as osp
import sys import sys
from starrygl.utils.data import partition_pyg from starrygl.graph import GraphData
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
...@@ -18,7 +18,9 @@ if __name__ == "__main__": ...@@ -18,7 +18,9 @@ if __name__ == "__main__":
print(f"num_nodes: {data.num_nodes}") print(f"num_nodes: {data.num_nodes}")
print(f"num_edges: {data.num_edges}") print(f"num_edges: {data.num_edges}")
print(f"num_features: {data.num_features}") print(f"num_features: {data.num_features}")
data = GraphData.from_pyg_data(data)
num_parts_list = [1, 2, 3, 5, 7, 9, 11] num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", 'mt-metis', "random"] algos = ["metis", 'mt-metis', "random"]
...@@ -27,4 +29,4 @@ if __name__ == "__main__": ...@@ -27,4 +29,4 @@ if __name__ == "__main__":
for num_parts in num_parts_list: for num_parts in num_parts_list:
for algo in algos: for algo in algos:
print(f"======== {num_parts} + {algo} ========") print(f"======== {num_parts} + {algo} ========")
partition_pyg(root, data, num_parts, algo) data.save_partition(root, num_parts, algo)
\ No newline at end of file
...@@ -5,7 +5,7 @@ from torch import Tensor ...@@ -5,7 +5,7 @@ from torch import Tensor
from typing import * from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed import DistributedContext
from starrygl.graph import new_vc_route from starrygl.graph import *
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
...@@ -28,32 +28,38 @@ all_eparts = [ ...@@ -28,32 +28,38 @@ all_eparts = [
], ],
] ]
def get_data():
def get_route(bipartite: bool = True):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
assert ctx.world_size == 3 assert ctx.world_size == 3
dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device) dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device)
edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t() edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t()
return new_vc_route(dst_ids, edge_index, bipartite=bipartite) src_ids, edge_index = init_vc_edge_index(dst_ids, edge_index)
return GraphData.from_bipartite(edge_index, raw_src_ids=src_ids, raw_dst_ids=dst_ids)
if __name__ == "__main__": if __name__ == "__main__":
ctx = DistributedContext.init(backend="gloo", use_gpu=True) ctx = DistributedContext.init(backend="gloo", use_gpu=True)
src_ids, edge_index, dst_ids, route = get_route(False) g = get_data()
src_size = route.src_len route = g.to_route()
dst_size = route.dst_len edge_index = g.edge_index()
# src_ids, edge_index, dst_ids, route = get_route(False)
# src_size = route.src_len
# dst_size = route.dst_len
ctx.sync_print(route.src_len, route.dst_len) ctx.sync_print(route.src_len, route.dst_len)
ctx.sync_print(route._fw_ptr, route._fw_ind)
ctx.sync_print(route._bw_ptr, route._bw_ind)
edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_() edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_()
src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len) src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len)
dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len) dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len)
# ctx.sync_print(route.fw_tensor(dst_ones)) ctx.sync_print(route.fw_tensor(dst_ones))
# ctx.sync_print(route.bw_tensor(src_ones)) ctx.sync_print(route.bw_tensor(src_ones))
out = route.reverse_route().apply(src_ones) out = route.rev().apply(src_ones)
ctx.sync_print(out) ctx.sync_print(out)
out.sum().backward() out.sum().backward()
...@@ -61,4 +67,13 @@ if __name__ == "__main__": ...@@ -61,4 +67,13 @@ if __name__ == "__main__":
ctx.sync_print(route.get_src_part_ids()) ctx.sync_print(route.get_src_part_ids())
dst_mask = torch.full((route.dst_len,), ctx.rank % 2, dtype=torch.bool, device=ctx.device)
ctx.main_print("="*64)
ctx.sync_print(dst_mask)
_, _, r2 = route.filter(dst_mask)
ctx.sync_print(r2.apply(dst_ones).detach())
ctx.sync_print(r2.rev().apply(src_ones).detach())
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
ctx.shutdown() ctx.shutdown()
...@@ -45,8 +45,6 @@ def all_to_all_v( ...@@ -45,8 +45,6 @@ def all_to_all_v(
assert len(output_tensor_list) == world_size assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size assert len(input_tensor_list) == world_size
# if group is None:
# group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group) backend = dist.get_backend(group)
if backend == "nccl": if backend == "nccl":
......
from .route import Route from .data import *
from .utils import init_vc_edge_index from .route import *
\ No newline at end of file
from torch import Tensor
from typing import Tuple
__all__ = [
"Route",
"init_vc_edge_index",
"new_vc_route",
]
def new_vc_route(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Route]:
src_ids, local_edge_index = init_vc_edge_index(
dst_ids, edge_index, bipartite=bipartite)
route = Route.from_raw_indices(
src_ids, dst_ids, bipartite=bipartite)
return src_ids, local_edge_index, dst_ids, route
...@@ -3,19 +3,18 @@ import torch ...@@ -3,19 +3,18 @@ import torch
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import os
import os.path as osp
import shutil import shutil
from pathlib import Path from pathlib import Path
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from .partition import * from starrygl.utils.partition import *
from .route import Route
import logging import logging
__all__ = [ __all__ = [
"GraphData", "GraphData",
"partition_pyg", "init_vc_edge_index",
"partition_load",
] ]
...@@ -75,6 +74,17 @@ class GraphData: ...@@ -75,6 +74,17 @@ class GraphData:
return data return data
return self._edge_indices[edge_type] return self._edge_indices[edge_type]
def node_types(self) -> List[str]:
return list(self._node_data.keys())
def edge_types(self) -> List[Tuple[str, str, str]]:
return list(self._edge_data.keys())
def to_route(self, group: Any = None) -> Route:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
@property @property
def is_heterogeneous(self) -> bool: def is_heterogeneous(self) -> bool:
return self._heterogeneous return self._heterogeneous
...@@ -91,6 +101,126 @@ class GraphData: ...@@ -91,6 +101,126 @@ class GraphData:
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()} self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self return self
@staticmethod
def from_bipartite(
edge_index: Tensor,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
raw_src_ids: Optional[Tensor] = None,
raw_dst_ids: Optional[Tensor] = None,
) -> 'GraphData':
if num_src_nodes is None:
num_src_nodes = raw_src_ids.numel()
if num_dst_nodes is None:
num_dst_nodes = raw_dst_ids.numel()
g = GraphData(
edge_indices={
("src", "@", "dst"): edge_index,
},
num_nodes={
"src": num_src_nodes,
"dst": num_dst_nodes,
}
)
if raw_src_ids is not None:
g.node("src")["raw_ids"] = raw_src_ids
if raw_dst_ids is not None:
g.node("dst")["raw_ids"] = raw_dst_ids
return g
@staticmethod
def from_pyg_data(data) -> 'GraphData':
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
g = GraphData(data.edge_index, data.num_nodes)
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == data.num_nodes:
g.node()[key] = val
elif val.size(0) == data.num_edges:
g.edge()[key] = val
elif isinstance(val, SparseTensor):
logging.warning(f"found sparse matrix {key}, but ignored.")
else:
g.meta()[key] = val
return g
@staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}")
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
local_edges = edge_index[:, epart_mask]
raw_src_ids, local_edges = init_vc_edge_index(
raw_dst_ids, local_edges, bipartite=True,
)
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
)
for key in self.node().keys():
g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys():
g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys():
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData: class MetaData:
def __init__(self) -> None: def __init__(self) -> None:
...@@ -193,97 +323,137 @@ class EdgeData: ...@@ -193,97 +323,137 @@ class EdgeData:
return self return self
def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData: def init_vc_edge_index(
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}" dst_ids: Tensor,
return torch.load(p.__str__()) edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"): # def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
root_path = Path(root).expanduser().resolve() # p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
base_path = root_path / f"{algo}_{num_parts}" # return torch.load(p.__str__())
if base_path.exists():
shutil.rmtree(base_path.__str__()) # def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
base_path.mkdir(parents=True, exist_ok=True) # root_path = Path(root).expanduser().resolve()
# base_path = root_path / f"{algo}_{num_parts}"
# if base_path.exists():
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
for i, g in enumerate(partition_pyg_data(data, num_parts, algo)): # for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
logging.info(f"saving partition data: {i+1}/{num_parts}") # logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__()) # torch.save(g, (base_path / f"{i:03d}").__str__())
def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]: # def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
from torch_geometric.data import Data # from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg" # assert isinstance(data, Data), f"must be Data class in pyg"
logging.info(f"running partition aglorithm: {algo}") # logging.info(f"running partition aglorithm: {algo}")
num_nodes: int = data.num_nodes # num_nodes: int = data.num_nodes
num_edges: int = data.num_edges # num_edges: int = data.num_edges
edge_index: Tensor = data.edge_index # edge_index: Tensor = data.edge_index
if algo == "metis": # if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts) # node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis": # elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts) # node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random": # elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts) # node_parts = random_partition(edge_index, num_nodes, num_parts)
else: # else:
raise ValueError(f"unknown partition algorithm: {algo}") # raise ValueError(f"unknown partition algorithm: {algo}")
if data.y.dtype == torch.long: # if data.y.dtype == torch.long:
if data.y.dim() == 1: # if data.y.dim() == 1:
num_classes = data.y.max().item() + 1 # num_classes = data.y.max().item() + 1
else: # else:
num_classes = data.y.size(1) # num_classes = data.y.size(1)
else: # else:
num_classes = None # num_classes = None
for i in range(num_parts): # for i in range(num_parts):
npart_mask = node_parts == i # npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]] # epart_mask = npart_mask[edge_index[1]]
local_edges = edge_index[:, epart_mask] # local_edges = edge_index[:, epart_mask]
raw_src_ids: Tensor = local_edges[0].unique() # raw_src_ids: Tensor = local_edges[0].unique()
raw_dst_ids: Tensor = torch.where(npart_mask)[0] # raw_dst_ids: Tensor = torch.where(npart_mask)[0]
M: int = raw_src_ids.max().item() + 1 # M: int = raw_src_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids) # imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids) # imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
local_src = imap[local_edges[0]] # local_src = imap[local_edges[0]]
M: int = raw_dst_ids.max().item() + 1 # M: int = raw_dst_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids) # imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids) # imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
local_dst = imap[local_edges[1]] # local_dst = imap[local_edges[1]]
local_edges = torch.vstack([local_src, local_dst]) # local_edges = torch.vstack([local_src, local_dst])
g = GraphData( # g = GraphData(
edge_indices={ # edge_indices={
("src", "@", "dst"): local_edges, # ("src", "@", "dst"): local_edges,
}, # },
num_nodes={ # num_nodes={
"src": raw_src_ids.numel(), # "src": raw_src_ids.numel(),
"dst": raw_dst_ids.numel(), # "dst": raw_dst_ids.numel(),
}, # },
) # )
g.node("src")["raw_ids"] = raw_src_ids # g.node("src")["raw_ids"] = raw_src_ids
g.node("dst")["raw_ids"] = raw_dst_ids # g.node("dst")["raw_ids"] = raw_dst_ids
if num_classes is not None: # if num_classes is not None:
g.meta()["num_classes"] = num_classes # g.meta()["num_classes"] = num_classes
for key, val in data: # for key, val in data:
if key == "edge_index": # if key == "edge_index":
continue # continue
elif isinstance(val, Tensor): # elif isinstance(val, Tensor):
if val.size(0) == num_nodes: # if val.size(0) == num_nodes:
g.node("dst")[key] = val[npart_mask] # g.node("dst")[key] = val[npart_mask]
elif val.size(0) == num_edges: # elif val.size(0) == num_edges:
g.edge()[key] = val[epart_mask] # g.edge()[key] = val[epart_mask]
elif isinstance(val, SparseTensor): # elif isinstance(val, SparseTensor):
pass # pass
else: # else:
g.meta()[key] = val # g.meta()[key] = val
yield g # yield g
...@@ -2,13 +2,18 @@ import torch ...@@ -2,13 +2,18 @@ import torch
import torch.autograd as autograd import torch.autograd as autograd
import torch.distributed as dist import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed.cclib import all_to_all_s, all_to_all_v
__all__ = [
"Route",
"RouteWork",
"RouteWorkCache",
"RouteAlltoAll",
]
class Route: class Route:
...@@ -31,50 +36,51 @@ class Route: ...@@ -31,50 +36,51 @@ class Route:
group=group, group=group,
) )
def subroute(self, def filter(self,
dst_mask: Tensor, dst_mask: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
remap: bool = False,
): ):
if dst_mask.dtype != torch.bool: if dst_mask is None:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len) if src_mask is None:
assert dst_mask.size(0) == self.dst_len raise ValueError("please provide at least one parameter.")
else:
if src_mask is None: assert src_mask.dtype == torch.bool
fw_dst_masks, work = self.all_to_all_fw( assert src_mask.numel() == self.src_len
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)], dst_mask = self.bw_tensor(src_mask.long()) != 0
async_op=True,
)
else:
if src_mask.dtype != torch.bool:
src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
assert src_mask.size(0) == self.src_len
fw_tables: List[Tensor] = []
for i in range(self.num_parts):
m = dst_mask[self.fw_table(i)[0]]
fw_tables.append(self.fw_table(i)[:,m])
bw_tables: List[Tensor] = []
if src_mask is None:
src_mask = torch.zeros(self.src_len, dtype=dst_mask.dtype, device=dst_mask.device)
work.wait()
for i, m in enumerate(fw_dst_masks):
src_mask[self.bw_table(i)[0]] |= m
bw_tables.append(self.bw_table(i)[:,m])
else: else:
for i in range(self.num_parts): assert dst_mask.dtype == torch.bool
m = src_mask[self.bw_table(i)[0]] assert dst_mask.numel() == self.dst_len
bw_tables.append(self.bw_table(i)[:,m]) tmp_src_mask = self.fw_tensor(dst_mask.long()) != 0
if src_mask is None:
return dst_mask, Route( src_mask = tmp_src_mask
else:
tmp_src_mask &= src_mask
src_mask = tmp_src_mask
dst_mask = self.bw_tensor(src_mask.long()) != 0
fw_ptr, fw_ind = Route.__filter_ind_and_ptr(self._fw_ptr, self._fw_ind, dst_mask)
bw_ptr, bw_ind = Route.__filter_ind_and_ptr(self._bw_ptr, self._bw_ind, src_mask)
route = Route(
src_len=self.src_len, src_len=self.src_len,
dst_len=self.dst_len, dst_len=self.dst_len,
**Route.__tables_to_indptr(fw_tables, bw_tables), fw_ptr=fw_ptr, fw_ind=fw_ind,
bw_ptr=bw_ptr, bw_ind=bw_ind,
group=self.group, group=self.group,
), src_mask, dst_mask )
if remap:
fw_ind, dst_len = Route.__remap_ind(route._fw_ind, dst_mask)
bw_ind, src_len = Route.__remap_ind(route._bw_ind, src_mask)
route = Route(
src_len=src_len,
dst_len=dst_len,
fw_ptr=route._fw_ptr, fw_ind=fw_ind,
bw_ptr=route._bw_ptr, bw_ind=bw_ind,
group=route.group,
)
return dst_mask, src_mask, route
def reverse_route(self): def rev(self):
return Route( return Route(
src_len=self.dst_len, src_len=self.dst_len,
dst_len=self.src_len, dst_len=self.src_len,
...@@ -83,62 +89,21 @@ class Route: ...@@ -83,62 +89,21 @@ class Route:
group=self.group, group=self.group,
) )
def filter_nodes(self,
dst_mask: Tensor,
src_mask: Optional[Tensor] = None,
):
if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len
new_dst_len = dst_mask.count_nonzero().item()
xmp = torch.empty(self.dst_len, dtype=torch.long, device=dst_mask.device)
xmp.fill_((2**62-1)*2+1)
xmp[dst_mask] = torch.arange(new_dst_len, dtype=torch.long, device=dst_mask.device)
fw_dst_masks, fw_m_work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def filter_edges(self,
edge_mask: Tensor,
edge_index: Tensor,
):
pass
def __init__(self, def __init__(self,
src_len: int, dst_len: int, src_len: int, dst_len: int,
fw_ptr: List[int], fw_ind: Tensor, fw_ptr: List[int], fw_ind: Tensor,
bw_ptr: List[int], bw_ind: Tensor, bw_ptr: List[int], bw_ind: Tensor,
group: Any, group: Any,
) -> None: ) -> None:
self._ctx = DistributedContext.get_default_context() assert len(fw_ptr) == len(bw_ptr)
self._src_len = src_len self._src_len = src_len
self._dst_len = dst_len self._dst_len = dst_len
self._fw_ptr = fw_ptr self._fw_ptr = tuple(fw_ptr)
self._fw_ind = fw_ind self._fw_ind = fw_ind
self._bw_ptr = bw_ptr self._bw_ptr = tuple(bw_ptr)
self._bw_ind = bw_ind self._bw_ind = bw_ind
self._group = group self._group = group
@property
def ctx(self):
return self._ctx
@property @property
def group(self): def group(self):
return self._group return self._group
...@@ -164,104 +129,84 @@ class Route: ...@@ -164,104 +129,84 @@ class Route:
self._bw_ind = self._bw_ind.to(device) self._bw_ind = self._bw_ind.to(device)
return self return self
def fw_table(self, i: int): # def fw_table(self, i: int):
return self._fw_ind[:,self._fw_ptr[i]:self._fw_ptr[i+1]] # return self._fw_ind[self._fw_ptr[i]:self._fw_ptr[i+1]]
def bw_table(self, i: int): # def bw_table(self, i: int):
return self._bw_ind[:,self._bw_ptr[i]:self._bw_ptr[i+1]] # return self._bw_ind[self._bw_ptr[i]:self._bw_ptr[i+1]]
def apply(self, data: Tensor) -> Tensor: def apply(self,
return RouteAlltoAll.apply(data, self) data: Tensor,
cache: Optional['RouteWorkCache'] = None,
cache_key: Optional[str] = None,
) -> Tensor:
return RouteAlltoAll.apply(data, self, cache, cache_key)
def fw_tensor(self, data: Tensor) -> Tensor: @torch.no_grad()
def fw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.dst_len assert data.size(0) == self.dst_len
xs = self.all_to_all_fw(
[data[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros( output_tensor = torch.empty(
self.src_len, *data.shape[1:], self._bw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device, dtype=data.dtype, device=data.device,
) )
for i, t in enumerate(xs):
out[self.bw_table(i)[0]] += t work = all_to_all_s(
return out output_tensor, data[self._fw_ind],
self._bw_ptr, self._fw_ptr,
group=self.group,
async_op=async_op,
)
work = RouteWork(
work if async_op else None,
self._bw_ptr, self._bw_ind,
self.src_len, output_tensor,
)
return work if async_op else work.wait()
def bw_tensor(self, data: Tensor) -> Tensor: @torch.no_grad()
def bw_tensor(self, data: Tensor, async_op: bool = False):
assert data.size(0) == self.src_len assert data.size(0) == self.src_len
xs = self.all_to_all_bw(
[data[self.bw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros( output_tensor = torch.empty(
self.dst_len, *data.shape[1:], self._fw_ind.numel(), *data.shape[1:],
dtype=data.dtype, device=data.device, dtype=data.dtype, device=data.device,
) )
for i, t in enumerate(xs):
out[self.fw_table(i)[0]] += t work = all_to_all_s(
return out output_tensor, data[self._bw_ind],
self._fw_ptr, self._bw_ptr,
def get_src_part_ids(self) -> Tensor: group=self.group,
xs = self.all_to_all_fw( async_op=async_op,
[
torch.full(
(self.fw_table(i).size(1),), self.part_id,
dtype=torch.long, device=self.ctx.device,
)
for i in range(self.num_parts)
],
async_op=False,
) )
out = torch.full((self.src_len,), 2**16-1, dtype=torch.long, device=self.ctx.device) work = RouteWork(
for i, t in enumerate(xs): work if async_op else None,
out[self.bw_table(i)[0]] = t self._fw_ptr, self._fw_ind,
return out.int() & 0xFFFF self.dst_len, output_tensor,
def all_to_all_fw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = []
for i in range(self.num_parts):
t = input_tensor_list[i]
assert t.size(0) == self._fw_ptr[i+1] - self._fw_ptr[i]
s = self._bw_ptr[i+1] - self._bw_ptr[i]
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device)
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
) )
return work if async_op else work.wait()
if async_op:
return output_tensor_list, work
else:
return output_tensor_list
def all_to_all_bw(self, input_tensor_list: List[Tensor], async_op: bool = False): @torch.no_grad()
output_tensor_list: List[Tensor] = [] def get_src_part_ids(self) -> Tensor:
for i in range(self.num_parts): input_tensor = torch.full_like(self._fw_ind, self.part_id)
t = input_tensor_list[i] output_tensor = torch.empty_like(self._bw_ind)
assert t.size(0) == self._bw_ptr[i+1] - self._bw_ptr[i]
all_to_all_s(
s = self._fw_ptr[i+1] - self._fw_ptr[i] output_tensor, input_tensor,
output_tensor_list.append( self._bw_ptr, self._fw_ptr,
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device) group=self.group,
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
) )
if async_op: out = torch.full(
return output_tensor_list, work (self.src_len,), 2**16-1,
else: dtype=self._bw_ind.dtype,
return output_tensor_list device=self._bw_ind.device,
)
for s, t in zip(self._bw_ptr, self._bw_ptr[1:]):
ind = self._bw_ind[s:t]
assert (out[ind] == 2**16-1).all(), f"some vertices exist in more than one partition"
out[ind] = output_tensor[s:t] & 0xFF
return out
@staticmethod @staticmethod
def _build_route_tables( def _build_route_tables(
...@@ -275,8 +220,6 @@ class Route: ...@@ -275,8 +220,6 @@ class Route:
assert src_ids.dim() == 1 assert src_ids.dim() == 1
assert dst_ids.dim() == 1 assert dst_ids.dim() == 1
ctx = DistributedContext.get_default_context()
src_len = src_ids.size(0) src_len = src_ids.size(0)
dst_len = dst_ids.size(0) dst_len = dst_ids.size(0)
...@@ -292,6 +235,7 @@ class Route: ...@@ -292,6 +235,7 @@ class Route:
all_dst_lens: List[int] = [None] * world_size all_dst_lens: List[int] = [None] * world_size
dist.all_gather_object(all_dst_lens, dst_len, group=group) dist.all_gather_object(all_dst_lens, dst_len, group=group)
# all_reduce number of nodes
num_nodes = torch.zeros(1, **ikw) num_nodes = torch.zeros(1, **ikw)
if src_ids.numel() > 0: if src_ids.numel() > 0:
num_nodes = src_ids.max().max(num_nodes) num_nodes = src_ids.max().max(num_nodes)
...@@ -342,25 +286,54 @@ class Route: ...@@ -342,25 +286,54 @@ class Route:
src_ind = smp[ind] src_ind = smp[ind]
dst_ind = xmp[ind] dst_ind = xmp[ind]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区 fw_tables.append(dst_ind)
fw_tables.append(torch.vstack([dst_ind, src_ind])) bw_tables.append(src_ind)
bw_tables.append(torch.vstack([src_ind, dst_ind]))
fw_tables = Route.__backward_fw_tables(fw_tables, group=group)
fw_tables = [t.t().contiguous() for t in fw_tables]
fw_tables = ctx.all_to_all_g(fw_tables, group=group)
fw_tables = [t.t().contiguous() for t in fw_tables]
# 非二分图,每个点添加自环 # add self-loops if not bipartite graph
if not bipartite: if not bipartite:
rank_ind = torch.arange(dst_len, **ikw) rank_ind = torch.arange(dst_len, **ikw)
fw_tables[rank] = bw_tables[rank] = torch.vstack([rank_ind, rank_ind]) fw_tables[rank] = bw_tables[rank] = rank_ind
return fw_tables, bw_tables return fw_tables, bw_tables
@staticmethod @staticmethod
def __expand_index_to_mask(index: Tensor, len: int) -> Tensor: def __filter_ind_and_ptr(ptr: List[int], ind: Tensor, mask: Tensor) -> Tuple[List[int], Tensor]:
mask = torch.zeros(len, dtype=torch.bool, device=index.device) m = mask[ind]
mask[index] = True new_ptr: List[int] = [0]
return mask new_ind: List[Tensor] = []
for s, t in zip(ptr, ptr[1:]):
new_ind.append(ind[s:t][m[s:t]])
new_ptr.append(new_ptr[-1] + new_ind[-1].numel())
return new_ptr, torch.cat(new_ind, dim=0)
@staticmethod
def __remap_ind(ind: Tensor, mask: Tensor) -> Tuple[Tensor, int]:
n: int = mask.count_nonzero().item()
imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
imp[mask] = torch.arange(n, dtype=ind.dtype, device=ind.device)
return ind, int(n)
@staticmethod
def __backward_fw_tables(
fw_tables: List[Tensor],
group: Any,
) -> List[Tensor]:
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
send_sizes = [t.size() for t in fw_tables]
recv_sizes = [None] * world_size
dist.all_gather_object(recv_sizes, send_sizes, group=group)
recv_sizes = [s[rank] for s in recv_sizes]
fixed_tables = []
for s, t in zip(recv_sizes, fw_tables):
t = torch.empty(*s, dtype=t.dtype, device=t.device)
fixed_tables.append(t)
all_to_all_v(fixed_tables, fw_tables, group=group)
return fixed_tables
@staticmethod @staticmethod
def __tables_to_indptr( def __tables_to_indptr(
...@@ -369,21 +342,101 @@ class Route: ...@@ -369,21 +342,101 @@ class Route:
): ):
fw_ptr: List[int] = [0] fw_ptr: List[int] = [0]
for t in fw_tables: for t in fw_tables:
last_n = fw_ptr[-1] assert t.dim() == 1
fw_ptr.append(last_n + t.size(1)) fw_ptr.append(fw_ptr[-1] + t.numel())
fw_ind = torch.cat(fw_tables, dim=1) fw_ind = torch.cat(fw_tables, dim=0)
bw_ptr: List[int] = [0] bw_ptr: List[int] = [0]
for t in bw_tables: for t in bw_tables:
last_n = bw_ptr[-1] assert t.dim() == 1
bw_ptr.append(last_n + t.size(1)) bw_ptr.append(bw_ptr[-1] + t.numel())
bw_ind = torch.cat(bw_tables, dim=1) bw_ind = torch.cat(bw_tables, dim=0)
return { return {
"fw_ptr": fw_ptr, "bw_ptr": bw_ptr, "fw_ptr": fw_ptr, "bw_ptr": bw_ptr,
"fw_ind": fw_ind, "bw_ind": bw_ind, "fw_ind": fw_ind, "bw_ind": bw_ind,
} }
class RouteWork:
def __init__(self,
work: Any,
ptr: List[int],
ind: Tensor,
len: int,
recv_t: Tensor,
) -> None:
self._work = work
self._ptr = ptr
self._ind = ind
self._len = len
self._recv_t = recv_t
if self._work is None:
self._reduce()
def _reduce(self):
out = torch.zeros(
self._len, *self._recv_t.shape[1:],
dtype=self._recv_t.dtype,
device=self._recv_t.device,
)
for s, t in zip(self._ptr, self._ptr[1:]):
ind = self._ind[s:t]
out[ind] += self._recv_t[s:t]
self._work = None
self._ptr = None
self._ind = None
self._len = None
self._recv_t = out
def wait(self) -> Tensor:
if self._work is None:
return self._recv_t
self._work.wait()
self._reduce()
return self._recv_t
class RouteWorkCache:
def __init__(self,
enable_fw: bool = True,
enable_bw: bool = True,
) -> None:
self.enable_fw = enable_fw
self.enable_bw = enable_bw
self._cached_works: Dict[str, RouteWork] = {}
def enable_fw_(self, enable: bool = True):
self.enable_fw = enable
return self
def enable_bw_(self, enable: bool = True):
self.enable_bw = enable
return self
def wait(self):
for work in self._cached_works.values():
work.wait()
def clear(self):
self._cached_works.clear()
def get_and_set(self,
key: str,
work: RouteWork,
bw: bool = False,
) -> Optional[RouteWork]:
if bw and self.enable_bw:
key = key + "_bw"
elif not bw and self.enable_fw:
key = key + "_fw"
else:
return work
t = self._cached_works.get(key, work)
self._cached_works[key] = work
return t
class RouteAlltoAll(autograd.Function): class RouteAlltoAll(autograd.Function):
@staticmethod @staticmethod
...@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function): ...@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function):
ctx: autograd.function.FunctionCtx, ctx: autograd.function.FunctionCtx,
x: Tensor, x: Tensor,
route: Route, route: Route,
cache: Optional[RouteWorkCache],
cache_key: Optional[str],
): ):
ctx.saved_route = route ctx.saved_route = route
return route.fw_tensor(x) ctx.saved_cache = cache
ctx.saved_cache_key = cache_key
if cache is None or cache_key is None:
return route.fw_tensor(x)
else:
work = route.fw_tensor(x, async_op=True)
work = cache.get_and_set(cache_key, work, bw=False)
return work.wait()
@staticmethod @staticmethod
def backward( def backward(
...@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function): ...@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function):
grad: Tensor, grad: Tensor,
) -> Tensor: ) -> Tensor:
route: Route = ctx.saved_route route: Route = ctx.saved_route
return route.bw_tensor(grad), None cache: Optional[RouteWorkCache] = ctx.saved_cache
\ No newline at end of file cache_key: Optional[str] = ctx.saved_cache_key
if cache is None or cache_key is None:
return route.bw_tensor(grad), None, None, None
else:
work = route.bw_tensor(grad, async_op=True)
work = cache.get_and_set(cache_key, work, bw=True)
return work.wait(), None, None, None
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment