Commit 09c175e2 by Wenjie Huang

add demo train_hybrid.py

parent 057dc57f
...@@ -152,8 +152,11 @@ def batch_send( ...@@ -152,8 +152,11 @@ def batch_send(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op: if async_op:
works = [] works = []
...@@ -177,8 +180,11 @@ def batch_recv( ...@@ -177,8 +180,11 @@ def batch_recv(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op: if async_op:
works = [] works = []
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import socket
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
...@@ -127,6 +128,18 @@ class DistributedContext: ...@@ -127,6 +128,18 @@ class DistributedContext:
self._local_rank = local_rank self._local_rank = local_rank
self._compute_device = device self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None self.__temp_ag_remote_object: Optional[rpc.RRef] = None
...@@ -145,16 +158,87 @@ class DistributedContext: ...@@ -145,16 +158,87 @@ class DistributedContext:
@property @property
def local_rank(self) -> int: def local_rank(self) -> int:
return self._local_rank return self._local_rank
@property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self._compute_device return self._compute_device
def get_default_group(self): def get_default_group(self):
return dist.distributed_c10d._get_default_group() # return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self): def get_default_store(self):
return dist.distributed_c10d._get_default_store() return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo: def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank rank = dist.get_rank() if rank is None else rank
......
...@@ -24,6 +24,9 @@ class Route: ...@@ -24,6 +24,9 @@ class Route:
bipartite: bool = True, bipartite: bool = True,
group: Any = None, group: Any = None,
) -> 'Route': ) -> 'Route':
if group is None:
group = dist.GroupMember.WORLD
fw_tables, bw_tables = Route._build_route_tables( fw_tables, bw_tables = Route._build_route_tables(
src_ids=src_ids, dst_ids=dst_ids, src_ids=src_ids, dst_ids=dst_ids,
bipartite=bipartite, group=group, bipartite=bipartite, group=group,
...@@ -256,8 +259,11 @@ class Route: ...@@ -256,8 +259,11 @@ class Route:
all_dst_ids[i] = dst_ids all_dst_ids[i] = dst_ids
else: else:
all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw) all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw)
src_rank = dist.get_global_rank(group, i)
all_dst_get[i] = dist.broadcast( all_dst_get[i] = dist.broadcast(
all_dst_ids[i], src=i, async_op=True, group=group all_dst_ids[i], src=src_rank,
async_op=True, group=group,
) )
fw_tables: List[Tensor] = [] fw_tables: List[Tensor] = []
......
...@@ -58,6 +58,9 @@ class SparseBlocks: ...@@ -58,6 +58,9 @@ class SparseBlocks:
def __fetch_ids_sizes(local_ids: Tensor, group: Any): def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1 assert local_ids.dim() == 1
if group is None:
group = dist.GroupMember.WORLD
rank = dist.get_rank(group) rank = dist.get_rank(group)
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device) ikw = dict(dtype=torch.long, device=local_ids.device)
...@@ -80,8 +83,9 @@ class SparseBlocks: ...@@ -80,8 +83,9 @@ class SparseBlocks:
all_ids[i] = local_ids all_ids[i] = local_ids
else: else:
all_ids[i] = torch.empty(all_lens[i], **ikw) all_ids[i] = torch.empty(all_lens[i], **ikw)
src = dist.get_global_rank(group, i)
all_get[i] = dist.broadcast( all_get[i] = dist.broadcast(
all_ids[i], src=i, async_op=True, group=group all_ids[i], src=src, async_op=True, group=group
) )
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw) imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function): ...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_fetch(i: int): def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1) n = sp.adj_t(i).sparse_size(1)
if i == part_id: if i == part_id:
h = x.clone() h = x.clone()
else: else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device) h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
return dist.broadcast(h, src=i, group=sp.group, async_op=True) src = dist.get_global_rank(group, i)
return dist.broadcast(h, src=src, group=sp.group, async_op=True)
last_work = None last_work = None
out = None out = None
...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function): ...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_reduce(i: int, g: Tensor): def async_reduce(i: int, g: Tensor):
dst = dist.get_global_rank(group, i)
return dist.reduce( return dist.reduce(
g, dst=i, op=dist.ReduceOp.SUM, g, dst=dst, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True, group=sp.group, async_op=True,
) )
......
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe
from starrygl.parallel.sequence import STensor
from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
import torch_geometric.datasets as pyg_datasets
import torch_geometric.utils as pyg_utils
import logging
logging.getLogger().setLevel(logging.INFO)
def prepare_data(root: str, num_parts, part_algo: str = "metis"):
ctx = DistributedContext.get_default_context()
data = pyg_datasets.Planetoid(root, "Cora")[0]
if data.is_directed():
data.edge_index, _ = pyg_utils.to_undirected(data.edge_index)
data.edge_index, _ = pyg_utils.add_remaining_self_loops(data.edge_index)
data.num_classes = data.y.max().item() + 1
logging.info(f"num_nodes: {data.num_nodes}")
logging.info(f"num_edges: {data.num_edges}")
logging.info(f"num_features: {data.num_features}")
logging.info(f"num_classes: {data.num_classes}")
g = GraphData.from_pyg_data(data)
logging.info(f"GraphData.meta().keys(): {g.meta().keys()}")
logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo)
return g
class SimpleConv(pyg_nn.MessagePassing):
def __init__(self, in_feats: int, out_feats: int):
super().__init__(aggr="mean")
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
dst_len = x.size(0)
x = route.apply(x) # exchange features
return self.propagate(edge_index, x=x)[:dst_len]
def message(self, x_j: Tensor):
return x_j
def update(self, x: Tensor):
return F.relu(self.linear(x))
class SimpleGNN(nn.Module):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_ch = hidden_dims if i > 0 else num_features
out_ch = hidden_dims
self.layers.append(SimpleConv(in_ch, out_ch))
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
for layer in self.layers:
x = layer(x, edge_index, route)
return x
class SimpleRNN(SequencePipe, nn.Module):
def __init__(self,
num_classes: int,
hidden_dims: int,
num_layers: int,
device: Any,
group: Any,
) -> None:
super().__init__()
self.device = device
self.group = group
self.num_layers = num_layers
self.hidden_dims = hidden_dims
self.gru = nn.GRU(
input_size = hidden_dims,
hidden_size = hidden_dims,
num_layers = num_layers,
batch_first = True,
)
self.out = nn.Linear(hidden_dims, num_classes)
def forward(self, inputs, states):
x, = inputs # (N, L, H)
h, = states # (N, L, H)
h = h.transpose(0, 1).contiguous() # (L, N, H)
x, h = self.gru(x, h) # (N, L, H), (L, N, H)
h = h.transpose(0, 1).contiguous() # (N, L, H)
return (x,), (h, )
def loss_fn(self, inputs, labels) -> Tensor:
x, = inputs
return x.square().mean()
def get_group(self) -> Any:
return self.group
def get_init_states(self):
s = torch.zeros(self.num_layers, self.hidden_dims).to(self.device)
return (s,)
if __name__ == "__main__":
data_root = "./dataset"
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
hybrid_matrix = ctx.get_hybrid_matrix()
if hybrid_matrix.size(0) == 1:
hybrid_matrix = hybrid_matrix.view(2, -1)
ctx.sync_print(hybrid_matrix)
# sp is sequence parallel
# pp is partition parallel
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
# partition data
if ctx.rank == 0:
prepare_data(data_root, dist.get_world_size(pp_group))
dist.barrier()
g = GraphData.load_partition(
data_root,
dist.get_rank(pp_group), dist.get_world_size(pp_group),
).to(ctx.device)
route = g.to_route(pp_group) # only on subgroup
num_features = g.node("dst")["x"].size(-1)
num_classes = g.meta()["num_classes"]
hidden_dims = 128
num_layers = 3
gnn = SimpleGNN(num_features, hidden_dims, num_layers).to(ctx.device)
rnn = SimpleRNN(num_classes, hidden_dims, num_layers, device=ctx.device, group=sp_group).to(ctx.device)
opt = torch.optim.Adam([p for p in gnn.parameters()] + [p for p in rnn.parameters()])
for ep in range(1, 100+1):
seq_len = 200
xs = []
opt.zero_grad()
for _ in range(seq_len): # snapshot parallel between partition parallel subgroups
z = gnn(
x = g.node("dst")["x"],
edge_index = g.edge_index(),
route = route, #
)
xs.append(z.unsqueeze(1))
x = torch.cat(xs, dim=1) # (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss = rnn.fast_backward(32, (x,), (g.node("dst")["train_mask"],))
# all reduce
all_reduce_gradients(rnn)
all_reduce_buffers(rnn)
all_reduce_gradients(gnn)
all_reduce_buffers(gnn)
opt.step()
ctx.sync_print(loss)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment