Commit 88de1d9c by Wenjie Huang

SquencePipe support long type of tensors

parent 32fec45c
...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected ...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp import os.path as osp
import sys import sys
from starrygl.graph import GraphData from starrygl.data import GraphData
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
......
...@@ -149,6 +149,9 @@ def batch_send( ...@@ -149,6 +149,9 @@ def batch_send(
group: Any = None, group: Any = None,
async_op: bool = False, async_op: bool = False,
): ):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
...@@ -171,6 +174,9 @@ def batch_recv( ...@@ -171,6 +174,9 @@ def batch_recv(
group: Any = None, group: Any = None,
async_op: bool = False, async_op: bool = False,
): ):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
......
...@@ -6,10 +6,14 @@ import torch.distributed as dist ...@@ -6,10 +6,14 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.distributed import DistributedContext
from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork
__all__ = [
"SequencePipe",
]
class SequencePipe: class SequencePipe:
def __init__(self, def __init__(self,
batch_size: int, batch_size: int,
...@@ -45,77 +49,12 @@ class SequencePipe: ...@@ -45,77 +49,12 @@ class SequencePipe:
def group(self): def group(self):
return self._group return self._group
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
if grad: # from target to source
for s in states:
s.grad = torch.zeros_like(s.data)
if self.index + 1 < self.num_ranks:
batch_recv(
*[s.grad for s in states],
src=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True,
).wait()
else: # from source to target
for s in states:
s.grad = None
if self.index > 0:
batch_recv(
*[s.data for s in states],
src=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
).wait()
def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
if grad: # from target to source
if self.index > 0:
return batch_send(
*[s.grad for s in states],
dst=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
)
else: # from source to target
if self.index + 1 < self.num_ranks:
return batch_send(
*[s.data for s in states],
dst=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True
)
return BatchWork(None, None)
def init_states(self, *states: Tensor): def init_states(self, *states: Tensor):
for s in states: for s in states:
assert isinstance(s, Tensor), f"states must be tuple of tensors" assert isinstance(s, Tensor), f"states must be tuple of tensors"
self._init_states = tuple(states) self._init_states = tuple(states)
return self return self
def batch_size_(self, bs: int):
self._batch_size = bs
return self
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()."
states: List[Tensor] = []
for s in self._init_states:
s = s.unsqueeze(0).broadcast_to(bs, *s.size())
states.append(s.contiguous())
if requires_grad and self.index > 0:
states = [s.requires_grad_() for s in states]
return tuple(states)
def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
detach_inputs: List[Tensor] = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
detach_inputs.append(t.detach())
return tuple(detach_inputs)
def state_forward(self, def state_forward(self,
batch_id: int, batch_id: int,
inputs: Tuple[Tensor,...], inputs: Tuple[Tensor,...],
...@@ -129,58 +68,21 @@ class SequencePipe: ...@@ -129,58 +68,21 @@ class SequencePipe:
) -> Tensor: ) -> Tensor:
raise NotImplementedError() raise NotImplementedError()
def _forward_inner(self, @torch.inference_mode()
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...], BatchWork]:
self._load_states(states)
outputs, next_states = self.state_forward(batch_id, inputs, states)
work = self._save_states(next_states)
return outputs, next_states, work
def _backward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
outputs: Tuple[Tensor,...],
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor, BatchWork]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
total_loss = loss
self._load_states(output_states, grad=True)
for s in output_states:
g, s.grad = s.grad, None
total_loss += torch.sum(s * g)
total_loss.backward()
work = self._save_states(input_states, grad=True)
input_grads = []
for t in inputs:
input_grads.append(t.grad)
t.grad = None
return tuple(input_grads), loss.detach(), work
def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]: def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
inputs = self._detach_inputs(inputs) inputs = self._detach_inputs(inputs)
B = inputs[0].size(0) B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size num_batchs = (B + self.batch_size - 1) // self.batch_size
with torch.no_grad():
outputs = []
last_work = None last_work = None
outputs = []
for batch_id in range(num_batchs): for batch_id in range(num_batchs):
start = batch_id * self.batch_size start = batch_id * self.batch_size
end = min(B, start + self.batch_size) end = min(B, start + self.batch_size)
batch_inputs = tuple(t[start:end] for t in inputs) batch_inputs = self._get_batch_inputs(start, end, inputs)
batch_states = self._get_batch_states(end - start) batch_states = self._get_batch_states(end - start)
batch_outputs, _, work = self._forward_inner(batch_id, batch_inputs, batch_states) batch_outputs, _, work = self._forward_inner(batch_id, batch_inputs, batch_states)
outputs.append(batch_outputs) outputs.append(batch_outputs)
...@@ -198,8 +100,8 @@ class SequencePipe: ...@@ -198,8 +100,8 @@ class SequencePipe:
return tuple(concat_outputs) return tuple(concat_outputs)
def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tuple[Tuple[Tensor,...], Tensor]: def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
inputs = self._detach_inputs(inputs) # inputs = self._detach_inputs(inputs)
B = inputs[0].size(0) B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size num_batchs = (B + self.batch_size - 1) // self.batch_size
...@@ -211,7 +113,7 @@ class SequencePipe: ...@@ -211,7 +113,7 @@ class SequencePipe:
bw_ready: Dict[int, Any] = {} bw_ready: Dict[int, Any] = {}
last_bw_work = None last_bw_work = None
input_grads = [] input_batch_grads = []
total_loss = None total_loss = None
# hist = [] # hist = []
...@@ -232,7 +134,7 @@ class SequencePipe: ...@@ -232,7 +134,7 @@ class SequencePipe:
last_bw_work = work last_bw_work = work
total_loss = loss if total_loss is None else total_loss + loss total_loss = loss if total_loss is None else total_loss + loss
input_grads.append(grads) input_batch_grads.append(grads)
bw_batch_id += 1 bw_batch_id += 1
else: else:
...@@ -241,7 +143,7 @@ class SequencePipe: ...@@ -241,7 +143,7 @@ class SequencePipe:
if fw_batch_id < num_batchs and (fw_starting or fw_occupying): if fw_batch_id < num_batchs and (fw_starting or fw_occupying):
start = fw_batch_id * self.batch_size start = fw_batch_id * self.batch_size
end = min(B, start + self.batch_size) end = min(B, start + self.batch_size)
batch_inputs = tuple(t[start:end].requires_grad_() for t in inputs) batch_inputs = self._get_batch_inputs(start, end, inputs, requires_grad=True)
batch_states = self._get_batch_states(end - start, requires_grad=True) batch_states = self._get_batch_states(end - start, requires_grad=True)
batch_outputs, batch_next_states, work = self._forward_inner(fw_batch_id, batch_inputs, batch_states) batch_outputs, batch_next_states, work = self._forward_inner(fw_batch_id, batch_inputs, batch_states)
...@@ -258,13 +160,167 @@ class SequencePipe: ...@@ -258,13 +160,167 @@ class SequencePipe:
fw_batch_id += 1 fw_batch_id += 1
count += 1 count += 1
concat_grads = [] input_grads = []
for ts in zip(*input_grads): for ts in zip(*input_batch_grads):
concat_grads.append(torch.cat(ts, dim=0)) if None in ts:
input_grads.append(None)
else:
input_grads.append(torch.cat(ts, dim=0))
if last_bw_work is not None: if last_bw_work is not None:
last_bw_work.wait() last_bw_work.wait()
self._vector_backward(inputs, input_grads)
# print(f"{self.index+1}: {hist}") # print(f"{self.index+1}: {hist}")
return total_loss
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
for s in states:
s.grad = None
if grad: # from target to source
if self.index + 1 < self.num_ranks:
s_grads = []
for s in states:
if s.dtype.is_floating_point:
s.grad = torch.zeros_like(s.data)
s_grads.append(s.grad)
batch_recv(
*s_grads,
src=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True,
).wait()
else: # from source to target
if self.index > 0:
batch_recv(
*[s.data for s in states],
src=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
).wait()
def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
if grad: # from target to source
if self.index > 0:
s_grads = []
for s in states:
g, s.grad = s.grad, None
if s.dtype.is_floating_point:
s_grads.append(torch.zeros_like(s) if g is None else g)
return batch_send(
*s_grads,
dst=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
)
else: # from source to target
if self.index + 1 < self.num_ranks:
return batch_send(
*[s.data for s in states],
dst=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True
)
return BatchWork(None, None)
def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]):
loss = []
grad = []
for x, g in zip(vec_loss, vec_grad):
if g is None:
continue
if not x.requires_grad:
continue
# if not x.dtype.is_floating_point:
# continue
loss.append(x.flatten())
grad.append(g.flatten())
if len(loss) != 0:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
def _get_batch_inputs(self,
start: int, end: int,
inputs: Tuple[Tensor,...],
requires_grad: bool = False
) -> Tuple[Tensor,...]:
_inputs: List[Tensor] = []
for t in inputs:
if requires_grad and t.requires_grad:
t = t.detach()[start:end]
t.requires_grad_()
t.retain_grad()
else:
t = t.detach()[start:end]
_inputs.append(t)
return tuple(_inputs)
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()."
states: List[Tensor] = []
for s in self._init_states:
s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.dtype.is_floating_point:
s.requires_grad_()
s.retain_grad()
states.append(s)
return tuple(states)
def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
detach_inputs: List[Tensor] = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
detach_inputs.append(t.detach())
return tuple(detach_inputs)
def _forward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...], BatchWork]:
self._load_states(states)
outputs, next_states = self.state_forward(batch_id, inputs, states)
work = self._save_states(next_states)
return outputs, next_states, work
def _backward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
outputs: Tuple[Tensor,...],
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor, BatchWork]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
vec_loss = [loss]
vec_grad = [torch.ones_like(loss)]
self._load_states(output_states, grad=True)
for s in output_states:
if s.requires_grad:
s.retain_grad()
g, s.grad = s.grad, None
vec_loss.append(s)
vec_grad.append(g)
self._vector_backward(vec_loss, vec_grad)
work = self._save_states(input_states, grad=True)
input_grads = []
for t in inputs:
g, t.grad = t.grad, None
input_grads.append(g)
return tuple(input_grads), loss.detach(), work
return tuple(concat_grads), total_loss
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None):
for p in net.parameters():
dist.all_reduce(p.grad, op=op, group=group)
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None):
for b in net.buffers():
dist.all_reduce(b.data, op=op, group=group)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment