Commit 2177c0ed by Wenjie Huang

add batched all_reduce() in SequencePipe

parent 88de1d9c
......@@ -7,6 +7,9 @@ from torch import Tensor
from typing import *
from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork
from abc import ABC, abstractmethod
from .utils import all_reduce_buffers, all_reduce_gradients
__all__ = [
......@@ -14,13 +17,17 @@ __all__ = [
]
class SequencePipe:
class SequencePipe(nn.Module,ABC):
def __init__(self,
batch_size: int,
seq_ranks: List[int],
group: Any,
seq_ranks: Optional[List[int]] = None,
group: Any = None,
) -> None:
super().__init__()
self._batch_size = int(batch_size)
if seq_ranks is None:
seq_ranks = list(range(dist.get_world_size(group)))
self._seq_ranks = tuple(seq_ranks)
self._group = group
......@@ -28,6 +35,9 @@ class SequencePipe:
self._index = self._seq_ranks.index(rank)
self._init_states: Optional[Tuple[Tensor,...]] = None
self._all_reduce_grads_op = None
self._all_reduce_buffers_op = None
@property
def batch_size(self) -> int:
......@@ -55,6 +65,20 @@ class SequencePipe:
self._init_states = tuple(states)
return self
def enable_all_reduce(self,
grads_op = dist.ReduceOp.SUM,
buffers_op = dist.ReduceOp.AVG,
):
self._all_reduce_grads_op = grads_op
self._all_reduce_buffers_op = buffers_op
return self
def disable_all_reduce(self):
self._all_reduce_grads_op = None
self._all_reduce_buffers_op = None
return self
@abstractmethod
def state_forward(self,
batch_id: int,
inputs: Tuple[Tensor,...],
......@@ -62,6 +86,7 @@ class SequencePipe:
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
raise NotImplementedError()
@abstractmethod
def loss_fn(self,
batch_id: int,
outputs: Tuple[Tensor,...],
......@@ -70,111 +95,157 @@ class SequencePipe:
@torch.inference_mode()
def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
inputs = self._detach_inputs(inputs)
detach = self._detach_inputs(inputs)
B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size
last_work = None
outputs = []
outputs = None
for batch_id in range(num_batchs):
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, inputs)
batch_inputs = self._get_batch_inputs(start, end, detach)
batch_states = self._get_batch_states(end - start)
batch_outputs, _, work = self._forward_inner(batch_id, batch_inputs, batch_states)
outputs.append(batch_outputs)
batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for o, t in zip(outputs, batch_outputs):
o[start:end] = t.data
if last_work is not None:
last_work.wait()
last_work = work
concat_outputs = []
for ts in zip(*outputs):
concat_outputs.append(torch.cat(ts, dim=0))
last_work = self._save_states(batch_states)
if last_work is not None:
last_work.wait()
return tuple(concat_outputs)
return outputs
def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
# inputs = self._detach_inputs(inputs)
detach = self._detach_inputs(inputs)
B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size
fw_batch_id = 0
bw_batch_id = 0
footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs)
source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs)
target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
bw_offset = 2 * self.num_ranks - self.index - 1
fw_ready: Dict[int, Any] = {}
bw_ready: Dict[int, Any] = {}
last_bw_work = None
input_batch_grads = []
total_loss = None
# hist = []
count = 0
while fw_batch_id < num_batchs or bw_batch_id < num_batchs:
if count >= bw_offset + 2 * bw_batch_id:
if bw_batch_id < num_batchs:
bs, work, *fw_graph = bw_ready.pop(bw_batch_id)
work.wait()
scale_factor = scale * self.batch_size / bs
grads, loss, work = self._backward_inner(bw_batch_id, *fw_graph, scale_factor=scale_factor)
# hist.append(f"{count+1}bw{bw_batch_id + 1}")
last_work = None
if last_bw_work is not None:
last_bw_work.wait()
last_bw_work = work
# input_batch_grads = []
input_grads = None
total_loss = None
total_loss = loss if total_loss is None else total_loss + loss
input_batch_grads.append(grads)
while True:
_, op, batch_id = footprint.step()
_, source_op, source_batch_id = source_footprint.step()
_, target_op, target_batch_id = target_footprint.step()
if op is None and source_op is None and target_op is None:
break
bw_batch_id += 1
else:
fw_starting = (fw_batch_id < self.num_ranks and count >= fw_batch_id + self.index)
fw_occupying = (count >= self.index + 2 * fw_batch_id)
if fw_batch_id < num_batchs and (fw_starting or fw_occupying):
start = fw_batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, inputs, requires_grad=True)
batch_states = self._get_batch_states(end - start, requires_grad=True)
batch_outputs, batch_next_states, work = self._forward_inner(fw_batch_id, batch_inputs, batch_states)
# hist.append(f"{count+1}fw{fw_batch_id + 1}")
bw_ready[fw_batch_id] = [
end - start,
work,
batch_inputs,
batch_outputs,
batch_states,
batch_next_states,
]
fw_batch_id += 1
count += 1
if last_work is not None:
last_work.wait()
last_work = None
if source_op == "backward":
input_states, = bw_ready.pop(source_batch_id)
last_work = self._save_states(input_states, grad=True)
del input_states
elif target_op == "forward":
*_, output_states = fw_ready[target_batch_id]
last_work = self._save_states(output_states)
del _, output_states
if op == "forward":
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
batch_input_states = self._get_batch_states(end - start, requires_grad=True)
batch_outputs, batch_output_states = self._forward_inner(
batch_id,
batch_inputs, batch_input_states,
)
fw_ready[batch_id] = [
batch_inputs,
batch_outputs,
batch_input_states,
batch_output_states,
]
elif op == "backward":
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
scale_factor = scale * self.batch_size / (end - start)
grads, loss = self._backward_inner(
batch_id,
batch_inputs,
batch_outputs,
batch_input_states,
batch_output_states,
scale_factor=scale_factor,
)
bw_ready[batch_id] = [
batch_input_states,
]
total_loss = loss if total_loss is None else total_loss + loss
if input_grads is None:
input_grads = []
for t in grads:
if t is not None:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
input_grads.append(t)
input_grads = tuple(input_grads)
for g, t in zip(input_grads, grads):
if g is not None:
g[start:end] = t.data
input_grads = []
for ts in zip(*input_batch_grads):
if None in ts:
input_grads.append(None)
else:
input_grads.append(torch.cat(ts, dim=0))
if last_work is not None:
last_work.wait()
if last_bw_work is not None:
last_bw_work.wait()
prev_works = self._prev_inputs_backward()
self._vector_backward(inputs, input_grads)
self._post_inputs_backward(prev_works)
# print(f"{self.index+1}: {hist}")
return total_loss
def _prev_inputs_backward(self):
works = []
works.extend(all_reduce_gradients(
self, op=self._all_reduce_grads_op,
group=self.group, async_op=True,
))
works.extend(all_reduce_buffers(
self, op=self._all_reduce_buffers_op,
group=self.group, async_op=True,
))
return works
def _post_inputs_backward(self, works):
for w in works:
w.wait()
works.clear()
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
for s in states:
......@@ -248,19 +319,22 @@ class SequencePipe:
def _get_batch_inputs(self,
start: int, end: int,
inputs: Tuple[Tensor,...],
requires_grad: bool = False
detach: Tuple[Tensor,...],
inputs: Optional[Tuple[Tensor,...]] = None,
) -> Tuple[Tensor,...]:
_inputs: List[Tensor] = []
for t in inputs:
if requires_grad and t.requires_grad:
t = t.detach()[start:end]
outs: List[Tensor] = []
for i, t in enumerate(detach):
assert not t.requires_grad
if inputs is None:
outs.append(t[start:end])
elif inputs[i].requires_grad:
t = t[start:end]
t.requires_grad_()
t.retain_grad()
outs.append(t)
else:
t = t.detach()[start:end]
_inputs.append(t)
return tuple(_inputs)
outs.append(t[start:end])
return tuple(outs)
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()."
......@@ -285,12 +359,11 @@ class SequencePipe:
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...], BatchWork]:
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
self._load_states(states)
outputs, next_states = self.state_forward(batch_id, inputs, states)
work = self._save_states(next_states)
return outputs, next_states, work
return outputs, next_states
def _backward_inner(self,
batch_id: int,
......@@ -299,7 +372,7 @@ class SequencePipe:
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor, BatchWork]:
) -> Tuple[Tuple[Tensor,...], Tensor]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
......@@ -316,11 +389,49 @@ class SequencePipe:
vec_grad.append(g)
self._vector_backward(vec_loss, vec_grad)
work = self._save_states(input_states, grad=True)
input_grads = []
for t in inputs:
g, t.grad = t.grad, None
input_grads.append(g)
return tuple(input_grads), loss.detach(), work
return tuple(input_grads), loss.detach()
class F1B1Footprint:
def __init__(self,
index: int,
num_ranks: int,
num_batchs: int,
) -> None:
self._index = index
self._num_ranks = num_ranks
self._num_batchs = num_batchs
self._bw_offset = 2 * self._num_ranks - self._index - 1
self._fw_batch_id = 0
self._bw_batch_id = 0
self._count = 0
if index < 0 or index > num_ranks:
self._finish = True
else:
self._finish = False
def step(self) -> Tuple[int, Optional[str], Optional[int]]:
if self._finish:
return (self._count, None, None)
ret = (self._count, "nop", None)
if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1
elif self._fw_batch_id < self._num_batchs:
if self._count >= self._index + 2 * self._fw_batch_id:
ret = (self._count, "forward", self._fw_batch_id)
self._fw_batch_id += 1
if self._bw_batch_id >= self._num_batchs:
self._finish = True
self._count += 1
return ret
\ No newline at end of file
......@@ -5,16 +5,111 @@ import torch.distributed as dist
from torch import Tensor
from typing import *
from collections import defaultdict
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None):
# def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
# works = []
# for p in net.parameters():
# if p.grad is None:
# p.grad = torch.zeros_like(p.data)
# w = dist.all_reduce(p.grad, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
# def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
# works = []
# for b in net.buffers():
# w = dist.all_reduce(b.data, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.parameters():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_tensors[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.parameters():
dist.all_reduce(p.grad, op=op, group=group)
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
if p.grad is not None:
typed_tensors[p.dtype][s:t] = p.grad.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
g = torch.empty(0, dtype=p.dtype, device=device)
p.grad = g.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.buffers():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_numel[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.buffers():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
typed_tensors[p.dtype][s:t] = p.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
p.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None):
for b in net.buffers():
dist.all_reduce(b.data, op=op, group=group)
def default_stride(*size: int) -> Tuple[int,...]:
dims = len(size)
stride = [1] * dims
for i in range(1, dims):
k = dims - i
stride[k - 1] = stride[k] * size[k]
return tuple(stride)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment