Commit 6fb76097 by Wenjie Huang

update

parent 57d75031
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.graph import new_vc_route
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_route(bipartite: bool = True):
ctx = DistributedContext.get_default_context()
assert ctx.world_size == 3
dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device)
edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t()
return new_vc_route(dst_ids, edge_index, bipartite=bipartite)
if __name__ == "__main__":
ctx = DistributedContext.init(backend="gloo", use_gpu=True)
src_ids, edge_index, dst_ids, route = get_route(False)
src_size = route.src_len
dst_size = route.dst_len
ctx.sync_print(route.src_len, route.dst_len)
edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_()
src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len)
dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len)
# ctx.sync_print(route.fw_tensor(dst_ones))
# ctx.sync_print(route.bw_tensor(src_ones))
out = route.reverse_route().apply(src_ones)
ctx.sync_print(out)
out.sum().backward()
ctx.sync_print(edge_ones.grad)
ctx.sync_print(route.get_src_part_ids())
ctx.shutdown()
from .route import Route from .route import Route
from .utils import init_vc_edge_index from .utils import init_vc_edge_index
from torch import Tensor
from typing import Tuple
__all__ = [
"Route",
"init_vc_edge_index",
"new_vc_route",
]
def new_vc_route(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Route]:
src_ids, local_edge_index = init_vc_edge_index(
dst_ids, edge_index, bipartite=bipartite)
route = Route.from_raw_indices(
src_ids, dst_ids, bipartite=bipartite)
return src_ids, local_edge_index, dst_ids, route
...@@ -31,33 +31,48 @@ class Route: ...@@ -31,33 +31,48 @@ class Route:
group=group, group=group,
) )
def new_subroute(self, dst_mask: Tensor): def subroute(self,
dst_mask: Tensor,
src_mask: Optional[Tensor] = None,
):
if dst_mask.dtype != torch.bool: if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len) dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len assert dst_mask.size(0) == self.dst_len
if src_mask is None:
fw_dst_masks, work = self.all_to_all_fw( fw_dst_masks, work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)], input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True, async_op=True,
) )
else:
if src_mask.dtype != torch.bool:
src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
assert src_mask.size(0) == self.src_len
fw_tables: List[Tensor] = [] fw_tables: List[Tensor] = []
for i in range(self.num_parts): for i in range(self.num_parts):
m = dst_mask[self.fw_table(i)[0]] m = dst_mask[self.fw_table(i)[0]]
fw_tables.append(self.fw_table(i)[:,m]) fw_tables.append(self.fw_table(i)[:,m])
work.wait()
bw_tables: List[Tensor] = [] bw_tables: List[Tensor] = []
if src_mask is None:
src_mask = torch.zeros(self.src_len, dtype=dst_mask.dtype, device=dst_mask.device)
work.wait()
for i, m in enumerate(fw_dst_masks): for i, m in enumerate(fw_dst_masks):
src_mask[self.bw_table(i)[0]] |= m
bw_tables.append(self.bw_table(i)[:,m])
else:
for i in range(self.num_parts):
m = src_mask[self.bw_table(i)[0]]
bw_tables.append(self.bw_table(i)[:,m]) bw_tables.append(self.bw_table(i)[:,m])
return Route( return dst_mask, Route(
src_len=self.src_len, src_len=self.src_len,
dst_len=self.dst_len, dst_len=self.dst_len,
**Route.__tables_to_indptr(fw_tables, bw_tables), **Route.__tables_to_indptr(fw_tables, bw_tables),
group=self.group, group=self.group,
) ), src_mask, dst_mask
def reverse_route(self): def reverse_route(self):
return Route( return Route(
...@@ -68,6 +83,43 @@ class Route: ...@@ -68,6 +83,43 @@ class Route:
group=self.group, group=self.group,
) )
def filter_nodes(self,
dst_mask: Tensor,
src_mask: Optional[Tensor] = None,
):
if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len
new_dst_len = dst_mask.count_nonzero().item()
xmp = torch.empty(self.dst_len, dtype=torch.long, device=dst_mask.device)
xmp.fill_((2**62-1)*2+1)
xmp[dst_mask] = torch.arange(new_dst_len, dtype=torch.long, device=dst_mask.device)
fw_dst_masks, fw_m_work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def filter_edges(self,
edge_mask: Tensor,
edge_index: Tensor,
):
pass
def __init__(self, def __init__(self,
src_len: int, dst_len: int, src_len: int, dst_len: int,
fw_ptr: List[int], fw_ind: Tensor, fw_ptr: List[int], fw_ind: Tensor,
...@@ -151,6 +203,22 @@ class Route: ...@@ -151,6 +203,22 @@ class Route:
out[self.fw_table(i)[0]] += t out[self.fw_table(i)[0]] += t
return out return out
def get_src_part_ids(self) -> Tensor:
xs = self.all_to_all_fw(
[
torch.full(
(self.fw_table(i).size(1),), self.part_id,
dtype=torch.long, device=self.ctx.device,
)
for i in range(self.num_parts)
],
async_op=False,
)
out = torch.full((self.src_len,), 2**16-1, dtype=torch.long, device=self.ctx.device)
for i, t in enumerate(xs):
out[self.bw_table(i)[0]] = t
return out.int() & 0xFFFF
def all_to_all_fw(self, input_tensor_list: List[Tensor], async_op: bool = False): def all_to_all_fw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = [] output_tensor_list: List[Tensor] = []
for i in range(self.num_parts): for i in range(self.num_parts):
...@@ -334,4 +402,4 @@ class RouteAlltoAll(autograd.Function): ...@@ -334,4 +402,4 @@ class RouteAlltoAll(autograd.Function):
grad: Tensor, grad: Tensor,
) -> Tensor: ) -> Tensor:
route: Route = ctx.saved_route route: Route = ctx.saved_route
return route.bw_tensor(grad) return route.bw_tensor(grad), None
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment