Commit ffb150f7 by Wenjie Huang

add 1f1b pipeline parallel

parent 9c1f47bb
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
......@@ -5,7 +5,7 @@ from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.graph import *
from starrygl.data import *
from torch_scatter import scatter_sum
......@@ -76,4 +76,14 @@ if __name__ == "__main__":
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
spb = g.to_sparse()
x = torch.randn(g.node("dst").num_nodes, 64, device=ctx.device).requires_grad_()
ctx.sync_print(x[:,0])
y = spb.apply(x)
ctx.sync_print(y[:,0])
y.sum().backward()
ctx.sync_print(x.grad[:,0])
ctx.shutdown()
from .graph import *
from .utils import *
\ No newline at end of file
......@@ -8,13 +8,15 @@ from pathlib import Path
from torch_sparse import SparseTensor
from starrygl.utils.partition import *
from .route import Route
from starrygl.parallel.route import Route
from starrygl.parallel.sparse import SparseBlocks
from .utils import init_vc_edge_index
import logging
__all__ = [
"GraphData",
"init_vc_edge_index",
]
......@@ -85,6 +87,17 @@ class GraphData:
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
def to_sparse(self, key: Optional[str] = None, group: Any = None) -> SparseBlocks:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
edge_index = self.edge_index()
edge_index = torch.vstack([
src_ids[edge_index[0]],
dst_ids[edge_index[1]],
])
edge_attr = None if key is None else self.edge()[key]
return SparseBlocks.from_raw_indices(dst_ids, edge_index, edge_attr=edge_attr, group=group)
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
......@@ -194,15 +207,6 @@ class GraphData:
raw_dst_ids, local_edges, bipartite=True,
)
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
......@@ -321,139 +325,3 @@ class EdgeData:
def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
# def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
# p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
# return torch.load(p.__str__())
# def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
# root_path = Path(root).expanduser().resolve()
# base_path = root_path / f"{algo}_{num_parts}"
# if base_path.exists():
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
# for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
# logging.info(f"saving partition data: {i+1}/{num_parts}")
# torch.save(g, (base_path / f"{i:03d}").__str__())
# def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
# from torch_geometric.data import Data
# assert isinstance(data, Data), f"must be Data class in pyg"
# logging.info(f"running partition aglorithm: {algo}")
# num_nodes: int = data.num_nodes
# num_edges: int = data.num_edges
# edge_index: Tensor = data.edge_index
# if algo == "metis":
# node_parts = metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "mt-metis":
# node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "random":
# node_parts = random_partition(edge_index, num_nodes, num_parts)
# else:
# raise ValueError(f"unknown partition algorithm: {algo}")
# if data.y.dtype == torch.long:
# if data.y.dim() == 1:
# num_classes = data.y.max().item() + 1
# else:
# num_classes = data.y.size(1)
# else:
# num_classes = None
# for i in range(num_parts):
# npart_mask = node_parts == i
# epart_mask = npart_mask[edge_index[1]]
# local_edges = edge_index[:, epart_mask]
# raw_src_ids: Tensor = local_edges[0].unique()
# raw_dst_ids: Tensor = torch.where(npart_mask)[0]
# M: int = raw_src_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
# imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
# local_src = imap[local_edges[0]]
# M: int = raw_dst_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
# imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
# local_dst = imap[local_edges[1]]
# local_edges = torch.vstack([local_src, local_dst])
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# },
# )
# g.node("src")["raw_ids"] = raw_src_ids
# g.node("dst")["raw_ids"] = raw_dst_ids
# if num_classes is not None:
# g.meta()["num_classes"] = num_classes
# for key, val in data:
# if key == "edge_index":
# continue
# elif isinstance(val, Tensor):
# if val.size(0) == num_nodes:
# g.node("dst")[key] = val[npart_mask]
# elif val.size(0) == num_edges:
# g.edge()[key] = val[epart_mask]
# elif isinstance(val, SparseTensor):
# pass
# else:
# g.meta()[key] = val
# yield g
import torch
from torch import Tensor
from typing import *
__all__ = [
"init_vc_edge_index",
]
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
......@@ -8,33 +8,45 @@ __all__ = [
"all_to_all_v",
"all_to_all_s",
"BatchWork",
"batch_send",
"batch_recv",
]
class BatchWork:
def __init__(self, works, buffer_tensor_list) -> None:
if isinstance(works, (list, tuple)):
assert len(works) // 2 == len(buffer_tensor_list)
self._works = works
self._buffer_tensor_list = buffer_tensor_list
def __init__(self,
works: Optional[List[Any]],
buffer_tensor_list: Optional[List[Tuple[Tensor, Optional[Tensor]]]],
step: int = 1,
) -> None:
if works is None:
self._step = None
self._works = None
self._buffer_tensor_list = None
else:
assert self._buffer_tensor_list is None
if buffer_tensor_list:
assert len(works) // step == len(buffer_tensor_list)
self._step = step
self._works = works
self._buffer_tensor_list = None
self._buffer_tensor_list = buffer_tensor_list
def wait(self):
if self._works is None:
return
if isinstance(self._works, (list, tuple)):
for i, w in enumerate(self._works):
if w is not None:
w.wait()
if i % 2 == 0:
if (i + 1) % self._step != 0:
continue
out, buf = self._buffer_tensor_list[i // 2]
if self._buffer_tensor_list:
out, buf = self._buffer_tensor_list[i // self._step]
if buf is not None:
out.copy_(buf)
else:
self._works.wait()
self._step = None
self._works = None
self._buffer_tensor_list = None
......@@ -60,7 +72,7 @@ def all_to_all_v(
group=group,
async_op=async_op,
)
return BatchWork(work, None) if async_op else None
return BatchWork([work], None) if async_op else None
elif backend == "mpi":
work = dist.all_to_all(
......@@ -69,7 +81,7 @@ def all_to_all_v(
group=group,
async_op=async_op,
)
return BatchWork(work, None) if async_op else None
return BatchWork([work], None) if async_op else None
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
......@@ -98,7 +110,7 @@ def all_to_all_v(
dist.irecv(recv_b, recv_i, group=group),
])
work = BatchWork(p2p_op_works, buffer_tensor_list)
work = BatchWork(p2p_op_works, buffer_tensor_list, 2)
output_tensor_list[rank].copy_(input_tensor_list[rank])
if async_op:
......@@ -130,3 +142,55 @@ def all_to_all_s(
group=group,
async_op=async_op,
)
def batch_send(
*tensors: Tensor,
dst: int,
group: Any = None,
async_op: bool = False,
):
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
if async_op:
works = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
works.append(dist.isend(t, dst=dst, group=group))
return BatchWork(works, None)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
dist.send(t, dst=dst, group=group)
def batch_recv(
*tensors: Tensor,
src: int,
group: Any = None,
async_op: bool = False,
):
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
if async_op:
works = []
output_tensor_list = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
works.append(dist.irecv(b, src=src, group=group))
else:
b = None
works.append(dist.irecv(t, src=src, group=group))
output_tensor_list.append((t, b))
return BatchWork(works, output_tensor_list, 1)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
dist.recv(b, src=src, group=group)
t.copy_(b)
else:
dist.recv(t, src=src, group=group)
......@@ -140,8 +140,11 @@ class DistributedTensor:
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes = DistInt(local_sizes)
self._num_parts = DistInt([1] * len(self.rrefs))
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def dtype(self):
......@@ -152,11 +155,19 @@ class DistributedTensor:
return self.accessor.data.device
@property
def num_nodes(self) -> DistInt:
def num_nodes(self) -> int:
return self._num_nodes
@property
def num_parts(self) -> DistInt:
def num_part_nodes(self) -> tuple[int,...]:
return self._num_part_nodes
@property
def part_id(self) -> int:
return self._part_id
@property
def num_parts(self) -> int:
return self._num_parts
def to(self,device):
......@@ -235,7 +246,7 @@ class DistributedTensor:
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
for i in range(self.num_parts):
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
......
from .data import *
from .route import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
from .route import *
from .sequence import *
from .sparse import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork
class SequencePipe:
def __init__(self,
batch_size: int,
seq_ranks: List[int],
group: Any,
) -> None:
self._batch_size = int(batch_size)
self._seq_ranks = tuple(seq_ranks)
self._group = group
rank = dist.get_rank(group)
self._index = self._seq_ranks.index(rank)
self._init_states: Optional[Tuple[Tensor,...]] = None
@property
def batch_size(self) -> int:
return self._batch_size
@property
def seq_ranks(self):
return self._seq_ranks
@property
def num_ranks(self) -> int:
return len(self._seq_ranks)
@property
def index(self) -> int:
return self._index
@property
def group(self):
return self._group
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
if grad: # from target to source
for s in states:
s.grad = torch.zeros_like(s.data)
if self.index + 1 < self.num_ranks:
batch_recv(
*[s.grad for s in states],
src=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True,
).wait()
else: # from source to target
for s in states:
s.grad = None
if self.index > 0:
batch_recv(
*[s.data for s in states],
src=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
).wait()
def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
if grad: # from target to source
if self.index > 0:
return batch_send(
*[s.grad for s in states],
dst=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
)
else: # from source to target
if self.index + 1 < self.num_ranks:
return batch_send(
*[s.data for s in states],
dst=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True
)
return BatchWork(None, None)
def init_states(self, *states: Tensor):
for s in states:
assert isinstance(s, Tensor), f"states must be tuple of tensors"
self._init_states = tuple(states)
return self
def batch_size_(self, bs: int):
self._batch_size = bs
return self
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()."
states: List[Tensor] = []
for s in self._init_states:
s = s.unsqueeze(0).broadcast_to(bs, *s.size())
states.append(s.contiguous())
if requires_grad and self.index > 0:
states = [s.requires_grad_() for s in states]
return tuple(states)
def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
detach_inputs: List[Tensor] = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
detach_inputs.append(t.detach())
return tuple(detach_inputs)
def state_forward(self,
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
raise NotImplementedError()
def loss_fn(self,
batch_id: int,
outputs: Tuple[Tensor,...],
) -> Tensor:
raise NotImplementedError()
def _forward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...], BatchWork]:
self._load_states(states)
outputs, next_states = self.state_forward(batch_id, inputs, states)
work = self._save_states(next_states)
return outputs, next_states, work
def _backward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
outputs: Tuple[Tensor,...],
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor, BatchWork]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
total_loss = loss
self._load_states(output_states, grad=True)
for s in output_states:
g, s.grad = s.grad, None
total_loss += torch.sum(s * g)
total_loss.backward()
work = self._save_states(input_states, grad=True)
input_grads = []
for t in inputs:
input_grads.append(t.grad)
t.grad = None
return tuple(input_grads), loss.detach(), work
def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
inputs = self._detach_inputs(inputs)
B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size
with torch.no_grad():
outputs = []
last_work = None
for batch_id in range(num_batchs):
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = tuple(t[start:end] for t in inputs)
batch_states = self._get_batch_states(end - start)
batch_outputs, _, work = self._forward_inner(batch_id, batch_inputs, batch_states)
outputs.append(batch_outputs)
if last_work is not None:
last_work.wait()
last_work = work
concat_outputs = []
for ts in zip(*outputs):
concat_outputs.append(torch.cat(ts, dim=0))
if last_work is not None:
last_work.wait()
return tuple(concat_outputs)
def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tuple[Tuple[Tensor,...], Tensor]:
inputs = self._detach_inputs(inputs)
B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size
fw_batch_id = 0
bw_batch_id = 0
bw_offset = 2 * self.num_ranks - self.index - 1
bw_ready: Dict[int, Any] = {}
last_bw_work = None
input_grads = []
total_loss = None
# hist = []
count = 0
while fw_batch_id < num_batchs or bw_batch_id < num_batchs:
if count >= bw_offset + 2 * bw_batch_id:
if bw_batch_id < num_batchs:
bs, work, *fw_graph = bw_ready.pop(bw_batch_id)
work.wait()
scale_factor = scale * self.batch_size / bs
grads, loss, work = self._backward_inner(bw_batch_id, *fw_graph, scale_factor=scale_factor)
# hist.append(f"{count+1}bw{bw_batch_id + 1}")
if last_bw_work is not None:
last_bw_work.wait()
last_bw_work = work
total_loss = loss if total_loss is None else total_loss + loss
input_grads.append(grads)
bw_batch_id += 1
else:
fw_starting = (fw_batch_id < self.num_ranks and count >= fw_batch_id + self.index)
fw_occupying = (count >= self.index + 2 * fw_batch_id)
if fw_batch_id < num_batchs and (fw_starting or fw_occupying):
start = fw_batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = tuple(t[start:end].requires_grad_() for t in inputs)
batch_states = self._get_batch_states(end - start, requires_grad=True)
batch_outputs, batch_next_states, work = self._forward_inner(fw_batch_id, batch_inputs, batch_states)
# hist.append(f"{count+1}fw{fw_batch_id + 1}")
bw_ready[fw_batch_id] = [
end - start,
work,
batch_inputs,
batch_outputs,
batch_states,
batch_next_states,
]
fw_batch_id += 1
count += 1
concat_grads = []
for ts in zip(*input_grads):
concat_grads.append(torch.cat(ts, dim=0))
if last_bw_work is not None:
last_bw_work.wait()
# print(f"{self.index+1}: {hist}")
return tuple(concat_grads), total_loss
from typing import Any
import torch
import torch.distributed as dist
import torch.autograd as autograd
from torch import Tensor
from typing import *
from torch_sparse import SparseTensor
__all__ = [
"SparseBlocks",
]
class SparseBlocks:
@staticmethod
def from_raw_indices(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Optional[Tensor] = None,
edge_attr: Optional[Tensor] = None,
group: Any = None,
) -> 'SparseBlocks':
assert edge_index.dim() == 2 and edge_index.size(0) == 2
if src_ids is None:
src_ids = dst_ids
src_ids, src_ptr = SparseBlocks.__fetch_ids_sizes(src_ids, group=group)
adj_ts = SparseBlocks.__remap_adj_t(dst_ids, edge_index, src_ids, src_ptr, edge_attr)
return SparseBlocks(adj_ts, group=group)
def __init__(self, adj_ts: List[SparseTensor], group: Any) -> None:
self._adj_ts = adj_ts
self._group = group
def adj_t(self, i: int) -> SparseTensor:
return self._adj_ts[i]
@property
def group(self):
return self._group
@property
def part_id(self) -> int:
return dist.get_rank(self._group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self._group)
def apply(self, x: Tensor) -> Tensor:
return SparseBlockMM.apply(self, x)
@staticmethod
def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device)
all_lens: List[int] = [None] * world_size
dist.all_gather_object(all_lens, local_ids.numel(), group=group)
# all reduce num_nodes
num_nodes = torch.zeros(1, **ikw)
if local_ids.numel() > 0:
num_nodes = local_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes: int = num_nodes.item() + 1
# async fetch remote ids
all_ids: List[Tensor] = [None] * world_size
all_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_ids[i] = local_ids
else:
all_ids[i] = torch.empty(all_lens[i], **ikw)
all_get[i] = dist.broadcast(
all_ids[i], src=i, async_op=True, group=group
)
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
offset: int = 0
for i in range(world_size):
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_get[i].wait()
ids = all_ids[i]
assert (imp[ids] == (2**62-1)*2+1).all(), "all ids must be orthogonal."
imp[ids] = torch.arange(offset, offset + all_lens[i], **ikw)
offset += all_lens[i]
assert (imp != (2**62-1)*2+1).all(), "some points that do not exist."
ids = torch.cat(all_ids, dim=0)
ptr: List[int] = [0]
for s in all_lens:
ptr.append(ptr[-1] + s)
return ids, ptr
@staticmethod
def __remap_adj_t(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Tensor,
src_ptr: List[int],
edge_attr: Optional[Tensor],
) -> List[SparseTensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
imp: Tensor = torch.full((dst_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[dst_ids] = torch.arange(dst_ids.numel(), **ikw)
dst = imp[edge_index[1]]
assert (dst != (2**62-1)*2+1).all()
imp: Tensor = torch.full((src_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[src_ids] = torch.arange(src_ids.numel(), **ikw)
src = imp[edge_index[0]]
assert (src != (2**62-1)*2+1).all()
edge_index = torch.vstack([src, dst])
adj = SparseTensor.from_edge_index(
edge_index=edge_index,
edge_attr=edge_attr,
sparse_sizes=(src_ids.numel(), dst_ids.numel()),
)
adj_ts: List[SparseTensor] = []
for s, t in zip(src_ptr, src_ptr[1:]):
adj_ts.append(adj[s:t].t())
return adj_ts
class SparseBlockMM(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
sp: SparseBlocks,
x: Tensor,
):
part_id = sp.part_id
num_parts = sp.num_parts
def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1)
if i == part_id:
h = x.clone()
else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
return dist.broadcast(h, src=i, group=sp.group, async_op=True)
last_work = None
out = None
for i in range(num_parts):
if i == 0:
work = async_fetch(0)
else:
work = last_work
if i + 1 < sp.num_parts:
last_work = async_fetch(i + 1)
work.wait()
h, = work.result()
if out is None:
out = sp.adj_t(i) @ h
else:
out += sp.adj_t(i) @ h
ctx.saved_sp = sp
return out
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
sp: SparseBlocks = ctx.saved_sp
part_id = sp.part_id
num_parts = sp.num_parts
def async_reduce(i: int, g: Tensor):
return dist.reduce(
g, dst=i, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True,
)
out = None
last_work = None
for i in range(num_parts):
g = sp.adj_t(i).t() @ grad
if i > 0:
last_work.wait()
last_work = async_reduce(i, g)
if i == part_id:
out = g
if last_work is not None:
last_work.wait()
return None, out
......@@ -8,6 +8,9 @@ from torch import Tensor
from typing import *
__all__ = [
"SyncBatchNorm",
]
class SyncBatchNorm(nn.Module):
def __init__(self,
......
# from .functional import train_epoch, eval_epoch
# from .partition import partition_load, partition_save, partition_data
from .printer import sync_print, main_print
from .metrics import all_reduce_loss, accuracy
\ No newline at end of file
from .partition import *
from .uvm import *
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment