Commit 9586742a by Wenjie Huang

fix bugs. Route

parent 10c38111
...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected ...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp import os.path as osp
import sys import sys
from starrygl.utils.data import partition_pyg from starrygl.graph import GraphData
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
...@@ -18,7 +18,9 @@ if __name__ == "__main__": ...@@ -18,7 +18,9 @@ if __name__ == "__main__":
print(f"num_nodes: {data.num_nodes}") print(f"num_nodes: {data.num_nodes}")
print(f"num_edges: {data.num_edges}") print(f"num_edges: {data.num_edges}")
print(f"num_features: {data.num_features}") print(f"num_features: {data.num_features}")
data = GraphData.from_pyg_data(data)
num_parts_list = [1, 2, 3, 5, 7, 9, 11] num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", 'mt-metis', "random"] algos = ["metis", 'mt-metis', "random"]
...@@ -27,4 +29,4 @@ if __name__ == "__main__": ...@@ -27,4 +29,4 @@ if __name__ == "__main__":
for num_parts in num_parts_list: for num_parts in num_parts_list:
for algo in algos: for algo in algos:
print(f"======== {num_parts} + {algo} ========") print(f"======== {num_parts} + {algo} ========")
partition_pyg(root, data, num_parts, algo) data.save_partition(root, num_parts, algo)
\ No newline at end of file
...@@ -5,7 +5,7 @@ from torch import Tensor ...@@ -5,7 +5,7 @@ from torch import Tensor
from typing import * from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed import DistributedContext
from starrygl.graph import new_vc_route from starrygl.graph import *
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
...@@ -28,32 +28,38 @@ all_eparts = [ ...@@ -28,32 +28,38 @@ all_eparts = [
], ],
] ]
def get_data():
def get_route(bipartite: bool = True):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
assert ctx.world_size == 3 assert ctx.world_size == 3
dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device) dst_ids = torch.tensor(all_nparts[ctx.rank], dtype=torch.long, device=ctx.device)
edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t() edge_index = torch.tensor(all_eparts[ctx.rank], dtype=torch.long, device=ctx.device).t()
return new_vc_route(dst_ids, edge_index, bipartite=bipartite) src_ids, edge_index = init_vc_edge_index(dst_ids, edge_index)
return GraphData.from_bipartite(edge_index, raw_src_ids=src_ids, raw_dst_ids=dst_ids)
if __name__ == "__main__": if __name__ == "__main__":
ctx = DistributedContext.init(backend="gloo", use_gpu=True) ctx = DistributedContext.init(backend="gloo", use_gpu=True)
src_ids, edge_index, dst_ids, route = get_route(False) g = get_data()
src_size = route.src_len route = g.to_route()
dst_size = route.dst_len edge_index = g.edge_index()
# src_ids, edge_index, dst_ids, route = get_route(False)
# src_size = route.src_len
# dst_size = route.dst_len
ctx.sync_print(route.src_len, route.dst_len) ctx.sync_print(route.src_len, route.dst_len)
ctx.sync_print(route._fw_ptr, route._fw_ind)
ctx.sync_print(route._bw_ptr, route._bw_ind)
edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_() edge_ones = torch.ones(edge_index.size(1), device=ctx.device).requires_grad_()
src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len) src_ones = scatter_sum(edge_ones, edge_index[0], dim=0, dim_size=route.src_len)
dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len) dst_ones = scatter_sum(edge_ones, edge_index[1], dim=0, dim_size=route.dst_len)
# ctx.sync_print(route.fw_tensor(dst_ones)) ctx.sync_print(route.fw_tensor(dst_ones))
# ctx.sync_print(route.bw_tensor(src_ones)) ctx.sync_print(route.bw_tensor(src_ones))
out = route.reverse_route().apply(src_ones) out = route.rev().apply(src_ones)
ctx.sync_print(out) ctx.sync_print(out)
out.sum().backward() out.sum().backward()
...@@ -61,4 +67,13 @@ if __name__ == "__main__": ...@@ -61,4 +67,13 @@ if __name__ == "__main__":
ctx.sync_print(route.get_src_part_ids()) ctx.sync_print(route.get_src_part_ids())
dst_mask = torch.full((route.dst_len,), ctx.rank % 2, dtype=torch.bool, device=ctx.device)
ctx.main_print("="*64)
ctx.sync_print(dst_mask)
_, _, r2 = route.filter(dst_mask)
ctx.sync_print(r2.apply(dst_ones).detach())
ctx.sync_print(r2.rev().apply(src_ones).detach())
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
ctx.shutdown() ctx.shutdown()
...@@ -45,8 +45,6 @@ def all_to_all_v( ...@@ -45,8 +45,6 @@ def all_to_all_v(
assert len(output_tensor_list) == world_size assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size assert len(input_tensor_list) == world_size
# if group is None:
# group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group) backend = dist.get_backend(group)
if backend == "nccl": if backend == "nccl":
......
from .route import Route from .data import *
from .utils import init_vc_edge_index from .route import *
\ No newline at end of file
from torch import Tensor
from typing import Tuple
__all__ = [
"Route",
"init_vc_edge_index",
"new_vc_route",
]
def new_vc_route(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Route]:
src_ids, local_edge_index = init_vc_edge_index(
dst_ids, edge_index, bipartite=bipartite)
route = Route.from_raw_indices(
src_ids, dst_ids, bipartite=bipartite)
return src_ids, local_edge_index, dst_ids, route
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment