Commit 09c175e2 by Wenjie Huang

add demo train_hybrid.py

parent 057dc57f
...@@ -152,8 +152,11 @@ def batch_send( ...@@ -152,8 +152,11 @@ def batch_send(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op: if async_op:
works = [] works = []
...@@ -177,8 +180,11 @@ def batch_recv( ...@@ -177,8 +180,11 @@ def batch_recv(
if len(tensors) == 0: if len(tensors) == 0:
return BatchWork(None, None) return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op: if async_op:
works = [] works = []
......
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from torch import Tensor from torch import Tensor
from typing import * from typing import *
import socket
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
...@@ -127,6 +128,18 @@ class DistributedContext: ...@@ -127,6 +128,18 @@ class DistributedContext:
self._local_rank = local_rank self._local_rank = local_rank
self._compute_device = device self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None self.__temp_ag_remote_object: Optional[rpc.RRef] = None
...@@ -145,16 +158,87 @@ class DistributedContext: ...@@ -145,16 +158,87 @@ class DistributedContext:
@property @property
def local_rank(self) -> int: def local_rank(self) -> int:
return self._local_rank return self._local_rank
@property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self._compute_device return self._compute_device
def get_default_group(self): def get_default_group(self):
return dist.distributed_c10d._get_default_group() # return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self): def get_default_store(self):
return dist.distributed_c10d._get_default_store() return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo: def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank rank = dist.get_rank() if rank is None else rank
......
...@@ -24,6 +24,9 @@ class Route: ...@@ -24,6 +24,9 @@ class Route:
bipartite: bool = True, bipartite: bool = True,
group: Any = None, group: Any = None,
) -> 'Route': ) -> 'Route':
if group is None:
group = dist.GroupMember.WORLD
fw_tables, bw_tables = Route._build_route_tables( fw_tables, bw_tables = Route._build_route_tables(
src_ids=src_ids, dst_ids=dst_ids, src_ids=src_ids, dst_ids=dst_ids,
bipartite=bipartite, group=group, bipartite=bipartite, group=group,
...@@ -256,8 +259,11 @@ class Route: ...@@ -256,8 +259,11 @@ class Route:
all_dst_ids[i] = dst_ids all_dst_ids[i] = dst_ids
else: else:
all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw) all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw)
src_rank = dist.get_global_rank(group, i)
all_dst_get[i] = dist.broadcast( all_dst_get[i] = dist.broadcast(
all_dst_ids[i], src=i, async_op=True, group=group all_dst_ids[i], src=src_rank,
async_op=True, group=group,
) )
fw_tables: List[Tensor] = [] fw_tables: List[Tensor] = []
......
...@@ -6,394 +6,788 @@ import torch.distributed as dist ...@@ -6,394 +6,788 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from starrygl.distributed.cclib import batch_send, batch_recv, BatchWork from starrygl.distributed.cclib import *
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from .utils import all_reduce_buffers, all_reduce_gradients
__all__ = [ __all__ = [
"SequencePipe", "SequencePipe",
] ]
OptTensor = Optional[Tensor]
SInteger = Sequence[int]
STensor = Sequence[Tensor]
SOptTensor = Sequence[OptTensor]
PairSTensor = Tuple[STensor, STensor]
OptSTensor = Optional[STensor]
class SequencePipe(nn.Module,ABC): class SequencePipe(ABC):
def __init__(self, def __init__(self) -> None:
batch_size: int,
seq_ranks: Optional[List[int]] = None,
group: Any = None,
) -> None:
super().__init__() super().__init__()
self._batch_size = int(batch_size) self._pos_begin = 0
self._pos_end = 0
if seq_ranks is None: self._pos_start = True
seq_ranks = list(range(dist.get_world_size(group)))
self._seq_ranks = tuple(seq_ranks)
self._group = group
rank = dist.get_rank(group)
self._index = self._seq_ranks.index(rank)
self._init_states: Optional[Tuple[Tensor,...]] = None
self._all_reduce_grads_op = None
self._all_reduce_buffers_op = None
@property @abstractmethod
def batch_size(self) -> int: def get_group(self) -> Any:
return self._batch_size raise NotImplementedError
@property @abstractmethod
def seq_ranks(self): def get_init_states(self) -> STensor:
return self._seq_ranks raise NotImplementedError
@property @abstractmethod
def num_ranks(self) -> int: def forward(self, inputs: STensor, states: STensor) -> PairSTensor:
return len(self._seq_ranks) raise NotImplementedError
def loss_fn(self, inputs: STensor, labels: STensor) -> Tensor:
raise NotImplementedError
def get_ranks(self) -> SInteger:
world_size = dist.get_world_size(self.get_group())
return tuple(range(world_size))
def apply(self, bs: int, *inputs: Tensor):
return SequencePipeFunction.apply(bs, self, *inputs)
def fast_backward(self, bs: int, inputs: STensor, labels: STensor) -> Tensor:
runtime = SequencePipeRuntime(bs, self, last_layer=True)
inputs_grads = runtime.backward(inputs, labels)
runtime.vector_backward(inputs, inputs_grads)
return runtime.acc_loss
@property @property
def index(self) -> int: def begin(self) -> int:
return self._index return self._pos_begin
@property @property
def group(self): def end(self) -> int:
return self._group return self._pos_end
def init_states(self, *states: Tensor):
for s in states:
assert isinstance(s, Tensor), f"states must be tuple of tensors"
self._init_states = tuple(states)
return self
def enable_all_reduce(self,
grads_op = dist.ReduceOp.SUM,
buffers_op = dist.ReduceOp.AVG,
):
self._all_reduce_grads_op = grads_op
self._all_reduce_buffers_op = buffers_op
return self
def disable_all_reduce(self): @property
self._all_reduce_grads_op = None def start(self) -> bool:
self._all_reduce_buffers_op = None return self._pos_start
return self
@abstractmethod @property
def state_forward(self, def batch_size(self) -> int:
batch_id: int, return self._pos_end - self._pos_begin
inputs: Tuple[Tensor,...],
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
raise NotImplementedError()
@abstractmethod @contextmanager
def loss_fn(self, def _switch(self,
batch_id: int, begin: int, end: int, start: bool,
outputs: Tuple[Tensor,...], ):
) -> Tensor: saved_begin = self._pos_begin
raise NotImplementedError() saved_end = self._pos_end
saved_start = self._pos_start
@torch.inference_mode()
def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]: self._pos_begin = begin
detach = self._detach_inputs(inputs) self._pos_end = end
self._pos_start = start
B = inputs[0].size(0) yield
num_batchs = (B + self.batch_size - 1) // self.batch_size
last_work = None self._pos_begin = saved_begin
self._pos_end = saved_end
self._pos_start = saved_start
outputs = None
for batch_id in range(num_batchs):
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, detach)
batch_states = self._get_batch_states(end - start)
batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for o, t in zip(outputs, batch_outputs):
o[start:end] = t.data
if last_work is not None: class SequencePipeRuntime:
last_work.wait() def __init__(self,
last_work = self._save_states(batch_states) micro_batch_size: int,
program: SequencePipe,
if last_work is not None: last_layer: bool = False,
last_work.wait() ) -> None:
self.micro_batch_size = micro_batch_size
self.program = program
self.last_layer = last_layer
self.acc_loss = None
self.group = program.get_group()
self.ranks = program.get_ranks()
self.index = self.ranks.index(dist.get_rank(self.group))
self._last_work = None
def forward(self, inputs: STensor) -> STensor:
detach = self.detach_inputs(inputs)
N = inputs[0].size(0)
S = (N + self.micro_batch_size - 1) // self.micro_batch_size
outputs = None
for i in range(S):
begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
start = (self.index == 0)
with self.program._switch(begin, end, start):
batch_inputs = self.get_batch_inputs(begin, end, detach)
batch_states = self.get_batch_states(begin, end)
batch_outputs, batch_states = self.forward_inner(batch_inputs, batch_states)
if outputs is None:
outputs = []
for t in batch_outputs:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
outputs.append(t)
outputs = tuple(outputs)
for t, b in zip(outputs, batch_outputs):
t[begin:end] = b.data
self.wait_last_work(
next_one=self.save_states(batch_states),
)
self.wait_last_work()
return outputs return outputs
def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
detach = self._detach_inputs(inputs)
B = inputs[0].size(0) def backward(self, inputs: STensor, output_grads: OptSTensor = None) -> STensor:
num_batchs = (B + self.batch_size - 1) // self.batch_size detach = self.detach_inputs(inputs)
detach_grads = self.detach_inputs(output_grads)
footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs) N = inputs[0].size(0)
source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs) S = (N + self.micro_batch_size - 1) // self.micro_batch_size
target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
fw_ready: Dict[int, Any] = {} footprint = F1B1Footprint(self.index, len(self.ranks), S)
bw_ready: Dict[int, Any] = {} source_footprint = F1B1Footprint(self.index-1, len(self.ranks), S)
target_footprint = F1B1Footprint(self.index+1, len(self.ranks), S)
last_work = None fw_ready = {}
bw_ready = {}
# input_batch_grads = []
input_grads = None input_grads = None
total_loss = None
while True: while True:
_, op, batch_id = footprint.step() _, op, i = footprint.step()
_, source_op, source_batch_id = source_footprint.step() _, source_op, source_i = source_footprint.step()
_, target_op, target_batch_id = target_footprint.step() _, target_op, target_i = target_footprint.step()
if op is None and source_op is None and target_op is None: if (not op) and (not source_op) and (not target_op):
break break
if last_work is not None: self.wait_last_work()
last_work.wait()
last_work = None
if source_op == "backward": if source_op == "backward":
input_states, = bw_ready.pop(source_batch_id) batch_input_state_grads, = bw_ready.pop(source_i)
last_work = self._save_states(input_states, grad=True) self.wait_last_work(
del input_states next_one=self.save_grads(batch_input_state_grads),
elif target_op == "forward":
*_, output_states = fw_ready[target_batch_id]
last_work = self._save_states(output_states)
del _, output_states
if op == "forward":
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
batch_input_states = self._get_batch_states(end - start, requires_grad=True)
batch_outputs, batch_output_states = self._forward_inner(
batch_id,
batch_inputs, batch_input_states,
) )
fw_ready[batch_id] = [ elif target_op == "forward":
batch_inputs, *_, batch_output_states = fw_ready[target_i]
batch_outputs, self.wait_last_work(
batch_input_states, next_one=self.save_states(batch_output_states),
batch_output_states,
]
elif op == "backward":
start = batch_id * self.batch_size
end = min(B, start + self.batch_size)
batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
scale_factor = scale * self.batch_size / (end - start)
grads, loss = self._backward_inner(
batch_id,
batch_inputs,
batch_outputs,
batch_input_states,
batch_output_states,
scale_factor=scale_factor,
) )
bw_ready[batch_id] = [ del _
batch_input_states,
] begin = i * self.micro_batch_size
end = min(N, begin + self.micro_batch_size)
total_loss = loss if total_loss is None else total_loss + loss start = (self.index == 0)
with self.program._switch(begin, end, start):
if input_grads is None: if op == "forward":
input_grads = [] batch_inputs = self.get_batch_inputs(begin, end, detach, inputs)
for t in grads: batch_input_states = self.get_batch_states(begin, end, requires_grad=True)
if t is not None:
t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device) batch_outputs, batch_output_states = \
input_grads.append(t) self.forward_inner(batch_inputs, batch_input_states)
input_grads = tuple(input_grads) fw_ready[i] = [
batch_inputs, batch_outputs,
for g, t in zip(input_grads, grads): batch_input_states, batch_output_states,
if g is not None: ]
g[start:end] = t.data elif op == "backward":
batch_inputs, batch_outputs,\
batch_input_states, batch_output_states = fw_ready.pop(i)
batch_output_grads = self.get_batch_inputs(begin, end, detach_grads)
batch_input_grads, batch_input_state_grads = \
self.backward_inner(
batch_inputs, batch_outputs,
batch_input_states, batch_output_states,
batch_output_grads,
)
bw_ready[i] = [
batch_input_state_grads,
]
if input_grads is None:
input_grads = []
for t in batch_input_grads:
if t is not None:
t = torch.empty(N, *t.shape[1:], dtype=t.dtype, device=t.device)
input_grads.append(t)
input_grads = tuple(input_grads)
for g, t in zip(input_grads, batch_input_grads):
if g is not None:
g[begin:end] = t.data
self.wait_last_work()
return input_grads
def wait_last_work(self, next_one = None):
if self._last_work is not None:
self._last_work.wait()
self._last_work = next_one
def detach_inputs(self, inputs: STensor) -> STensor:
detach: STensor = []
for t in inputs:
assert t.size(0) == inputs[0].size(0), "The first dimension of all tensors must be the same."
detach.append(t.detach())
return detach
def get_batch_inputs(self,
begin: int, end: int,
detach: STensor, inputs: OptSTensor = None,
) -> STensor:
batch: STensor = []
for i, t in enumerate(detach):
assert not t.requires_grad
if inputs and inputs[i].requires_grad:
t = t[begin:end]
t.requires_grad_()
t.retain_grad()
batch.append(t)
else:
batch.append(t[begin:end])
return batch
def get_batch_states(self,
begin: int, end: int,
requires_grad: bool = False,
) -> STensor:
states = []
for s in self.program.get_init_states():
s = s.unsqueeze(0).broadcast_to(
end - begin, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.is_floating_point():
s.requires_grad_()
s.retain_grad()
states.append(s)
return states
def forward_inner(self,
inputs: STensor,
states: STensor,
):
states = self.load_states(states)
return self.program.forward(inputs, states)
def backward_inner(self,
inputs: STensor,
outputs: STensor,
input_states: STensor,
output_states: STensor,
output_grads: STensor,
) -> PairSTensor:
vloss = []
vgrad = []
if last_work is not None: if self.last_layer:
last_work.wait() vloss.append(self.program.loss_fn(outputs, output_grads))
vgrad.append(torch.ones_like(vloss[-1]))
if self.acc_loss is None:
self.acc_loss = vloss[-1].detach()
else:
self.acc_loss += vloss[-1].detach()
else:
vloss.extend(outputs)
vgrad.extend(output_grads)
prev_works = self._prev_inputs_backward() vloss.extend(output_states)
self._vector_backward(inputs, input_grads) vgrad.extend(self.load_grads(output_states))
self._post_inputs_backward(prev_works)
self.vector_backward(vloss, vgrad)
input_grads = []
for t in inputs:
g, t.grad = t.grad, None
input_grads.append(g)
return total_loss input_state_grads = []
for s in input_states:
def _prev_inputs_backward(self): g, s.grad = s.grad, None
works = [] if s.is_floating_point():
g = torch.zeros_like(s) if g is None else g
works.extend(all_reduce_gradients( else:
self, op=self._all_reduce_grads_op, g = None
group=self.group, async_op=True, input_state_grads.append(g)
))
works.extend(all_reduce_buffers( return input_grads, input_state_grads
self, op=self._all_reduce_buffers_op,
group=self.group, async_op=True, def load_states(self, states: STensor):
)) if self.index > 0:
batch_recv(
return works *[s.data for s in states],
src=self.ranks[self.index - 1],
def _post_inputs_backward(self, works): group=self.group,
for w in works: async_op=True,
w.wait() ).wait()
works.clear() return states
def _load_states(self, states: Tuple[Tensor,...], grad: bool = False): def load_grads(self, states: STensor):
for s in states: grads: SOptTensor = []
s.grad = None if self.index + 1 < len(self.ranks):
for s in states:
if grad: # from target to source if s.is_floating_point():
if self.index + 1 < self.num_ranks: g = torch.zeros_like(s)
s_grads = [] grads.append(g)
for s in states: else:
if s.dtype.is_floating_point: grads.append(None)
s.grad = torch.zeros_like(s.data)
s_grads.append(s.grad) batch_recv(
*[g.data for g in grads if g is not None],
batch_recv( src=self.ranks[self.index + 1],
*s_grads, group=self.group,
src=self.seq_ranks[self.index + 1], async_op=True,
group=self.group, ).wait()
async_op=True, else:
).wait() for s in states:
else: # from source to target grads.append(None)
if self.index > 0: return grads
batch_recv(
*[s.data for s in states], def save_states(self, states: STensor):
src=self.seq_ranks[self.index - 1], if self.index + 1 < len(self.ranks):
group=self.group, return batch_send(
async_op=True, *[s.data for s in states],
).wait() dst=self.ranks[self.index + 1],
group=self.group,
def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork: async_op=True,
if grad: # from target to source )
if self.index > 0: else:
s_grads = [] return BatchWork(None, None)
for s in states:
g, s.grad = s.grad, None
if s.dtype.is_floating_point:
s_grads.append(torch.zeros_like(s) if g is None else g)
return batch_send(
*s_grads,
dst=self.seq_ranks[self.index - 1],
group=self.group,
async_op=True,
)
else: # from source to target
if self.index + 1 < self.num_ranks:
return batch_send(
*[s.data for s in states],
dst=self.seq_ranks[self.index + 1],
group=self.group,
async_op=True
)
return BatchWork(None, None)
def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]): def save_grads(self, grads: STensor):
loss = [] if self.index > 0:
grad = [] return batch_send(
for x, g in zip(vec_loss, vec_grad): *[g.data for g in grads if g is not None],
dst=self.ranks[self.index - 1],
group=self.group,
async_op=True,
)
else:
return BatchWork(None, None)
def vector_backward(self, vloss: STensor, vgrad: STensor):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vloss, vgrad):
if g is None: if g is None:
continue continue
if not x.requires_grad: if not x.requires_grad:
continue continue
# if not x.dtype.is_floating_point:
# continue
loss.append(x.flatten()) loss.append(x.flatten())
grad.append(g.flatten()) grad.append(g.flatten())
if len(loss) != 0: if loss:
loss = torch.cat(loss, dim=0) loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0) grad = torch.cat(grad, dim=0)
loss.backward(grad) loss.backward(grad)
class SequencePipeFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
micro_batch_size: int,
program: SequencePipe,
*inputs: Tensor,
):
runtime = SequencePipeRuntime(
micro_batch_size, program
)
ctx.save_for_backward(*inputs)
ctx.saved_runtime = runtime
return runtime.forward(inputs)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
*grads: Tensor,
):
inputs: STensor = ctx.saved_tensors
runtime: SequencePipeRuntime = ctx.saved_runtime
with torch.enable_grad():
input_grads = runtime.backward(inputs, grads)
return None, None, *input_grads
def _get_batch_inputs(self,
start: int, end: int, # class SequencePipe(nn.Module,ABC):
detach: Tuple[Tensor,...], # def __init__(self,
inputs: Optional[Tuple[Tensor,...]] = None, # batch_size: int,
) -> Tuple[Tensor,...]: # seq_ranks: Optional[List[int]] = None,
outs: List[Tensor] = [] # group: Any = None,
for i, t in enumerate(detach): # ) -> None:
assert not t.requires_grad # super().__init__()
if inputs is None: # self._batch_size = int(batch_size)
outs.append(t[start:end])
elif inputs[i].requires_grad: # if seq_ranks is None:
t = t[start:end] # seq_ranks = list(range(dist.get_world_size(group)))
t.requires_grad_() # self._seq_ranks = tuple(seq_ranks)
t.retain_grad() # self._group = group
outs.append(t)
else: # rank = dist.get_rank(group)
outs.append(t[start:end]) # self._index = self._seq_ranks.index(rank)
return tuple(outs)
# self._init_states: Optional[Tuple[Tensor,...]] = None
def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
assert self._init_states is not None, "please call init_states()." # self._all_reduce_grads_op = None
states: List[Tensor] = [] # self._all_reduce_buffers_op = None
for s in self._init_states: # self._all_reduce_group = None
s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
if requires_grad and self.index > 0 and s.dtype.is_floating_point:
s.requires_grad_()
s.retain_grad()
states.append(s)
return tuple(states)
def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: # @property
assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one." # def batch_size(self) -> int:
detach_inputs: List[Tensor] = [] # return self._batch_size
for t in inputs:
assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same." # @property
detach_inputs.append(t.detach()) # def seq_ranks(self):
return tuple(detach_inputs) # return self._seq_ranks
def _forward_inner(self, # @property
batch_id: int, # def num_ranks(self) -> int:
inputs: Tuple[Tensor,...], # return len(self._seq_ranks)
states: Tuple[Tensor,...],
) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]: # @property
# def index(self) -> int:
# return self._index
# @property
# def group(self):
# return self._group
# def init_states(self, *states: Tensor):
# for s in states:
# assert isinstance(s, Tensor), f"states must be tuple of tensors"
# self._init_states = tuple(states)
# return self
# def enable_all_reduce(self,
# grads_op = dist.ReduceOp.SUM,
# buffers_op = dist.ReduceOp.AVG,
# group: Any = None,
# ):
# self._all_reduce_grads_op = grads_op
# self._all_reduce_buffers_op = buffers_op
# self._all_reduce_group = group
# return self
# def disable_all_reduce(self):
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# return self
# @abstractmethod
# def state_forward(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# raise NotImplementedError()
# @abstractmethod
# def loss_fn(self,
# batch_id: int,
# outputs: Tuple[Tensor,...],
# ) -> Tensor:
# raise NotImplementedError()
# @torch.inference_mode()
# def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
self._load_states(states) # last_work = None
outputs, next_states = self.state_forward(batch_id, inputs, states)
return outputs, next_states
def _backward_inner(self,
batch_id: int,
inputs: Tuple[Tensor,...],
outputs: Tuple[Tensor,...],
input_states: Tuple[Tensor,...],
output_states: Tuple[Tensor,...],
scale_factor: float = 1.0,
) -> Tuple[Tuple[Tensor,...], Tensor]:
loss = self.loss_fn(batch_id, outputs)
if scale_factor != 1.0:
loss = loss * scale_factor
vec_loss = [loss]
vec_grad = [torch.ones_like(loss)]
self._load_states(output_states, grad=True)
for s in output_states:
if s.requires_grad:
s.retain_grad()
g, s.grad = s.grad, None
vec_loss.append(s)
vec_grad.append(g)
self._vector_backward(vec_loss, vec_grad)
input_grads = [] # outputs = None
for t in inputs: # for batch_id in range(num_batchs):
g, t.grad = t.grad, None # start = batch_id * self.batch_size
input_grads.append(g) # end = min(B, start + self.batch_size)
return tuple(input_grads), loss.detach()
# batch_inputs = self._get_batch_inputs(start, end, detach)
# batch_states = self._get_batch_states(end - start)
# batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
# if outputs is None:
# outputs = []
# for t in batch_outputs:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# outputs.append(t)
# outputs = tuple(outputs)
# for o, t in zip(outputs, batch_outputs):
# o[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# last_work = self._save_states(batch_states)
# if last_work is not None:
# last_work.wait()
# return outputs
# def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs)
# source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs)
# target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
# fw_ready: Dict[int, Any] = {}
# bw_ready: Dict[int, Any] = {}
# last_work = None
# # input_batch_grads = []
# input_grads = None
# total_loss = None
# while True:
# _, op, batch_id = footprint.step()
# _, source_op, source_batch_id = source_footprint.step()
# _, target_op, target_batch_id = target_footprint.step()
# if op is None and source_op is None and target_op is None:
# break
# if last_work is not None:
# last_work.wait()
# last_work = None
# if source_op == "backward":
# input_states, = bw_ready.pop(source_batch_id)
# last_work = self._save_states(input_states, grad=True)
# del input_states
# elif target_op == "forward":
# *_, output_states = fw_ready[target_batch_id]
# last_work = self._save_states(output_states)
# del _, output_states
# if op == "forward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
# batch_input_states = self._get_batch_states(end - start, requires_grad=True)
# batch_outputs, batch_output_states = self._forward_inner(
# batch_id,
# batch_inputs, batch_input_states,
# )
# fw_ready[batch_id] = [
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# ]
# elif op == "backward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
# scale_factor = scale * self.batch_size / (end - start)
# grads, loss = self._backward_inner(
# batch_id,
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# scale_factor=scale_factor,
# )
# bw_ready[batch_id] = [
# batch_input_states,
# ]
# total_loss = loss if total_loss is None else total_loss + loss
# if input_grads is None:
# input_grads = []
# for t in grads:
# if t is not None:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# input_grads.append(t)
# input_grads = tuple(input_grads)
# for g, t in zip(input_grads, grads):
# if g is not None:
# g[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# prev_works = self._prev_inputs_backward()
# self._vector_backward(inputs, input_grads)
# self._post_inputs_backward(prev_works)
# return total_loss
# def _prev_inputs_backward(self):
# works = []
# works.extend(all_reduce_gradients(
# self, op=self._all_reduce_grads_op,
# group=self._all_reduce_group, async_op=True,
# ))
# works.extend(all_reduce_buffers(
# self, op=self._all_reduce_buffers_op,
# group=self._all_reduce_group, async_op=True,
# ))
# return works
# def _post_inputs_backward(self, works):
# for w in works:
# w.wait()
# works.clear()
# def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
# for s in states:
# s.grad = None
# if grad: # from target to source
# if self.index + 1 < self.num_ranks:
# s_grads = []
# for s in states:
# if s.dtype.is_floating_point:
# s.grad = torch.zeros_like(s.data)
# s_grads.append(s.grad)
# batch_recv(
# *s_grads,
# src=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True,
# ).wait()
# else: # from source to target
# if self.index > 0:
# batch_recv(
# *[s.data for s in states],
# src=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# ).wait()
# def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
# if grad: # from target to source
# if self.index > 0:
# s_grads = []
# for s in states:
# g, s.grad = s.grad, None
# if s.dtype.is_floating_point:
# s_grads.append(torch.zeros_like(s) if g is None else g)
# return batch_send(
# *s_grads,
# dst=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# )
# else: # from source to target
# if self.index + 1 < self.num_ranks:
# return batch_send(
# *[s.data for s in states],
# dst=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True
# )
# return BatchWork(None, None)
# def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]):
# loss = []
# grad = []
# for x, g in zip(vec_loss, vec_grad):
# if g is None:
# continue
# if not x.requires_grad:
# continue
# # if not x.dtype.is_floating_point:
# # continue
# loss.append(x.flatten())
# grad.append(g.flatten())
# if len(loss) != 0:
# loss = torch.cat(loss, dim=0)
# grad = torch.cat(grad, dim=0)
# loss.backward(grad)
# def _get_batch_inputs(self,
# start: int, end: int,
# detach: Tuple[Tensor,...],
# inputs: Optional[Tuple[Tensor,...]] = None,
# ) -> Tuple[Tensor,...]:
# outs: List[Tensor] = []
# for i, t in enumerate(detach):
# assert not t.requires_grad
# if inputs is None:
# outs.append(t[start:end])
# elif inputs[i].requires_grad:
# t = t[start:end]
# t.requires_grad_()
# t.retain_grad()
# outs.append(t)
# else:
# outs.append(t[start:end])
# return tuple(outs)
# def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
# assert self._init_states is not None, "please call init_states()."
# states: List[Tensor] = []
# for s in self._init_states:
# s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
# if requires_grad and self.index > 0 and s.dtype.is_floating_point:
# s.requires_grad_()
# s.retain_grad()
# states.append(s)
# return tuple(states)
# def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
# assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
# detach_inputs: List[Tensor] = []
# for t in inputs:
# assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
# detach_inputs.append(t.detach())
# return tuple(detach_inputs)
# def _forward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# self._load_states(states)
# outputs, next_states = self.state_forward(batch_id, inputs, states)
# return outputs, next_states
# def _backward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# outputs: Tuple[Tensor,...],
# input_states: Tuple[Tensor,...],
# output_states: Tuple[Tensor,...],
# scale_factor: float = 1.0,
# ) -> Tuple[Tuple[Tensor,...], Tensor]:
# loss = self.loss_fn(batch_id, outputs)
# if scale_factor != 1.0:
# loss = loss * scale_factor
# vec_loss = [loss]
# vec_grad = [torch.ones_like(loss)]
# self._load_states(output_states, grad=True)
# for s in output_states:
# if s.requires_grad:
# s.retain_grad()
# g, s.grad = s.grad, None
# vec_loss.append(s)
# vec_grad.append(g)
# self._vector_backward(vec_loss, vec_grad)
# input_grads = []
# for t in inputs:
# g, t.grad = t.grad, None
# input_grads.append(g)
# return tuple(input_grads), loss.detach()
class F1B1Footprint: class F1B1Footprint:
...@@ -417,11 +811,11 @@ class F1B1Footprint: ...@@ -417,11 +811,11 @@ class F1B1Footprint:
else: else:
self._finish = False self._finish = False
def step(self) -> Tuple[int, Optional[str], Optional[int]]: def step(self) -> Tuple[int, Optional[str], int]:
if self._finish: if self._finish:
return (self._count, None, None) return (self._count, None, -1)
ret = (self._count, "nop", None) ret = (self._count, "nop", -1)
if self._count >= self._bw_offset + 2 * self._bw_batch_id: if self._count >= self._bw_offset + 2 * self._bw_batch_id:
ret = (self._count, "backward", self._bw_batch_id) ret = (self._count, "backward", self._bw_batch_id)
self._bw_batch_id += 1 self._bw_batch_id += 1
......
...@@ -58,6 +58,9 @@ class SparseBlocks: ...@@ -58,6 +58,9 @@ class SparseBlocks:
def __fetch_ids_sizes(local_ids: Tensor, group: Any): def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1 assert local_ids.dim() == 1
if group is None:
group = dist.GroupMember.WORLD
rank = dist.get_rank(group) rank = dist.get_rank(group)
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device) ikw = dict(dtype=torch.long, device=local_ids.device)
...@@ -80,8 +83,9 @@ class SparseBlocks: ...@@ -80,8 +83,9 @@ class SparseBlocks:
all_ids[i] = local_ids all_ids[i] = local_ids
else: else:
all_ids[i] = torch.empty(all_lens[i], **ikw) all_ids[i] = torch.empty(all_lens[i], **ikw)
src = dist.get_global_rank(group, i)
all_get[i] = dist.broadcast( all_get[i] = dist.broadcast(
all_ids[i], src=i, async_op=True, group=group all_ids[i], src=src, async_op=True, group=group
) )
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw) imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function): ...@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_fetch(i: int): def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1) n = sp.adj_t(i).sparse_size(1)
if i == part_id: if i == part_id:
h = x.clone() h = x.clone()
else: else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device) h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
return dist.broadcast(h, src=i, group=sp.group, async_op=True) src = dist.get_global_rank(group, i)
return dist.broadcast(h, src=src, group=sp.group, async_op=True)
last_work = None last_work = None
out = None out = None
...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function): ...@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
part_id = sp.part_id part_id = sp.part_id
num_parts = sp.num_parts num_parts = sp.num_parts
group = sp.group
if group is None:
group = dist.GroupMember.WORLD
def async_reduce(i: int, g: Tensor): def async_reduce(i: int, g: Tensor):
dst = dist.get_global_rank(group, i)
return dist.reduce( return dist.reduce(
g, dst=i, op=dist.ReduceOp.SUM, g, dst=dst, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True, group=sp.group, async_op=True,
) )
......
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe
from starrygl.parallel.sequence import STensor
from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
import torch_geometric.datasets as pyg_datasets
import torch_geometric.utils as pyg_utils
import logging
logging.getLogger().setLevel(logging.INFO)
def prepare_data(root: str, num_parts, part_algo: str = "metis"):
ctx = DistributedContext.get_default_context()
data = pyg_datasets.Planetoid(root, "Cora")[0]
if data.is_directed():
data.edge_index, _ = pyg_utils.to_undirected(data.edge_index)
data.edge_index, _ = pyg_utils.add_remaining_self_loops(data.edge_index)
data.num_classes = data.y.max().item() + 1
logging.info(f"num_nodes: {data.num_nodes}")
logging.info(f"num_edges: {data.num_edges}")
logging.info(f"num_features: {data.num_features}")
logging.info(f"num_classes: {data.num_classes}")
g = GraphData.from_pyg_data(data)
logging.info(f"GraphData.meta().keys(): {g.meta().keys()}")
logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo)
return g
class SimpleConv(pyg_nn.MessagePassing):
def __init__(self, in_feats: int, out_feats: int):
super().__init__(aggr="mean")
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
dst_len = x.size(0)
x = route.apply(x) # exchange features
return self.propagate(edge_index, x=x)[:dst_len]
def message(self, x_j: Tensor):
return x_j
def update(self, x: Tensor):
return F.relu(self.linear(x))
class SimpleGNN(nn.Module):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_ch = hidden_dims if i > 0 else num_features
out_ch = hidden_dims
self.layers.append(SimpleConv(in_ch, out_ch))
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
for layer in self.layers:
x = layer(x, edge_index, route)
return x
class SimpleRNN(SequencePipe, nn.Module):
def __init__(self,
num_classes: int,
hidden_dims: int,
num_layers: int,
device: Any,
group: Any,
) -> None:
super().__init__()
self.device = device
self.group = group
self.num_layers = num_layers
self.hidden_dims = hidden_dims
self.gru = nn.GRU(
input_size = hidden_dims,
hidden_size = hidden_dims,
num_layers = num_layers,
batch_first = True,
)
self.out = nn.Linear(hidden_dims, num_classes)
def forward(self, inputs, states):
x, = inputs # (N, L, H)
h, = states # (N, L, H)
h = h.transpose(0, 1).contiguous() # (L, N, H)
x, h = self.gru(x, h) # (N, L, H), (L, N, H)
h = h.transpose(0, 1).contiguous() # (N, L, H)
return (x,), (h, )
def loss_fn(self, inputs, labels) -> Tensor:
x, = inputs
return x.square().mean()
def get_group(self) -> Any:
return self.group
def get_init_states(self):
s = torch.zeros(self.num_layers, self.hidden_dims).to(self.device)
return (s,)
if __name__ == "__main__":
data_root = "./dataset"
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
hybrid_matrix = ctx.get_hybrid_matrix()
if hybrid_matrix.size(0) == 1:
hybrid_matrix = hybrid_matrix.view(2, -1)
ctx.sync_print(hybrid_matrix)
# sp is sequence parallel
# pp is partition parallel
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
# partition data
if ctx.rank == 0:
prepare_data(data_root, dist.get_world_size(pp_group))
dist.barrier()
g = GraphData.load_partition(
data_root,
dist.get_rank(pp_group), dist.get_world_size(pp_group),
).to(ctx.device)
route = g.to_route(pp_group) # only on subgroup
num_features = g.node("dst")["x"].size(-1)
num_classes = g.meta()["num_classes"]
hidden_dims = 128
num_layers = 3
gnn = SimpleGNN(num_features, hidden_dims, num_layers).to(ctx.device)
rnn = SimpleRNN(num_classes, hidden_dims, num_layers, device=ctx.device, group=sp_group).to(ctx.device)
opt = torch.optim.Adam([p for p in gnn.parameters()] + [p for p in rnn.parameters()])
for ep in range(1, 100+1):
seq_len = 200
xs = []
opt.zero_grad()
for _ in range(seq_len): # snapshot parallel between partition parallel subgroups
z = gnn(
x = g.node("dst")["x"],
edge_index = g.edge_index(),
route = route, #
)
xs.append(z.unsqueeze(1))
x = torch.cat(xs, dim=1) # (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss = rnn.fast_backward(32, (x,), (g.node("dst")["train_mask"],))
# all reduce
all_reduce_gradients(rnn)
all_reduce_buffers(rnn)
all_reduce_gradients(gnn)
all_reduce_buffers(gnn)
opt.step()
ctx.sync_print(loss)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment