Commit 6928f9b4 by Wenjie Huang

add layerpipe

parent 28564281
from .route import * from .route import *
from .timeline import SequencePipe from .timeline import SequencePipe
from .layerpipe import LayerPipe, LayerDetach
from .sparse import * from .sparse import *
\ No newline at end of file
import torch import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor from torch import Tensor
from typing import * from typing import *
from .route import Route
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from .timeline.utils import vector_backward
from .utils import *
class LayerPipe(ABC):
def __init__(self) -> None:
pass
@abstractmethod __all__ = [
def get_group(self): "LayerPipe",
raise NotImplemented "LayerDetach",
]
@abstractmethod
def get_route(self, tag: int) -> Route: class LayerPipe(ABC):
raise NotImplemented def __init__(self) -> None:
self._layer_id: Optional[int] = None
self._snapshot_id: Optional[int] = None
self._rts: List[LayerPipeRuntime] = []
@property
def layer_id(self) -> int:
assert self._layer_id is not None
return self._layer_id
@property
def snapshot_id(self) -> int:
assert self._snapshot_id is not None
return self._snapshot_id
def apply(self,
num_layers: int,
num_snapshots: int,
) -> Sequence[Sequence[Tensor]]:
runtime = LayerPipeRuntime(num_layers, num_snapshots, self)
self._rts.append(runtime)
return runtime.forward()
def backward(self):
for runtime in self._rts:
runtime.backward()
self._rts.clear()
def all_reduce(self, async_op: bool = False):
works = []
for _, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
if key in {"layer_id", "snapshot_id"}:
continue
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
@abstractmethod @abstractmethod
def get_graph(self, tag: int) -> Any: def layer_inputs(self,
raise NotImplemented inputs: Optional[Sequence[Tensor]] = None,
) -> Sequence[Tensor]:
raise NotImplementedError
@abstractmethod @abstractmethod
def forward(self, def layer_forward(self,
tag: int, inputs: Sequence[Tensor],
dist_inputs: Sequence[Tensor], ) -> Sequence[Tensor]:
self_inputs: Sequence[Tensor],
time_states: Sequence[Tensor],
) -> Tuple[
Sequence[Tensor],
Sequence[Tensor],
Sequence[Tensor],
]:
raise NotImplemented raise NotImplemented
def load_node(self) -> Sequence[Tensor]: @contextmanager
pass def _switch_layer(self,
layer_id: int,
def load_edge(self) -> Sequence[Tensor]: snapshot_id: int,
pass ):
saved_layer_id = self._layer_id
def load_conv_state(self) -> Sequence[Tensor]: saved_snapshot_id = self._snapshot_id
pass
self._layer_id = layer_id
def load_time_state(self) -> Sequence[Tensor]: self._snapshot_id = snapshot_id
pass try:
yield
def save_as_node(self, *outputs: Tensor): finally:
pass self._layer_id = saved_layer_id
self._snapshot_id = saved_snapshot_id
class LayerPipeRuntime:
def __init__(self,
num_layers: int,
num_snapshots: int,
program: LayerPipe,
) -> None:
self.num_layers = num_layers
self.num_snapshots = num_snapshots
self.program = program
self.ready_bw: Dict[Any, LayerDetach] = {}
def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(None)
self.ready_bw[(layer_i, snap_i, 0)] = LayerDetach(*xs)
elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_forward(xs)
self.ready_bw[(layer_i, snap_i, 1)] = LayerDetach(*xs)
xs = []
for snap_i in range(self.num_snapshots):
layer_i = self.num_layers - 1
xs.append(self.ready_bw[(layer_i, snap_i, 1)].values())
return xs
def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
self.ready_bw.pop((layer_i, snap_i, 0)).backward()
elif op == "comp":
self.ready_bw.pop((layer_i, snap_i, 1)).backward()
assert len(self.ready_bw) == 0
class LayerDetach:
def __init__(self,
*inputs: Tensor,
) -> None:
outputs = tuple(t.detach() for t in inputs)
for s, t in zip(inputs, outputs):
t.requires_grad_(s.requires_grad)
self._inputs = inputs
self._outputs = outputs
def values(self) -> Sequence[Tensor]:
return tuple(self._outputs)
def backward(self) -> None:
vec_loss, vec_grad = [], []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
if not s.requires_grad:
continue
vec_loss.append(s)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
# starting
if self._num_snapshots > 1:
yield "sync", 0, 0
yield "sync", 0, 1
elif self._num_snapshots > 0:
yield "sync", 0, 0
for i in range(0, self._num_snapshots, 2):
for l in range(self._num_layers):
# snapshot i
yield "comp", l, i
if l + 1 < self._num_layers:
yield "sync", l + 1, i
elif i + 2 < self._num_snapshots:
yield "sync", 0, i + 2
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
if l + 1 < self._num_layers:
yield "sync", l + 1, i + 1
elif i + 3 < self._num_snapshots:
yield "sync", 0, i + 3
class BackwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
for i in range(0, self._num_snapshots, 2):
for j in range(self._num_layers):
l = self._num_layers - j - 1
# snapshot i
yield "comp", l, i
yield "sync", l, i
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
yield "sync", l, i + 1
def save_as_edge(self, *outputs: Tensor):
pass
def save_as_conv_state(self, *states: Tensor):
pass
def save_as_time_state(self, *states: Tensor):
pass
class SparseLayerPipe:
def __init__(self) -> None:
pass
\ No newline at end of file
...@@ -11,6 +11,7 @@ from contextlib import contextmanager ...@@ -11,6 +11,7 @@ from contextlib import contextmanager
from .sync import VirtualMotions, VirtualForward, BatchSync from .sync import VirtualMotions, VirtualForward, BatchSync
from .utils import vector_backward from .utils import vector_backward
from starrygl.parallel.utils import *
class SequencePipe(ABC): class SequencePipe(ABC):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -54,6 +55,24 @@ class SequencePipe(ABC): ...@@ -54,6 +55,24 @@ class SequencePipe(ABC):
models.append((key, val)) models.append((key, val))
return tuple(models) return tuple(models)
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def all_reduce(self, async_op: bool = False):
works = []
for name, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def apply(self, bs: int, *inputs: Tensor) -> Sequence[Tensor]: def apply(self, bs: int, *inputs: Tensor) -> Sequence[Tensor]:
runtime = SequencePipeRuntime(bs, self) runtime = SequencePipeRuntime(bs, self)
return SequencePipeFunction.apply(runtime, *inputs) return SequencePipeFunction.apply(runtime, *inputs)
...@@ -87,8 +106,9 @@ class SequencePipe(ABC): ...@@ -87,8 +106,9 @@ class SequencePipe(ABC):
self._pos_begin = begin self._pos_begin = begin
self._pos_end = end self._pos_end = end
try:
yield yield
finally:
self._pos_begin = saved_begin self._pos_begin = saved_begin
self._pos_end = saved_end self._pos_end = saved_end
......
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Sequence, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -9,7 +9,7 @@ from typing import * ...@@ -9,7 +9,7 @@ from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed import DistributedContext
from starrygl.data import GraphData from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe from starrygl.parallel import Route, SequencePipe, LayerPipe
from starrygl.parallel.utils import * from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn import torch_geometric.nn as pyg_nn
...@@ -76,6 +76,35 @@ class SimpleGNN(nn.Module): ...@@ -76,6 +76,35 @@ class SimpleGNN(nn.Module):
x = layer(x, edge_index, route) x = layer(x, edge_index, route)
return x return x
class SimpleGNNPipe(LayerPipe):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
features: Tensor,
edge_index: Tensor,
route: Route,
) -> None:
super().__init__()
self.features = features
self.edge_index = edge_index
self.route = route
self.net = SimpleGNN(num_features, hidden_dims, num_layers)
def layer_inputs(self, inputs: Sequence[Tensor] | None = None) -> Sequence[Tensor]:
if self.layer_id == 0:
x = self.features
else:
x, = inputs
x = self.route.apply(x)
return (x,)
def layer_forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
x, = inputs
x = self.net.layers[self.layer_id](x, self.edge_index, self.route)
return (x,)
class SimpleRNN(SequencePipe, nn.Module): class SimpleRNN(SequencePipe, nn.Module):
def __init__(self, def __init__(self,
num_classes: int, num_classes: int,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment