Commit 6928f9b4 by Wenjie Huang

add layerpipe

parent 28564281
from .route import *
from .timeline import SequencePipe
from .layerpipe import LayerPipe, LayerDetach
from .sparse import *
\ No newline at end of file
import torch
import torch.nn as nn
import torch.autograd as autograd
from torch import Tensor
from typing import *
from .route import Route
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .timeline.utils import vector_backward
from .utils import *
class LayerPipe(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def get_group(self):
raise NotImplemented
__all__ = [
"LayerPipe",
"LayerDetach",
]
@abstractmethod
def get_route(self, tag: int) -> Route:
raise NotImplemented
class LayerPipe(ABC):
def __init__(self) -> None:
self._layer_id: Optional[int] = None
self._snapshot_id: Optional[int] = None
self._rts: List[LayerPipeRuntime] = []
@property
def layer_id(self) -> int:
assert self._layer_id is not None
return self._layer_id
@property
def snapshot_id(self) -> int:
assert self._snapshot_id is not None
return self._snapshot_id
def apply(self,
num_layers: int,
num_snapshots: int,
) -> Sequence[Sequence[Tensor]]:
runtime = LayerPipeRuntime(num_layers, num_snapshots, self)
self._rts.append(runtime)
return runtime.forward()
def backward(self):
for runtime in self._rts:
runtime.backward()
self._rts.clear()
def all_reduce(self, async_op: bool = False):
works = []
for _, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def get_model(self) -> Sequence[Tuple[str, nn.Module]]:
models = []
for key in dir(self):
if key in {"layer_id", "snapshot_id"}:
continue
val = getattr(self, key)
if isinstance(val, nn.Module):
models.append((key, val))
return tuple(models)
@abstractmethod
def get_graph(self, tag: int) -> Any:
raise NotImplemented
def layer_inputs(self,
inputs: Optional[Sequence[Tensor]] = None,
) -> Sequence[Tensor]:
raise NotImplementedError
@abstractmethod
def forward(self,
tag: int,
dist_inputs: Sequence[Tensor],
self_inputs: Sequence[Tensor],
time_states: Sequence[Tensor],
) -> Tuple[
Sequence[Tensor],
Sequence[Tensor],
Sequence[Tensor],
]:
def layer_forward(self,
inputs: Sequence[Tensor],
) -> Sequence[Tensor]:
raise NotImplemented
def load_node(self) -> Sequence[Tensor]:
pass
def load_edge(self) -> Sequence[Tensor]:
pass
def load_conv_state(self) -> Sequence[Tensor]:
pass
def load_time_state(self) -> Sequence[Tensor]:
pass
def save_as_node(self, *outputs: Tensor):
pass
@contextmanager
def _switch_layer(self,
layer_id: int,
snapshot_id: int,
):
saved_layer_id = self._layer_id
saved_snapshot_id = self._snapshot_id
self._layer_id = layer_id
self._snapshot_id = snapshot_id
try:
yield
finally:
self._layer_id = saved_layer_id
self._snapshot_id = saved_snapshot_id
class LayerPipeRuntime:
def __init__(self,
num_layers: int,
num_snapshots: int,
program: LayerPipe,
) -> None:
self.num_layers = num_layers
self.num_snapshots = num_snapshots
self.program = program
self.ready_bw: Dict[Any, LayerDetach] = {}
def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(None)
self.ready_bw[(layer_i, snap_i, 0)] = LayerDetach(*xs)
elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_forward(xs)
self.ready_bw[(layer_i, snap_i, 1)] = LayerDetach(*xs)
xs = []
for snap_i in range(self.num_snapshots):
layer_i = self.num_layers - 1
xs.append(self.ready_bw[(layer_i, snap_i, 1)].values())
return xs
def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
self.ready_bw.pop((layer_i, snap_i, 0)).backward()
elif op == "comp":
self.ready_bw.pop((layer_i, snap_i, 1)).backward()
assert len(self.ready_bw) == 0
class LayerDetach:
def __init__(self,
*inputs: Tensor,
) -> None:
outputs = tuple(t.detach() for t in inputs)
for s, t in zip(inputs, outputs):
t.requires_grad_(s.requires_grad)
self._inputs = inputs
self._outputs = outputs
def values(self) -> Sequence[Tensor]:
return tuple(self._outputs)
def backward(self) -> None:
vec_loss, vec_grad = [], []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
if not s.requires_grad:
continue
vec_loss.append(s)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
# starting
if self._num_snapshots > 1:
yield "sync", 0, 0
yield "sync", 0, 1
elif self._num_snapshots > 0:
yield "sync", 0, 0
for i in range(0, self._num_snapshots, 2):
for l in range(self._num_layers):
# snapshot i
yield "comp", l, i
if l + 1 < self._num_layers:
yield "sync", l + 1, i
elif i + 2 < self._num_snapshots:
yield "sync", 0, i + 2
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
if l + 1 < self._num_layers:
yield "sync", l + 1, i + 1
elif i + 3 < self._num_snapshots:
yield "sync", 0, i + 3
class BackwardFootprint:
def __init__(self,
num_layers: int,
num_snapshots: int,
) -> None:
self._num_layers = num_layers
self._num_snapshots = num_snapshots
def __iter__(self):
if self._num_layers <= 0 or self._num_snapshots <= 0:
return
for i in range(0, self._num_snapshots, 2):
for j in range(self._num_layers):
l = self._num_layers - j - 1
# snapshot i
yield "comp", l, i
yield "sync", l, i
# snapshot i + 1
if i + 1 >= self._num_snapshots:
continue
yield "comp", l, i + 1
yield "sync", l, i + 1
def save_as_edge(self, *outputs: Tensor):
pass
def save_as_conv_state(self, *states: Tensor):
pass
def save_as_time_state(self, *states: Tensor):
pass
class SparseLayerPipe:
def __init__(self) -> None:
pass
\ No newline at end of file
......@@ -11,6 +11,7 @@ from contextlib import contextmanager
from .sync import VirtualMotions, VirtualForward, BatchSync
from .utils import vector_backward
from starrygl.parallel.utils import *
class SequencePipe(ABC):
def __init__(self) -> None:
......@@ -54,6 +55,24 @@ class SequencePipe(ABC):
models.append((key, val))
return tuple(models)
def to(self, device: Any):
for _, net in self.get_model():
net.to(device)
return self
def all_reduce(self, async_op: bool = False):
works = []
for name, net in self.get_model():
ws = all_reduce_gradients(net, async_op=async_op)
if async_op:
works.extend(ws)
ws = all_reduce_buffers(net, async_op=async_op)
if async_op:
works.extend(ws)
if async_op:
return ws
def apply(self, bs: int, *inputs: Tensor) -> Sequence[Tensor]:
runtime = SequencePipeRuntime(bs, self)
return SequencePipeFunction.apply(runtime, *inputs)
......@@ -87,8 +106,9 @@ class SequencePipe(ABC):
self._pos_begin = begin
self._pos_end = end
try:
yield
finally:
self._pos_begin = saved_begin
self._pos_end = saved_end
......
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -9,7 +9,7 @@ from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe
from starrygl.parallel import Route, SequencePipe, LayerPipe
from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
......@@ -76,6 +76,35 @@ class SimpleGNN(nn.Module):
x = layer(x, edge_index, route)
return x
class SimpleGNNPipe(LayerPipe):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
features: Tensor,
edge_index: Tensor,
route: Route,
) -> None:
super().__init__()
self.features = features
self.edge_index = edge_index
self.route = route
self.net = SimpleGNN(num_features, hidden_dims, num_layers)
def layer_inputs(self, inputs: Sequence[Tensor] | None = None) -> Sequence[Tensor]:
if self.layer_id == 0:
x = self.features
else:
x, = inputs
x = self.route.apply(x)
return (x,)
def layer_forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
x, = inputs
x = self.net.layers[self.layer_id](x, self.edge_index, self.route)
return (x,)
class SimpleRNN(SequencePipe, nn.Module):
def __init__(self,
num_classes: int,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment