Commit 2cdc59cc by Wenjie Huang

optimize SequencePipe's communication

parent 7302be6b
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
from .route import * from .route import *
from .sequence import * from .timeline import SequencePipe
from .sparse import * from .sparse import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed.cclib import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
__all__ = [
"SequencePipe",
]
OptTensor = Optional[Tensor]
SInteger = Sequence[int]
STensor = Sequence[Tensor]
SOptTensor = Sequence[OptTensor]
PairSTensor = Tuple[STensor, STensor]
OptSTensor = Optional[STensor]
class SequencePipe(ABC):
def __init__(self) -> None:
super().__init__()
self._pos_begin = 0
self._pos_end = 0
self._pos_start = True
@abstractmethod
def get_group(self) -> Any:
raise NotImplementedError
@abstractmethod
def get_init_states(self) -> STensor:
raise NotImplementedError
@abstractmethod
def forward(self, inputs: STensor, states: STensor) -> PairSTensor:
raise NotImplementedError
def loss_fn(self, inputs: STensor, labels: STensor) -> Tensor:
raise NotImplementedError
def get_ranks(self) -> SInteger:
world_size = dist.get_world_size(self.get_group())
return tuple(range(world_size))
def apply(self, bs: int, *inputs: Tensor):
return SequencePipeFunction.apply(bs, self, *inputs)
def fast_backward(self, bs: int, inputs: STensor, labels: STensor) -> Tensor:
runtime = SequencePipeRuntime(bs, self, last_layer=True)
inputs_grads = runtime.backward(inputs, labels)
runtime.vector_backward(inputs, inputs_grads)
return runtime.acc_loss
@property
def begin(self) -> int:
return self._pos_begin
@property
def end(self) -> int:
return self._pos_end
@property
def start(self) -> bool:
return self._pos_start
@property
def batch_size(self) -> int:
return self._pos_end - self._pos_begin
@contextmanager
def _switch(self,
begin: int, end: int, start: bool,
):
saved_begin = self._pos_begin
saved_end = self._pos_end
saved_start = self._pos_start
self._pos_begin = begin
self._pos_end = end
self._pos_start = start
yield
self._pos_begin = saved_begin
self._pos_end = saved_end
self._pos_start = saved_start
class SequencePipeRuntime:
def __init__(self,
micro_batch_size: int,
program: SequencePipe,
last_layer: bool = False,
) -> None:
self.micro_batch_size = micro_batch_size
self.program = program
self.last_layer = last_layer
self.acc_loss = None
self.group = program.get_group()
self.ranks = program.get_ranks()
self.index = self.ranks.index(dist.get_rank(self.group))
self._last_work = None
def forward(self, inputs: STensor) -> STensor:
detach = self.detach_inputs(inputs)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
outputs = None
for i in range(S):
begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
start = (self.index == 0)
with self.program._switch(begin, end, start):
batch_inputs = self.get_batch_inputs(begin, end, detach)
batch_states = self.get_batch_states(begin, end)
batch_outputs, batch_states = self.forward_inner(batch_inputs, batch_states)
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for t, b in zip(outputs, batch_outputs):
t[begin:end] = b.data
self.wait_last_work(
next_one=self.save_states(batch_states),
)
self.wait_last_work()
return outputs
def backward(self, inputs: STensor, output_grads: OptSTensor = None) -> STensor:
detach = self.detach_inputs(inputs)
detach_grads = self.detach_inputs(output_grads)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
footprint = F1B1Footprint(self.index, len(self.ranks), S)
source_footprint = F1B1Footprint(self.index-1, len(self.ranks), S)
target_footprint = F1B1Footprint(self.index+1, len(self.ranks), S)
fw_ready = {}
bw_ready = {}
input_grads = None
while True:
_, op, i = footprint.step()
_, source_op, source_i = source_footprint.step()
_, target_op, target_i = target_footprint.step()
if (not op) and (not source_op) and (not target_op):
break
self.wait_last_work()
if source_op == "backward":
batch_input_state_grads, = bw_ready.pop(source_i)
self.wait_last_work(
next_one=self.save_grads(batch_input_state_grads),
)
elif target_op == "forward":
*_, batch_output_states = fw_ready[target_i]
self.wait_last_work(
next_one=self.save_states(batch_output_states),
)
del _
begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
start = (self.index == 0)
with self.program._switch(begin, end, start):
if op == "forward":
batch_inputs = self.get_batch_inputs(begin, end, detach, inputs)
batch_input_states = self.get_batch_states(begin, end, requires_grad=True)
batch_outputs, batch_output_states = \
self.forward_inner(batch_inputs, batch_input_states)
fw_ready[i] = [
batch_inputs, batch_outputs,
batch_input_states, batch_output_states,
]
elif op == "backward":
batch_inputs, batch_outputs,\
batch_input_states, batch_output_states = fw_ready.pop(i)
batch_output_grads = self.get_batch_inputs(begin, end, detach_grads)
batch_input_grads, batch_input_state_grads = \
self.backward_inner(
batch_inputs, batch_outputs,
batch_input_states, batch_output_states,
batch_output_grads,
)
bw_ready[i] = [
batch_input_state_grads,
]
if input_grads is None:
input_grads = []
for t in batch_input_grads:
if t is not None:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
input_grads.append(t)
input_grads = tuple(input_grads)
for g, t in zip(input_grads, batch_input_grads):
if g is not None:
g[begin:end] = t.data
self.wait_last_work()
return input_grads
def wait_last_work(self, next_one = None):
if self._last_work is not None:
self._last_work.wait()
self._last_work = next_one
def detach_inputs(self, inputs: STensor) -> STensor:
detach: STensor = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "The first dimension of all tensors must be the same."
detach.append(t.detach())
return detach
def get_batch_inputs(self,
begin: int, end: int,
detach: STensor, inputs: OptSTensor = None,
) -> STensor:
batch: STensor = []
for i, t in enumerate(detach):
assert not t.requires_grad
if inputs and inputs[i].requires_grad:
t = t[begin:end]
t.requires_grad_()
t.retain_grad()
batch.append(t)
else:
batch.append(t[begin:end])
return batch
def get_batch_states(self,
begin: int, end: int,
requires_grad: bool = False,
) -> STensor:
states = []
for s in self.program.get_init_states():
s = s.unsqueeze(0).broadcast_to(
end - begin, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.is_floating_point():
s.requires_grad_()
s.retain_grad()
states.append(s)
return states
def forward_inner(self,
inputs: STensor,
states: STensor,
):
states = self.load_states(states)
return self.program.forward(inputs, states)
def backward_inner(self,
inputs: STensor,
outputs: STensor,
input_states: STensor,
output_states: STensor,
output_grads: STensor,
) -> PairSTensor:
vloss = []
vgrad = []
if self.last_layer:
vloss.append(self.program.loss_fn(outputs, output_grads))
vgrad.append(torch.ones_like(vloss[-1]))
if self.acc_loss is None:
self.acc_loss = vloss[-1].detach()
else:
self.acc_loss += vloss[-1].detach()
else:
vloss.extend(outputs)
vgrad.extend(output_grads)
vloss.extend(output_states)
vgrad.extend(self.load_grads(output_states))
self.vector_backward(vloss, vgrad)
input_grads = []
for t in inputs:
g, t.grad = t.grad, None
input_grads.append(g)
input_state_grads = []
for s in input_states:
g, s.grad = s.grad, None
if s.is_floating_point():
g = torch.zeros_like(s) if g is None else g
else:
g = None
input_state_grads.append(g)
return input_grads, input_state_grads
def load_states(self, states: STensor):
if self.index > 0:
batch_recv(
*[s.data for s in states],
src=self.ranks[self.index - 1],
group=self.group,
async_op=True,
).wait()
return states
def load_grads(self, states: STensor):
grads: SOptTensor = []
if self.index + 1 < len(self.ranks):
for s in states:
if s.is_floating_point():
g = torch.zeros_like(s)
grads.append(g)
else:
grads.append(None)
batch_recv(
*[g.data for g in grads if g is not None],
src=self.ranks[self.index + 1],
group=self.group,
async_op=True,
).wait()
else:
for s in states:
grads.append(None)
return grads
def save_states(self, states: STensor):
if self.index + 1 < len(self.ranks):
return batch_send(
*[s.data for s in states],
dst=self.ranks[self.index + 1],
group=self.group,
async_op=True,
)
else:
return BatchWork(None, None)
def save_grads(self, grads: STensor):
if self.index > 0:
return batch_send(
*[g.data for g in grads if g is not None],
dst=self.ranks[self.index - 1],
group=self.group,
async_op=True,
)
else:
return BatchWork(None, None)
def vector_backward(self, vloss: STensor, vgrad: STensor):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vloss, vgrad):
if g is None:
continue
if not x.requires_grad:
continue
loss.append(x.flatten())
grad.append(g.flatten())
if loss:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
class SequencePipeFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
micro_batch_size: int,
program: SequencePipe,
*inputs: Tensor,
):
runtime = SequencePipeRuntime(
micro_batch_size, program
)
ctx.save_for_backward(*inputs)
ctx.saved_runtime = runtime
return runtime.forward(inputs)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
*grads: Tensor,
):
inputs: STensor = ctx.saved_tensors
runtime: SequencePipeRuntime = ctx.saved_runtime
with torch.enable_grad():
input_grads = runtime.backward(inputs, grads)
return None, None, *input_grads
# class SequencePipe(nn.Module,ABC):
# def __init__(self,
# batch_size: int,
# seq_ranks: Optional[List[int]] = None,
# group: Any = None,
# ) -> None:
# super().__init__()
# self._batch_size = int(batch_size)
# if seq_ranks is None:
# seq_ranks = list(range(dist.get_world_size(group)))
# self._seq_ranks = tuple(seq_ranks)
# self._group = group
# rank = dist.get_rank(group)
# self._index = self._seq_ranks.index(rank)
# self._init_states: Optional[Tuple[Tensor,...]] = None
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# @property
# def batch_size(self) -> int:
# return self._batch_size
# @property
# def seq_ranks(self):
# return self._seq_ranks
# @property
# def num_ranks(self) -> int:
# return len(self._seq_ranks)
# @property
# def index(self) -> int:
# return self._index
# @property
# def group(self):
# return self._group
# def init_states(self, *states: Tensor):
# for s in states:
# assert isinstance(s, Tensor), f"states must be tuple of tensors"
# self._init_states = tuple(states)
# return self
# def enable_all_reduce(self,
# grads_op = dist.ReduceOp.SUM,
# buffers_op = dist.ReduceOp.AVG,
# group: Any = None,
# ):
# self._all_reduce_grads_op = grads_op
# self._all_reduce_buffers_op = buffers_op
# self._all_reduce_group = group
# return self
# def disable_all_reduce(self):
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# return self
# @abstractmethod
# def state_forward(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# raise NotImplementedError()
# @abstractmethod
# def loss_fn(self,
# batch_id: int,
# outputs: Tuple[Tensor,...],
# ) -> Tensor:
# raise NotImplementedError()
# @torch.inference_mode()
# def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# last_work = None
# outputs = None
# for batch_id in range(num_batchs):
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach)
# batch_states = self._get_batch_states(end - start)
# batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
# if outputs is None:
# outputs = []
# for t in batch_outputs:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# outputs.append(t)
# outputs = tuple(outputs)
# for o, t in zip(outputs, batch_outputs):
# o[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# last_work = self._save_states(batch_states)
# if last_work is not None:
# last_work.wait()
# return outputs
# def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs)
# source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs)
# target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
# fw_ready: Dict[int, Any] = {}
# bw_ready: Dict[int, Any] = {}
# last_work = None
# # input_batch_grads = []
# input_grads = None
# total_loss = None
# while True:
# _, op, batch_id = footprint.step()
# _, source_op, source_batch_id = source_footprint.step()
# _, target_op, target_batch_id = target_footprint.step()
# if op is None and source_op is None and target_op is None:
# break
# if last_work is not None:
# last_work.wait()
# last_work = None
# if source_op == "backward":
# input_states, = bw_ready.pop(source_batch_id)
# last_work = self._save_states(input_states, grad=True)
# del input_states
# elif target_op == "forward":
# *_, output_states = fw_ready[target_batch_id]
# last_work = self._save_states(output_states)
# del _, output_states
# if op == "forward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
# batch_input_states = self._get_batch_states(end - start, requires_grad=True)
# batch_outputs, batch_output_states = self._forward_inner(
# batch_id,
# batch_inputs, batch_input_states,
# )
# fw_ready[batch_id] = [
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# ]
# elif op == "backward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
# scale_factor = scale * self.batch_size / (end - start)
# grads, loss = self._backward_inner(
# batch_id,
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# scale_factor=scale_factor,
# )
# bw_ready[batch_id] = [
# batch_input_states,
# ]
# total_loss = loss if total_loss is None else total_loss + loss
# if input_grads is None:
# input_grads = []
# for t in grads:
# if t is not None:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# input_grads.append(t)
# input_grads = tuple(input_grads)
# for g, t in zip(input_grads, grads):
# if g is not None:
# g[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# prev_works = self._prev_inputs_backward()
# self._vector_backward(inputs, input_grads)
# self._post_inputs_backward(prev_works)
# return total_loss
# def _prev_inputs_backward(self):
# works = []
# works.extend(all_reduce_gradients(
# self, op=self._all_reduce_grads_op,
# group=self._all_reduce_group, async_op=True,
# ))
# works.extend(all_reduce_buffers(
# self, op=self._all_reduce_buffers_op,
# group=self._all_reduce_group, async_op=True,
# ))
# return works
# def _post_inputs_backward(self, works):
# for w in works:
# w.wait()
# works.clear()
# def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
# for s in states:
# s.grad = None
# if grad: # from target to source
# if self.index + 1 < self.num_ranks:
# s_grads = []
# for s in states:
# if s.dtype.is_floating_point:
# s.grad = torch.zeros_like(s.data)
# s_grads.append(s.grad)
# batch_recv(
# *s_grads,
# src=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True,
# ).wait()
# else: # from source to target
# if self.index > 0:
# batch_recv(
# *[s.data for s in states],
# src=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# ).wait()
# def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
# if grad: # from target to source
# if self.index > 0:
# s_grads = []
# for s in states:
# g, s.grad = s.grad, None
# if s.dtype.is_floating_point:
# s_grads.append(torch.zeros_like(s) if g is None else g)
# return batch_send(
# *s_grads,
# dst=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# )
# else: # from source to target
# if self.index + 1 < self.num_ranks:
# return batch_send(
# *[s.data for s in states],
# dst=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True
# )
# return BatchWork(None, None)
# def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]):
# loss = []
# grad = []
# for x, g in zip(vec_loss, vec_grad):
# if g is None:
# continue
# if not x.requires_grad:
# continue
# # if not x.dtype.is_floating_point:
# # continue
# loss.append(x.flatten())
# grad.append(g.flatten())
# if len(loss) != 0:
# loss = torch.cat(loss, dim=0)
# grad = torch.cat(grad, dim=0)
# loss.backward(grad)
# def _get_batch_inputs(self,
# start: int, end: int,
# detach: Tuple[Tensor,...],
# inputs: Optional[Tuple[Tensor,...]] = None,
# ) -> Tuple[Tensor,...]:
# outs: List[Tensor] = []
# for i, t in enumerate(detach):
# assert not t.requires_grad
# if inputs is None:
# outs.append(t[start:end])
# elif inputs[i].requires_grad:
# t = t[start:end]
# t.requires_grad_()
# t.retain_grad()
# outs.append(t)
# else:
# outs.append(t[start:end])
# return tuple(outs)
# def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
# assert self._init_states is not None, "please call init_states()."
# states: List[Tensor] = []
# for s in self._init_states:
# s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
# if requires_grad and self.index > 0 and s.dtype.is_floating_point:
# s.requires_grad_()
# s.retain_grad()
# states.append(s)
# return tuple(states)
# def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
# assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
# detach_inputs: List[Tensor] = []
# for t in inputs:
# assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
# detach_inputs.append(t.detach())
# return tuple(detach_inputs)
# def _forward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# self._load_states(states)
# outputs, next_states = self.state_forward(batch_id, inputs, states)
# return outputs, next_states
# def _backward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# outputs: Tuple[Tensor,...],
# input_states: Tuple[Tensor,...],
# output_states: Tuple[Tensor,...],
# scale_factor: float = 1.0,
# ) -> Tuple[Tuple[Tensor,...], Tensor]:
# loss = self.loss_fn(batch_id, outputs)
# if scale_factor != 1.0:
# loss = loss * scale_factor
# vec_loss = [loss]
# vec_grad = [torch.ones_like(loss)]
# self._load_states(output_states, grad=True)
# for s in output_states:
# if s.requires_grad:
# s.retain_grad()
# g, s.grad = s.grad, None
# vec_loss.append(s)
# vec_grad.append(g)
# self._vector_backward(vec_loss, vec_grad)
# input_grads = []
# for t in inputs:
# g, t.grad = t.grad, None
# input_grads.append(g)
# return tuple(input_grads), loss.detach()
class F1B1Footprint:
def __init__(self,
index: int,
num_ranks: int,
num_batchs: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._num_batchs = num_batchs
self._bw_offset = 2 * self._num_ranks - self._index - 1
self._fw_batch_id = 0
self._bw_batch_id = 0
self._count = 0
if index < 0 or index > num_ranks:
self._finish = True
else:
self._finish = False
def step(self) -> Tuple[int, Optional[str], int]:
if self._finish:
return (self._count, None, -1)
ret = (self._count, "nop", -1)
if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1
elif self._fw_batch_id < self._num_batchs:
if self._count >= self._index + 2 * self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._bw_batch_id >= self._num_batchs:
self._finish = True
self._count += 1
return ret
\ No newline at end of file
from .pipe import SequencePipe
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.distributed as dist
from torch import Tensor
from typing import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .sync import VirtualMotions, VirtualForward, BatchSync
from .utils import vector_backward
class SequencePipe(ABC):
def __init__(self) -> None:
super().__init__()
self._pos_begin = 0
self._pos_end = 0
@abstractmethod
def get_group(self) -> Any:
raise NotImplementedError
@abstractmethod
def get_init_states(self) -> Union[Tensor, Sequence[Tensor]]:
raise NotImplementedError
@abstractmethod
def forward(self,
inputs: Sequence[Tensor],
states: Sequence[Tensor],
) -> Tuple[
Sequence[Tensor],
Sequence[Tensor],
]:
raise NotImplementedError
def loss_fn(self,
inputs: Sequence[Tensor],
labels: Sequence[Tensor],
) -> Tensor:
raise NotImplementedError
def get_ranks(self) -> Sequence[int]:
world_size = dist.get_world_size(self.get_group())
return tuple(range(world_size))
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
def apply(self, bs: int, *inputs: Tensor) -> Sequence[Tensor]:
runtime = SequencePipeRuntime(bs, self)
return SequencePipeFunction.apply(runtime, *inputs)
def fast_backward(self,
bs: int,
inputs: Sequence[Tensor],
labels: Sequence[Tensor],
) -> Optional[Tensor]:
runtime = SequencePipeRuntime(bs, self, use_fast_backward=True)
inputs_grads = runtime.backward(inputs, labels)
vector_backward(inputs, inputs_grads)
return runtime.acc_loss
@property
def begin(self) -> int:
return self._pos_begin
@property
def end(self) -> int:
return self._pos_end
@property
def batch_size(self) -> int:
return self._pos_end - self._pos_begin
@contextmanager
def _switch_batch(self, begin: int, end: int):
saved_begin = self._pos_begin
saved_end = self._pos_end
self._pos_begin = begin
self._pos_end = end
yield
self._pos_begin = saved_begin
self._pos_end = saved_end
class SequencePipeRuntime:
def __init__(self,
micro_batch_size: int,
program: SequencePipe,
use_fast_backward: bool = False,
) -> None:
self.micro_batch_size = micro_batch_size
self.program = program
self.use_fast_backward = use_fast_backward
self.acc_loss = None
self.group = program.get_group()
self.ranks = program.get_ranks()
self.index = self.ranks.index(dist.get_rank(self.group))
self._last_work = None
def detach_inputs(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
detach = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "The first dimension of all tensors must be the same."
detach.append(t.detach())
return detach
def get_begin_end(self, i: int, n: int) -> Tuple[int, int]:
begin = i * self.micro_batch_size
end = min(n, begin + self.micro_batch_size)
return begin, end
def get_batch_sync(self, tensors: Sequence[Tensor], device: Any) -> BatchSync:
return BatchSync(
*tensors,
seq_index=self.index,
seq_ranks=self.ranks,
group=self.group, device=device,
)
def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
detach = self.detach_inputs(inputs)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
motion = VirtualForward(self.index, len(self.ranks), S, batch_vsz=3)
outputs = None
ready_recv: Dict[int, BatchSync] = {}
ready_send: Dict[int, BatchSync] = {}
while not motion.finished:
for op, i in motion.step_comp():
begin, end = self.get_begin_end(i, N)
if op == "forward":
batch_inputs = self.get_batch_inputs(begin, end, detach)
if self.index > 0:
batch_states = ready_recv.pop(i).wait_state()
else:
batch_states = self.get_batch_states(begin, end)
with self.program._switch_batch(begin, end):
batch_outputs, batch_states = \
self.program.forward(batch_inputs, batch_states)
if self.index + 1 < len(self.ranks):
ready_send[i] = self.get_batch_sync(
batch_states,
device=detach[0].device,
)
del batch_inputs, batch_states
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for t, b in zip(outputs, batch_outputs):
t[begin:end] = b.detach()
del batch_outputs
for op, type, i in motion.step_sync():
assert type == "state"
begin, end = self.get_begin_end(i, N)
if op == "send":
ready_send.pop(i).send_state()
elif op == "recv":
ready_recv[i] = self.get_batch_sync(
self.get_batch_states(begin, end),
device=detach[0].device,
)
ready_recv[i].recv_state()
assert not ready_recv
assert not ready_send
return outputs
def backward(self,
inputs: Sequence[Tensor],
gradients: Sequence[Tensor],
) -> Sequence[Tensor]:
detach = self.detach_inputs(inputs)
detach_grads = self.detach_inputs(gradients)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
motions = VirtualMotions(self.index, len(self.ranks), S, batch_vsz=3)
ready_recv_s: Dict[int, BatchSync] = {}
ready_recv_g: Dict[int, BatchSync] = {}
ready_send_s: Dict[int, BatchSync] = {}
ready_send_g: Dict[int, BatchSync] = {}
ready_bw_cmp = {}
input_grads = [None] * len(detach)
while not motions.finished:
for op, i in motions.step_comp():
begin, end = self.get_begin_end(i, N)
if op == "forward":
batch_inputs = self.get_batch_inputs(begin, end, detach, inputs)
if self.index > 0:
ready_send_g[i] = ready_recv_s.pop(i)
batch_states = ready_send_g[i].wait_state()
else:
batch_states = self.get_batch_states(begin, end)
with self.program._switch_batch(begin, end):
batch_outputs, batch_states = self.program.forward(batch_inputs, batch_states)
if self.index + 1 < len(self.ranks):
ready_send_s[i] = self.get_batch_sync(
batch_states,
device=detach[0].device,
)
ready_recv_g[i] = ready_send_s[i]
ready_bw_cmp[i] = (batch_inputs, batch_outputs, batch_states)
del batch_inputs, batch_outputs, batch_states
elif op == "backward":
batch_output_grads = self.get_batch_inputs(begin, end, detach_grads)
batch_inputs, batch_outputs, batch_states = ready_bw_cmp.pop(i)
if self.use_fast_backward:
with self.program._switch_batch(begin, end):
vec_loss = [self.program.loss_fn(batch_outputs, batch_output_grads)]
vec_grad = [torch.ones_like(vec_loss[0])]
if self.acc_loss is None:
self.acc_loss = vec_loss[0].detach()
else:
self.acc_loss += vec_loss[0].detach()
else:
vec_loss = list(batch_outputs)
vec_grad = list(batch_output_grads)
del batch_outputs, batch_output_grads
if self.index + 1 < len(self.ranks):
batch_state_grads = ready_recv_g.pop(i).wait_grads()
vec_loss.extend(batch_states)
vec_grad.extend(batch_state_grads)
del batch_state_grads
del batch_states
vector_backward(vec_loss, vec_grad)
for i, t in enumerate(batch_inputs):
g, t.grad = t.grad, None
if g is None:
continue
if input_grads[i] is None:
input_grads[i] = torch.zeros(N, *g.shape[1:], dtype=g.dtype, device=g.device)
input_grads[i][begin:end] = g
del batch_inputs
for op, type, i in motions.step_sync():
begin, end = self.get_begin_end(i, N)
if op == "send":
if type == "state":
ready_send_s.pop(i).send_state()
elif type == "grads":
ready_send_g.pop(i).send_grads()
elif op == "recv":
if type == "state":
ready_recv_s[i] = self.get_batch_sync(
self.get_batch_states(begin, end),
device=detach[0].device,
)
ready_recv_s[i].recv_state()
elif type == "grads":
ready_recv_g[i].recv_grads()
assert not ready_recv_s
assert not ready_recv_g
assert not ready_send_s
assert not ready_send_g
assert not ready_bw_cmp
return input_grads
def get_batch_inputs(self,
begin: int, end: int,
detach: Sequence[Tensor],
inputs: Sequence[Tensor] = None,
) -> Sequence[Tensor]:
batch = []
for i, t in enumerate(detach):
assert not t.requires_grad
t = t[begin:end]
if inputs and inputs[i].requires_grad:
t.requires_grad_()
t.retain_grad()
batch.append(t)
return batch
def get_batch_states(self,
begin: int, end: int,
) -> Sequence[Tensor]:
states = []
for s in self.program.get_init_states():
s = s.unsqueeze(0).broadcast_to(
end - begin, *s.size(),
).contiguous()
states.append(s)
return states
class SequencePipeFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
runtime: SequencePipeRuntime,
*inputs: Tensor,
):
ctx.save_for_backward(*inputs)
ctx.saved_runtime = runtime
return runtime.forward(inputs)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
*grads: Tensor,
):
inputs: Sequence[Tensor] = ctx.saved_tensors
runtime: SequencePipeRuntime = ctx.saved_runtime
with torch.enable_grad():
input_grads = runtime.backward(inputs, grads)
return None, *input_grads
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
class BatchSync:
def __init__(self,
*state: Tensor,
seq_index: int,
seq_ranks: Optional[List[int]] = None,
group: Any = None,
device: Any = None,
) -> None:
self._state = state
self._grads = [None] * len(self._state)
self._rgrad = torch.tensor(
[t.requires_grad for t in self._state],
dtype=torch.bool, device=device,
)
self._seq_index = int(seq_index)
if group is None:
group = dist.GroupMember.WORLD
self._group = group
self._device = torch.device(device)
if seq_ranks is None:
group_size = dist.get_world_size(group)
seq_ranks = range(group_size)
self._seq_ranks: Tuple[int,...] = tuple(seq_ranks)
self._works = []
def zip_for_backward(self, *grads: Optional[Tensor]):
assert len(grads) == len(self._state)
self._grads = grads
vec_loss, vec_grad = [], []
for s, g in zip(self._state, self._grads):
if s.requires_grad:
vec_loss.append(s)
vec_grad.append(g)
return vec_loss, vec_grad
def wait_state(self) -> Sequence[Tensor]:
for w in self._works:
w.wait()
self._works.clear()
rgrad = self._rgrad.tolist()
for r, t in zip(rgrad, self._state):
assert t.is_leaf
t.requires_grad_(r)
return self._state
def wait_grads(self) -> Sequence[Tensor]:
for w in self._works:
w.wait()
self._works.clear()
assert self._grads is not None
return self._grads
def send_state(self):
if not self._state:
return
if self._seq_index + 1 >= len(self._seq_ranks):
return
dst = self._seq_ranks[self._seq_index + 1]
dst = dist.get_global_rank(self._group, dst)
dist.isend(self._rgrad, dst=dst, group=self._group)
for t in self._state:
dist.isend(t, dst=dst, group=self._group)
def send_grads(self):
if not self._state:
return
if self._seq_index <= 0:
return
rgrad = self._rgrad.tolist()
dst = self._seq_ranks[self._seq_index - 1]
dst = dist.get_global_rank(self._group, dst)
for r, t in zip(rgrad, self._state):
if not r:
continue
g, t.grad = t.grad, None
if g is None:
g = torch.zeros_like(t)
dist.isend(g, dst=dst, group=self._group)
def recv_state(self):
if not self._state:
return
if self._seq_index <= 0:
return
src = self._seq_ranks[self._seq_index - 1]
src = dist.get_global_rank(self._group, src)
self._works.append(
dist.irecv(self._rgrad, src=src, group=self._group)
)
for t in self._state:
self._works.append(
dist.irecv(t, src=src, group=self._group)
)
def recv_grads(self):
if not self._state:
return
if self._seq_index + 1 >= len(self._seq_ranks):
return
rgrad = self._rgrad.tolist()
src = self._seq_ranks[self._seq_index + 1]
src = dist.get_global_rank(self._group, src)
for i, (r, t) in enumerate(zip(rgrad, self._state)):
if not r:
self._grads[i] = None
continue
if self._grads[i] is None:
self._grads[i] = torch.empty_like(t)
self._works.append(
dist.irecv(self._grads[i], src=src, group=self._group)
)
class VirtualForward:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
batch_vsz: int = 2,
) -> None:
assert batch_vsz > 0
self._max_count = max_count
self._bs = batch_vsz
vmax_count = (max_count + batch_vsz - 1) // batch_vsz
self._motions: List[ForwardGenerator] = []
for _ in range(batch_vsz):
self._motions.append(
ForwardGenerator(index, num_ranks, vmax_count)
)
self._step_count = 0
@property
def finished(self):
return self._motions[self._step_count].finished
def step_comp(self):
for op, i in self._motions[self._step_count].step_comp():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, k
def step_sync(self):
for op, d, i in self._motions[self._step_count].step_sync():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, d, k
self._step_count += 1
self._step_count %= self._bs
class ForwardGenerator:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._dst_fp = ForwardFootprint(index+1, num_ranks, max_count)
self._fp = ForwardFootprint(index, num_ranks, max_count)
self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
self._finished = False
@property
def finished(self):
t = self._dst_fp.finished
k = self._fp.finished
op, _ = self._last_action
return t and k and not op
def step_comp(self):
if self.finished:
return
op, i = self._last_action
self._last_action = None, -1
if op == "forward":
yield "forward", i
def step_sync(self):
if self.finished:
return
_, dst_op, dst_i = self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
if dst_op == "forward":
yield "send", "state", dst_i
if op == "forward" and self._index > 0:
yield "recv", "state", i
class ForwardFootprint:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._max_count = max_count
self._fw_batch_id = 0
self._count = 0
if index < 0 or index >= num_ranks:
self._finished = True
else:
self._finished = False
@property
def finished(self):
return self._finished
def step(self) -> Tuple[int, Optional[str], int]:
if self._finished:
return (self._count, None, -1)
ret = (self._count, "nop", -1)
if self._count == self._index + self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._fw_batch_id >= self._max_count:
self._finished = True
self._count += 1
return ret
class VirtualMotions:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
batch_vsz: int = 2,
) -> None:
assert batch_vsz > 0
self._max_count = max_count
self._bs = batch_vsz
vmax_count = (max_count + batch_vsz - 1) // batch_vsz
self._motions: List[MotionGenerator] = []
for _ in range(batch_vsz):
self._motions.append(
MotionGenerator(index, num_ranks, vmax_count)
)
self._step_count = 0
@property
def finished(self):
return self._motions[self._step_count].finished
def step_comp(self):
for op, i in self._motions[self._step_count].step_comp():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, k
def step_sync(self):
for op, d, i in self._motions[self._step_count].step_sync():
k = i * self._bs + self._step_count
if k < self._max_count:
yield op, d, k
self._step_count += 1
self._step_count %= self._bs
class MotionGenerator:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._src_fp = F1B1Footprint(index-1, num_ranks, max_count)
self._dst_fp = F1B1Footprint(index+1, num_ranks, max_count)
self._fp = F1B1Footprint(index, num_ranks, max_count)
self._src_fp.step()
self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
self._finished = False
@property
def finished(self):
s = self._src_fp.finished
t = self._dst_fp.finished
k = self._fp.finished
op, _ = self._last_action
return s and t and k and not op
def step_comp(self):
if self.finished:
return
op, i = self._last_action
self._last_action = None, -1
if op == "forward":
yield "forward", i
elif op == "backward":
yield "backward", i
def step_sync(self):
if self.finished:
return
_, src_op, src_i = self._src_fp.step()
_, dst_op, dst_i = self._dst_fp.step()
_, op, i = self._fp.step()
self._last_action = op, i
if op == "backward" and \
self._index + 1 < self._num_ranks:
yield "recv", "grads", i
if src_op == "backward":
assert dst_op != "forward"
yield "send", "grads", src_i
elif dst_op == "forward":
assert src_op != "backward"
yield "send", "state", dst_i
if op == "forward" and self._index > 0:
yield "recv", "state", i
class F1B1Footprint:
def __init__(self,
index: int,
num_ranks: int,
max_count: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._max_count = max_count
self._bw_offset = 2 * self._num_ranks - self._index - 1
self._fw_batch_id = 0
self._bw_batch_id = 0
self._count = 0
if index < 0 or index >= num_ranks:
self._finished = True
else:
self._finished = False
@property
def finished(self):
return self._finished
def step(self) -> Tuple[int, Optional[str], int]:
if self._finished:
return (self._count, None, -1)
ret = (self._count, "nop", -1)
if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1
elif self._fw_batch_id < self._max_count:
if self._count >= self._index + 2 * self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._bw_batch_id >= self._max_count:
self._finished = True
self._count += 1
return ret
import torch
from torch import Tensor
from typing import *
def vector_backward(
vec_loss: Sequence[Tensor],
vec_grad: Sequence[Tensor],
):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vec_loss, vec_grad):
if g is None:
continue
if not x.requires_grad:
continue
loss.append(x.flatten())
grad.append(g.flatten())
if loss:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
\ No newline at end of file
...@@ -10,7 +10,6 @@ from typing import * ...@@ -10,7 +10,6 @@ from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed import DistributedContext
from starrygl.data import GraphData from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe from starrygl.parallel import Route, SequencePipe
from starrygl.parallel.sequence import STensor
from starrygl.parallel.utils import * from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn import torch_geometric.nn as pyg_nn
...@@ -129,7 +128,6 @@ if __name__ == "__main__": ...@@ -129,7 +128,6 @@ if __name__ == "__main__":
hybrid_matrix = ctx.get_hybrid_matrix() hybrid_matrix = ctx.get_hybrid_matrix()
if hybrid_matrix.size(0) == 1: if hybrid_matrix.size(0) == 1:
hybrid_matrix = hybrid_matrix.view(2, -1) hybrid_matrix = hybrid_matrix.view(2, -1)
ctx.sync_print(hybrid_matrix)
# sp is sequence parallel # sp is sequence parallel
# pp is partition parallel # pp is partition parallel
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix) sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment