Commit 09c175e2 by Wenjie Huang

add demo train_hybrid.py

parent 057dc57f
...@@ -152,8 +152,11 @@ def batch_send( ...@@ -152,8 +152,11 @@ def batch_send(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op: if async_op:
works = [] works = []
...@@ -177,8 +180,11 @@ def batch_recv( ...@@ -177,8 +180,11 @@ def batch_recv(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op: if async_op:
works = [] works = []
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import socket
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
...@@ -127,6 +128,18 @@ class DistributedContext: ...@@ -127,6 +128,18 @@ class DistributedContext:
self._local_rank = local_rank self._local_rank = local_rank
self._compute_device = device self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None self.__temp_ag_remote_object: Optional[rpc.RRef] = None
...@@ -147,15 +160,86 @@ class DistributedContext: ...@@ -147,15 +160,86 @@ class DistributedContext:
return self._local_rank return self._local_rank
@property @property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property
def device(self) -> torch.device: def device(self) -> torch.device:
return self._compute_device return self._compute_device
def get_default_group(self): def get_default_group(self):
return dist.distributed_c10d._get_default_group() # return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self): def get_default_store(self):
return dist.distributed_c10d._get_default_store() return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo: def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}") return rpc.get_worker_info(f"worker{rank}")
......
...@@ -24,6 +24,9 @@ class Route: ...@@ -24,6 +24,9 @@ class Route:
bipartite: bool = True, bipartite: bool = True,
group: Any = None, group: Any = None,
) -> 'Route': ) -> 'Route':
if group is None:
group = dist.GroupMember.WORLD
fw_tables, bw_tables = Route._build_route_tables( fw_tables, bw_tables = Route._build_route_tables(
src_ids=src_ids, dst_ids=dst_ids, src_ids=src_ids, dst_ids=dst_ids,
bipartite=bipartite, group=group, bipartite=bipartite, group=group,
...@@ -256,8 +259,11 @@ class Route: ...@@ -256,8 +259,11 @@ class Route:
all_dst_ids[i] = dst_ids all_dst_ids[i] = dst_ids
else: else:
all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw) all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw)
src_rank = dist.get_global_rank(group, i)
all_dst_get[i] = dist.broadcast( all_dst_get[i] = dist.broadcast(
all_dst_ids[i], src=i, async_op=True, group=group all_dst_ids[i], src=src_rank,
async_op=True, group=group,
) )
fw_tables: List[Tensor] = [] fw_tables: List[Tensor] = []
......
...@@ -6,394 +6,788 @@ import torch.distributed as dist ...@@ -6,394 +6,788 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork from starrygl.distributed.cclib import *
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from .utils import all_reduce_buffers, all_reduce_gradients
__all__ = [ __all__ = [
"SequencePipe", "SequencePipe",
] ]
OptTensor = Optional[Tensor]
SInteger = Sequence[int]
STensor = Sequence[Tensor]
SOptTensor = Sequence[OptTensor]
PairSTensor = Tuple[STensor, STensor]
OptSTensor = Optional[STensor]
class SequencePipe(nn.Module,ABC): class SequencePipe(ABC):
def __init__(self, def __init__(self) -> None:
batch_size: int,
seq_ranks: Optional[List[int]] = None,
group: Any = None,
) -> None:
super().__init__() super().__init__()
self._batch_size = int(batch_size) self._pos_begin = 0
self._pos_end = 0
self._pos_start = True
@abstractmethod
def get_group(self) -> Any:
raise NotImplementedError
@abstractmethod
def get_init_states(self) -> STensor:
raise NotImplementedError
if seq_ranks is None: @abstractmethod
seq_ranks = list(range(dist.get_world_size(group))) def forward(self, inputs: STensor, states: STensor) -> PairSTensor:
self._seq_ranks = tuple(seq_ranks) raise NotImplementedError
self._group = group
rank = dist.get_rank(group) def loss_fn(self, inputs: STensor, labels: STensor) -> Tensor:
self._index = self._seq_ranks.index(rank) raise NotImplementedError
self._init_states: Optional[Tuple[Tensor,...]] = None def get_ranks(self) -> SInteger:
world_size = dist.get_world_size(self.get_group())
return tuple(range(world_size))
self._all_reduce_grads_op = None def apply(self, bs: int, *inputs: Tensor):
self._all_reduce_buffers_op = None return SequencePipeFunction.apply(bs, self, *inputs)
@property def fast_backward(self, bs: int, inputs: STensor, labels: STensor) -> Tensor:
def batch_size(self) -> int: runtime = SequencePipeRuntime(bs, self, last_layer=True)
return self._batch_size inputs_grads = runtime.backward(inputs, labels)
runtime.vector_backward(inputs, inputs_grads)
return runtime.acc_loss
@property @property
def seq_ranks(self): def begin(self) -> int:
return self._seq_ranks return self._pos_begin
@property @property
def num_ranks(self) -> int: def end(self) -> int:
return len(self._seq_ranks) return self._pos_end
@property @property
def index(self) -> int: def start(self) -> bool:
return self._index return self._pos_start
@property @property
def group(self): def batch_size(self) -> int:
return self._group return self._pos_end - self._pos_begin
def init_states(self, *states: Tensor):
for s in states:
assert isinstance(s, Tensor), f"states must be tuple of tensors"
self._init_states = tuple(states)
return self
def enable_all_reduce(self, @contextmanager
grads_op = dist.ReduceOp.SUM, def _switch(self,
buffers_op = dist.ReduceOp.AVG, begin: int, end: int, start: bool,
): ):
self._all_reduce_grads_op = grads_op saved_begin = self._pos_begin
self._all_reduce_buffers_op = buffers_op saved_end = self._pos_end
return self saved_start = self._pos_start
def disable_all_reduce(self): self._pos_begin = begin
self._all_reduce_grads_op = None self._pos_end = end
self._all_reduce_buffers_op = None self._pos_start = start
return self yield
@abstractmethod self._pos_begin = saved_begin
def state_forward(self, self._pos_end = saved_end
batch_id: int, self._pos_start = saved_start
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
raise NotImplementedError()
@abstractmethod
def loss_fn(self,
batch_id: int,
outputs: Tuple[Tensor,...],
) -> Tensor:
raise NotImplementedError()
@torch.inference_mode() class SequencePipeRuntime:
def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]: def __init__(self,
detach = self._detach_inputs(inputs) micro_batch_size: int,
program: SequencePipe,
last_layer: bool = False,
) -> None:
self.micro_batch_size = micro_batch_size
self.program = program
self.last_layer = last_layer
self.acc_loss = None
B = inputs[0].size(0) self.group = program.get_group()
num_batchs = (B + self.batch_size - 1) // self.batch_size self.ranks = program.get_ranks()
self.index = self.ranks.index(dist.get_rank(self.group))
self._last_work = None
last_work = None def forward(self, inputs: STensor) -> STensor:
detach = self.detach_inputs(inputs)
outputs = None N = inputs[0].size(0)
for batch_id in range(num_batchs): S = (N + self.micro_batch_size - 1) // self.micro_batch_size
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, detach) outputs = None
batch_states = self._get_batch_states(end - start) for i in range(S):
batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states) begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
start = (self.index == 0)
with self.program._switch(begin, end, start):
batch_inputs = self.get_batch_inputs(begin, end, detach)
batch_states = self.get_batch_states(begin, end)
batch_outputs, batch_states = self.forward_inner(batch_inputs, batch_states)
if outputs is None: if outputs is None:
outputs = [] outputs = []
for t in batch_outputs: for t in batch_outputs:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device) t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t) outputs.append(t)
outputs = tuple(outputs) outputs = tuple(outputs)
for o, t in zip(outputs, batch_outputs): for t, b in zip(outputs, batch_outputs):
o[start:end] = t.data t[begin:end] = b.data
if last_work is not None:
last_work.wait()
last_work = self._save_states(batch_states)
if last_work is not None:
last_work.wait()
self.wait_last_work(
next_one=self.save_states(batch_states),
)
self.wait_last_work()
return outputs return outputs
def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor: def backward(self, inputs: STensor, output_grads: OptSTensor = None) -> STensor:
detach = self._detach_inputs(inputs) detach = self.detach_inputs(inputs)
detach_grads = self.detach_inputs(output_grads)
B = inputs[0].size(0)
num_batchs = (B + self.batch_size - 1) // self.batch_size
footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs) N = inputs[0].size(0)
source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs) S = (N + self.micro_batch_size - 1) // self.micro_batch_size
target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
fw_ready: Dict[int, Any] = {} footprint = F1B1Footprint(self.index, len(self.ranks), S)
bw_ready: Dict[int, Any] = {} source_footprint = F1B1Footprint(self.index-1, len(self.ranks), S)
target_footprint = F1B1Footprint(self.index+1, len(self.ranks), S)
last_work = None fw_ready = {}
bw_ready = {}
# input_batch_grads = []
input_grads = None input_grads = None
total_loss = None
while True: while True:
_, op, batch_id = footprint.step() _, op, i = footprint.step()
_, source_op, source_batch_id = source_footprint.step() _, source_op, source_i = source_footprint.step()
_, target_op, target_batch_id = target_footprint.step() _, target_op, target_i = target_footprint.step()
if op is None and source_op is None and target_op is None: if (not op) and (not source_op) and (not target_op):
break break
if last_work is not None: self.wait_last_work()
last_work.wait()
last_work = None
if source_op == "backward": if source_op == "backward":
input_states, = bw_ready.pop(source_batch_id) batch_input_state_grads, = bw_ready.pop(source_i)
last_work = self._save_states(input_states, grad=True) self.wait_last_work(
del input_states next_one=self.save_grads(batch_input_state_grads),
)
elif target_op == "forward": elif target_op == "forward":
*_, output_states = fw_ready[target_batch_id] *_, batch_output_states = fw_ready[target_i]
last_work = self._save_states(output_states) self.wait_last_work(
del _, output_states next_one=self.save_states(batch_output_states),
)
del _
begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
start = (self.index == 0)
with self.program._switch(begin, end, start):
if op == "forward": if op == "forward":
start = batch_id * self.batch_size batch_inputs = self.get_batch_inputs(begin, end, detach, inputs)
end = min(B, start + self.batch_size) batch_input_states = self.get_batch_states(begin, end, requires_grad=True)
batch_inputs = self._get_batch_inputs(start, end, detach, inputs) batch_outputs, batch_output_states = \
batch_input_states = self._get_batch_states(end - start, requires_grad=True) self.forward_inner(batch_inputs, batch_input_states)
fw_ready[i] = [
batch_outputs, batch_output_states = self._forward_inner( batch_inputs, batch_outputs,
batch_id, batch_input_states, batch_output_states,
batch_inputs, batch_input_states,
)
fw_ready[batch_id] = [
batch_inputs,
batch_outputs,
batch_input_states,
batch_output_states,
] ]
elif op == "backward": elif op == "backward":
start = batch_id * self.batch_size batch_inputs, batch_outputs,\
end = min(B, start + self.batch_size) batch_input_states, batch_output_states = fw_ready.pop(i)
batch_output_grads = self.get_batch_inputs(begin, end, detach_grads)
batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
batch_input_grads, batch_input_state_grads = \
scale_factor = scale * self.batch_size / (end - start) self.backward_inner(
grads, loss = self._backward_inner( batch_inputs, batch_outputs,
batch_id, batch_input_states, batch_output_states,
batch_inputs, batch_output_grads,
batch_outputs,
batch_input_states,
batch_output_states,
scale_factor=scale_factor,
) )
bw_ready[batch_id] = [ bw_ready[i] = [
batch_input_states, batch_input_state_grads,
] ]
total_loss = loss if total_loss is None else total_loss + loss
if input_grads is None: if input_grads is None:
input_grads = [] input_grads = []
for t in grads: for t in batch_input_grads:
if t is not None: if t is not None:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device) t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
input_grads.append(t) input_grads.append(t)
input_grads = tuple(input_grads) input_grads = tuple(input_grads)
for g, t in zip(input_grads, grads): for g, t in zip(input_grads, batch_input_grads):
if g is not None: if g is not None:
g[start:end] = t.data g[begin:end] = t.data
self.wait_last_work()
if last_work is not None: return input_grads
last_work.wait()
prev_works = self._prev_inputs_backward() def wait_last_work(self, next_one = None):
self._vector_backward(inputs, input_grads) if self._last_work is not None:
self._post_inputs_backward(prev_works) self._last_work.wait()
self._last_work = next_one
return total_loss def detach_inputs(self, inputs: STensor) -> STensor:
detach: STensor = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "The first dimension of all tensors must be the same."
detach.append(t.detach())
return detach
def get_batch_inputs(self,
begin: int, end: int,
detach: STensor, inputs: OptSTensor = None,
) -> STensor:
batch: STensor = []
for i, t in enumerate(detach):
assert not t.requires_grad
if inputs and inputs[i].requires_grad:
t = t[begin:end]
t.requires_grad_()
t.retain_grad()
batch.append(t)
else:
batch.append(t[begin:end])
return batch
def get_batch_states(self,
begin: int, end: int,
requires_grad: bool = False,
) -> STensor:
states = []
for s in self.program.get_init_states():
s = s.unsqueeze(0).broadcast_to(
end - begin, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.is_floating_point():
s.requires_grad_()
s.retain_grad()
states.append(s)
return states
def _prev_inputs_backward(self): def forward_inner(self,
works = [] inputs: STensor,
states: STensor,
):
states = self.load_states(states)
return self.program.forward(inputs, states)
def backward_inner(self,
inputs: STensor,
outputs: STensor,
input_states: STensor,
output_states: STensor,
output_grads: STensor,
) -> PairSTensor:
vloss = []
vgrad = []
if self.last_layer:
vloss.append(self.program.loss_fn(outputs, output_grads))
vgrad.append(torch.ones_like(vloss[-1]))
if self.acc_loss is None:
self.acc_loss = vloss[-1].detach()
else:
self.acc_loss += vloss[-1].detach()
else:
vloss.extend(outputs)
vgrad.extend(output_grads)
works.extend(all_reduce_gradients( vloss.extend(output_states)
self, op=self._all_reduce_grads_op, vgrad.extend(self.load_grads(output_states))
group=self.group, async_op=True,
))
works.extend(all_reduce_buffers(
self, op=self._all_reduce_buffers_op,
group=self.group, async_op=True,
))
return works self.vector_backward(vloss, vgrad)
def _post_inputs_backward(self, works): input_grads = []
for w in works: for t in inputs:
w.wait() g, t.grad = t.grad, None
works.clear() input_grads.append(g)
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False): input_state_grads = []
for s in states: for s in input_states:
s.grad = None g, s.grad = s.grad, None
if s.is_floating_point():
g = torch.zeros_like(s) if g is None else g
else:
g = None
input_state_grads.append(g)
if grad: # from target to source return input_grads, input_state_grads
if self.index + 1 < self.num_ranks:
s_grads = []
for s in states:
if s.dtype.is_floating_point:
s.grad = torch.zeros_like(s.data)
s_grads.append(s.grad)
def load_states(self, states: STensor):
if self.index > 0:
batch_recv( batch_recv(
*s_grads, *[s.data for s in states],
src=self.seq_ranks[self.index + 1], src=self.ranks[self.index - 1],
group=self.group, group=self.group,
async_op=True, async_op=True,
).wait() ).wait()
else: # from source to target return states
if self.index > 0:
def load_grads(self, states: STensor):
grads: SOptTensor = []
if self.index + 1 < len(self.ranks):
for s in states:
if s.is_floating_point():
g = torch.zeros_like(s)
grads.append(g)
else:
grads.append(None)
batch_recv( batch_recv(
*[s.data for s in states], *[g.data for g in grads if g is not None],
src=self.seq_ranks[self.index - 1], src=self.ranks[self.index + 1],
group=self.group, group=self.group,
async_op=True, async_op=True,
).wait() ).wait()
else:
def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
if grad: # from target to source
if self.index > 0:
s_grads = []
for s in states: for s in states:
g, s.grad = s.grad, None grads.append(None)
if s.dtype.is_floating_point: return grads
s_grads.append(torch.zeros_like(s) if g is None else g)
def save_states(self, states: STensor):
if self.index + 1 < len(self.ranks):
return batch_send( return batch_send(
*s_grads, *[s.data for s in states],
dst=self.seq_ranks[self.index - 1], dst=self.ranks[self.index + 1],
group=self.group, group=self.group,
async_op=True, async_op=True,
) )
else: # from source to target else:
if self.index + 1 < self.num_ranks: return BatchWork(None, None)
def save_grads(self, grads: STensor):
if self.index > 0:
return batch_send( return batch_send(
*[s.data for s in states], *[g.data for g in grads if g is not None],
dst=self.seq_ranks[self.index + 1], dst=self.ranks[self.index - 1],
group=self.group, group=self.group,
async_op=True async_op=True,
) )
else:
return BatchWork(None, None) return BatchWork(None, None)
def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]): def vector_backward(self, vloss: STensor, vgrad: STensor):
loss = [] loss: List[Tensor] = []
grad = [] grad: List[Tensor] = []
for x, g in zip(vec_loss, vec_grad): for x, g in zip(vloss, vgrad):
if g is None: if g is None:
continue continue
if not x.requires_grad: if not x.requires_grad:
continue continue
# if not x.dtype.is_floating_point:
# continue
loss.append(x.flatten()) loss.append(x.flatten())
grad.append(g.flatten()) grad.append(g.flatten())
if len(loss) != 0: if loss:
loss = torch.cat(loss, dim=0) loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0) grad = torch.cat(grad, dim=0)
loss.backward(grad) loss.backward(grad)
def _get_batch_inputs(self,
start: int, end: int,
detach: Tuple[Tensor,...],
inputs: Optional[Tuple[Tensor,...]] = None,
) -> Tuple[Tensor,...]:
outs: List[Tensor] = []
for i, t in enumerate(detach):
assert not t.requires_grad
if inputs is None:
outs.append(t[start:end])
elif inputs[i].requires_grad:
t = t[start:end]
t.requires_grad_()
t.retain_grad()
outs.append(t)
else:
outs.append(t[start:end])
return tuple(outs)
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()."
states: List[Tensor] = []
for s in self._init_states:
s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.dtype.is_floating_point:
s.requires_grad_()
s.retain_grad()
states.append(s)
return tuple(states)
def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
detach_inputs: List[Tensor] = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
detach_inputs.append(t.detach())
return tuple(detach_inputs)
def _forward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
self._load_states(states)
outputs, next_states = self.state_forward(batch_id, inputs, states)
return outputs, next_states
def _backward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
outputs: Tuple[Tensor,...],
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
vec_loss = [loss]
vec_grad = [torch.ones_like(loss)]
self._load_states(output_states, grad=True)
for s in output_states:
if s.requires_grad:
s.retain_grad()
g, s.grad = s.grad, None
vec_loss.append(s)
vec_grad.append(g)
self._vector_backward(vec_loss, vec_grad)
input_grads = [] class SequencePipeFunction(autograd.Function):
for t in inputs: @staticmethod
g, t.grad = t.grad, None def forward(
input_grads.append(g) ctx: autograd.function.FunctionCtx,
return tuple(input_grads), loss.detach() micro_batch_size: int,
program: SequencePipe,
*inputs: Tensor,
):
runtime = SequencePipeRuntime(
micro_batch_size, program
)
ctx.save_for_backward(*inputs)
ctx.saved_runtime = runtime
return runtime.forward(inputs)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
*grads: Tensor,
):
inputs: STensor = ctx.saved_tensors
runtime: SequencePipeRuntime = ctx.saved_runtime
with torch.enable_grad():
input_grads = runtime.backward(inputs, grads)
return None, None, *input_grads
# class SequencePipe(nn.Module,ABC):
# def __init__(self,
# batch_size: int,
# seq_ranks: Optional[List[int]] = None,
# group: Any = None,
# ) -> None:
# super().__init__()
# self._batch_size = int(batch_size)
# if seq_ranks is None:
# seq_ranks = list(range(dist.get_world_size(group)))
# self._seq_ranks = tuple(seq_ranks)
# self._group = group
# rank = dist.get_rank(group)
# self._index = self._seq_ranks.index(rank)
# self._init_states: Optional[Tuple[Tensor,...]] = None
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# @property
# def batch_size(self) -> int:
# return self._batch_size
# @property
# def seq_ranks(self):
# return self._seq_ranks
# @property
# def num_ranks(self) -> int:
# return len(self._seq_ranks)
# @property
# def index(self) -> int:
# return self._index
# @property
# def group(self):
# return self._group
# def init_states(self, *states: Tensor):
# for s in states:
# assert isinstance(s, Tensor), f"states must be tuple of tensors"
# self._init_states = tuple(states)
# return self
# def enable_all_reduce(self,
# grads_op = dist.ReduceOp.SUM,
# buffers_op = dist.ReduceOp.AVG,
# group: Any = None,
# ):
# self._all_reduce_grads_op = grads_op
# self._all_reduce_buffers_op = buffers_op
# self._all_reduce_group = group
# return self
# def disable_all_reduce(self):
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# return self
# @abstractmethod
# def state_forward(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# raise NotImplementedError()
# @abstractmethod
# def loss_fn(self,
# batch_id: int,
# outputs: Tuple[Tensor,...],
# ) -> Tensor:
# raise NotImplementedError()
# @torch.inference_mode()
# def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# last_work = None
# outputs = None
# for batch_id in range(num_batchs):
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach)
# batch_states = self._get_batch_states(end - start)
# batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
# if outputs is None:
# outputs = []
# for t in batch_outputs:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# outputs.append(t)
# outputs = tuple(outputs)
# for o, t in zip(outputs, batch_outputs):
# o[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# last_work = self._save_states(batch_states)
# if last_work is not None:
# last_work.wait()
# return outputs
# def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs)
# source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs)
# target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
# fw_ready: Dict[int, Any] = {}
# bw_ready: Dict[int, Any] = {}
# last_work = None
# # input_batch_grads = []
# input_grads = None
# total_loss = None
# while True:
# _, op, batch_id = footprint.step()
# _, source_op, source_batch_id = source_footprint.step()
# _, target_op, target_batch_id = target_footprint.step()
# if op is None and source_op is None and target_op is None:
# break
# if last_work is not None:
# last_work.wait()
# last_work = None
# if source_op == "backward":
# input_states, = bw_ready.pop(source_batch_id)
# last_work = self._save_states(input_states, grad=True)
# del input_states
# elif target_op == "forward":
# *_, output_states = fw_ready[target_batch_id]
# last_work = self._save_states(output_states)
# del _, output_states
# if op == "forward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
# batch_input_states = self._get_batch_states(end - start, requires_grad=True)
# batch_outputs, batch_output_states = self._forward_inner(
# batch_id,
# batch_inputs, batch_input_states,
# )
# fw_ready[batch_id] = [
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# ]
# elif op == "backward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
# scale_factor = scale * self.batch_size / (end - start)
# grads, loss = self._backward_inner(
# batch_id,
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# scale_factor=scale_factor,
# )
# bw_ready[batch_id] = [
# batch_input_states,
# ]
# total_loss = loss if total_loss is None else total_loss + loss
# if input_grads is None:
# input_grads = []
# for t in grads:
# if t is not None:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# input_grads.append(t)
# input_grads = tuple(input_grads)
# for g, t in zip(input_grads, grads):
# if g is not None:
# g[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# prev_works = self._prev_inputs_backward()
# self._vector_backward(inputs, input_grads)
# self._post_inputs_backward(prev_works)
# return total_loss
# def _prev_inputs_backward(self):
# works = []
# works.extend(all_reduce_gradients(
# self, op=self._all_reduce_grads_op,
# group=self._all_reduce_group, async_op=True,
# ))
# works.extend(all_reduce_buffers(
# self, op=self._all_reduce_buffers_op,
# group=self._all_reduce_group, async_op=True,
# ))
# return works
# def _post_inputs_backward(self, works):
# for w in works:
# w.wait()
# works.clear()
# def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
# for s in states:
# s.grad = None
# if grad: # from target to source
# if self.index + 1 < self.num_ranks:
# s_grads = []
# for s in states:
# if s.dtype.is_floating_point:
# s.grad = torch.zeros_like(s.data)
# s_grads.append(s.grad)
# batch_recv(
# *s_grads,
# src=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True,
# ).wait()
# else: # from source to target
# if self.index > 0:
# batch_recv(
# *[s.data for s in states],
# src=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# ).wait()
# def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
# if grad: # from target to source
# if self.index > 0:
# s_grads = []
# for s in states:
# g, s.grad = s.grad, None
# if s.dtype.is_floating_point:
# s_grads.append(torch.zeros_like(s) if g is None else g)
# return batch_send(
# *s_grads,
# dst=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# )
# else: # from source to target
# if self.index + 1 < self.num_ranks:
# return batch_send(
# *[s.data for s in states],
# dst=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True
# )
# return BatchWork(None, None)
# def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]):
# loss = []
# grad = []
# for x, g in zip(vec_loss, vec_grad):
# if g is None:
# continue
# if not x.requires_grad:
# continue
# # if not x.dtype.is_floating_point:
# # continue
# loss.append(x.flatten())
# grad.append(g.flatten())
# if len(loss) != 0:
# loss = torch.cat(loss, dim=0)
# grad = torch.cat(grad, dim=0)
# loss.backward(grad)
# def _get_batch_inputs(self,
# start: int, end: int,
# detach: Tuple[Tensor,...],
# inputs: Optional[Tuple[Tensor,...]] = None,
# ) -> Tuple[Tensor,...]:
# outs: List[Tensor] = []
# for i, t in enumerate(detach):
# assert not t.requires_grad
# if inputs is None:
# outs.append(t[start:end])
# elif inputs[i].requires_grad:
# t = t[start:end]
# t.requires_grad_()
# t.retain_grad()
# outs.append(t)
# else:
# outs.append(t[start:end])
# return tuple(outs)
# def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
# assert self._init_states is not None, "please call init_states()."
# states: List[Tensor] = []
# for s in self._init_states:
# s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
# if requires_grad and self.index > 0 and s.dtype.is_floating_point:
# s.requires_grad_()
# s.retain_grad()
# states.append(s)
# return tuple(states)
# def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
# assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
# detach_inputs: List[Tensor] = []
# for t in inputs:
# assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
# detach_inputs.append(t.detach())
# return tuple(detach_inputs)
# def _forward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# self._load_states(states)
# outputs, next_states = self.state_forward(batch_id, inputs, states)
# return outputs, next_states
# def _backward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# outputs: Tuple[Tensor,...],
# input_states: Tuple[Tensor,...],
# output_states: Tuple[Tensor,...],
# scale_factor: float = 1.0,
# ) -> Tuple[Tuple[Tensor,...], Tensor]:
# loss = self.loss_fn(batch_id, outputs)
# if scale_factor != 1.0:
# loss = loss * scale_factor
# vec_loss = [loss]
# vec_grad = [torch.ones_like(loss)]
# self._load_states(output_states, grad=True)
# for s in output_states:
# if s.requires_grad:
# s.retain_grad()
# g, s.grad = s.grad, None
# vec_loss.append(s)
# vec_grad.append(g)
# self._vector_backward(vec_loss, vec_grad)
# input_grads = []
# for t in inputs:
# g, t.grad = t.grad, None
# input_grads.append(g)
# return tuple(input_grads), loss.detach()
class F1B1Footprint: class F1B1Footprint:
...@@ -417,11 +811,11 @@ class F1B1Footprint: ...@@ -417,11 +811,11 @@ class F1B1Footprint:
else: else:
self._finish = False self._finish = False
def step(self) -> Tuple[int, Optional[str], Optional[int]]: def step(self) -> Tuple[int, Optional[str], int]:
if self._finish: if self._finish:
return (self._count, None, None) return (self._count, None, -1)
ret = (self._count, "nop", None) ret = (self._count, "nop", -1)
if self._count >= self._bw_offset + 2 * self._bw_batch_id: if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id) ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1 self._bw_batch_id += 1
......
...@@ -58,6 +58,9 @@ class SparseBlocks: ...@@ -58,6 +58,9 @@ class SparseBlocks:
def __fetch_ids_sizes(local_ids: Tensor, group: Any): def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1 assert local_ids.dim() == 1
if group is None:
group = dist.GroupMember.WORLD
rank = dist.get_rank(group) rank = dist.get_rank(group)
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device) ikw = dict(dtype=torch.long, device=local_ids.device)
...@@ -80,8 +83,9 @@ class SparseBlocks: ...@@ -80,8 +83,9 @@ class SparseBlocks:
all_ids[i] = local_ids all_ids[i] = local_ids
else: else:
all_ids[i] = torch.empty(all_lens[i], **ikw) all_ids[i] = torch.empty(all_lens[i], **ikw)
src = dist.get_global_rank(group, i)
all_get[i] = dist.broadcast( all_get[i] = dist.broadcast(
all_ids[i], src=i, async_op=True, group=group all_ids[i], src=src, async_op=True, group=group
) )
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw) imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function): ...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_fetch(i: int): def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1) n = sp.adj_t(i).sparse_size(1)
if i == part_id: if i == part_id:
h = x.clone() h = x.clone()
else: else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device) h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
return dist.broadcast(h, src=i, group=sp.group, async_op=True) src = dist.get_global_rank(group, i)
return dist.broadcast(h, src=src, group=sp.group, async_op=True)
last_work = None last_work = None
out = None out = None
...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function): ...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_reduce(i: int, g: Tensor): def async_reduce(i: int, g: Tensor):
dst = dist.get_global_rank(group, i)
return dist.reduce( return dist.reduce(
g, dst=i, op=dist.ReduceOp.SUM, g, dst=dst, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True, group=sp.group, async_op=True,
) )
......
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe
from starrygl.parallel.sequence import STensor
from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
import torch_geometric.datasets as pyg_datasets
import torch_geometric.utils as pyg_utils
import logging
logging.getLogger().setLevel(logging.INFO)
def prepare_data(root: str, num_parts, part_algo: str = "metis"):
ctx = DistributedContext.get_default_context()
data = pyg_datasets.Planetoid(root, "Cora")[0]
if data.is_directed():
data.edge_index, _ = pyg_utils.to_undirected(data.edge_index)
data.edge_index, _ = pyg_utils.add_remaining_self_loops(data.edge_index)
data.num_classes = data.y.max().item() + 1
logging.info(f"num_nodes: {data.num_nodes}")
logging.info(f"num_edges: {data.num_edges}")
logging.info(f"num_features: {data.num_features}")
logging.info(f"num_classes: {data.num_classes}")
g = GraphData.from_pyg_data(data)
logging.info(f"GraphData.meta().keys(): {g.meta().keys()}")
logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo)
return g
class SimpleConv(pyg_nn.MessagePassing):
def __init__(self, in_feats: int, out_feats: int):
super().__init__(aggr="mean")
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
dst_len = x.size(0)
x = route.apply(x) # exchange features
return self.propagate(edge_index, x=x)[:dst_len]
def message(self, x_j: Tensor):
return x_j
def update(self, x: Tensor):
return F.relu(self.linear(x))
class SimpleGNN(nn.Module):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_ch = hidden_dims if i > 0 else num_features
out_ch = hidden_dims
self.layers.append(SimpleConv(in_ch, out_ch))
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
for layer in self.layers:
x = layer(x, edge_index, route)
return x
class SimpleRNN(SequencePipe, nn.Module):
def __init__(self,
num_classes: int,
hidden_dims: int,
num_layers: int,
device: Any,
group: Any,
) -> None:
super().__init__()
self.device = device
self.group = group
self.num_layers = num_layers
self.hidden_dims = hidden_dims
self.gru = nn.GRU(
input_size = hidden_dims,
hidden_size = hidden_dims,
num_layers = num_layers,
batch_first = True,
)
self.out = nn.Linear(hidden_dims, num_classes)
def forward(self, inputs, states):
x, = inputs # (N, L, H)
h, = states # (N, L, H)
h = h.transpose(0, 1).contiguous() # (L, N, H)
x, h = self.gru(x, h) # (N, L, H), (L, N, H)
h = h.transpose(0, 1).contiguous() # (N, L, H)
return (x,), (h, )
def loss_fn(self, inputs, labels) -> Tensor:
x, = inputs
return x.square().mean()
def get_group(self) -> Any:
return self.group
def get_init_states(self):
s = torch.zeros(self.num_layers, self.hidden_dims).to(self.device)
return (s,)
if __name__ == "__main__":
data_root = "./dataset"
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
hybrid_matrix = ctx.get_hybrid_matrix()
if hybrid_matrix.size(0) == 1:
hybrid_matrix = hybrid_matrix.view(2, -1)
ctx.sync_print(hybrid_matrix)
# sp is sequence parallel
# pp is partition parallel
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
# partition data
if ctx.rank == 0:
prepare_data(data_root, dist.get_world_size(pp_group))
dist.barrier()
g = GraphData.load_partition(
data_root,
dist.get_rank(pp_group), dist.get_world_size(pp_group),
).to(ctx.device)
route = g.to_route(pp_group) # only on subgroup
num_features = g.node("dst")["x"].size(-1)
num_classes = g.meta()["num_classes"]
hidden_dims = 128
num_layers = 3
gnn = SimpleGNN(num_features, hidden_dims, num_layers).to(ctx.device)
rnn = SimpleRNN(num_classes, hidden_dims, num_layers, device=ctx.device, group=sp_group).to(ctx.device)
opt = torch.optim.Adam([p for p in gnn.parameters()] + [p for p in rnn.parameters()])
for ep in range(1, 100+1):
seq_len = 200
xs = []
opt.zero_grad()
for _ in range(seq_len): # snapshot parallel between partition parallel subgroups
z = gnn(
x = g.node("dst")["x"],
edge_index = g.edge_index(),
route = route, #
)
xs.append(z.unsqueeze(1))
x = torch.cat(xs, dim=1) # (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss = rnn.fast_backward(32, (x,), (g.node("dst")["train_mask"],))
# all reduce
all_reduce_gradients(rnn)
all_reduce_buffers(rnn)
all_reduce_gradients(gnn)
all_reduce_buffers(gnn)
opt.step()
ctx.sync_print(loss)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment