Commit e6c06ee9 by Wenjie Huang

apply async routing to layerpipe

parent 6928f9b4
......@@ -7,6 +7,7 @@ from typing import *
from abc import ABC, abstractmethod
from contextlib import contextmanager
from .route import Route, RouteWork
from .timeline.utils import vector_backward
from .utils import *
......@@ -74,6 +75,14 @@ class LayerPipe(ABC):
models.append((key, val))
return tuple(models)
def register_route(self, *xs: Tensor):
for t in xs:
t.requires_route = True
@abstractmethod
def get_route(self) -> Route:
raise NotImplementedError
@abstractmethod
def layer_inputs(self,
inputs: Optional[Sequence[Tensor]] = None,
......@@ -84,7 +93,7 @@ class LayerPipe(ABC):
def layer_forward(self,
inputs: Sequence[Tensor],
) -> Sequence[Tensor]:
raise NotImplemented
raise NotImplementedError
@contextmanager
def _switch_layer(self,
......@@ -112,7 +121,7 @@ class LayerPipeRuntime:
self.num_layers = num_layers
self.num_snapshots = num_snapshots
self.program = program
self.ready_bw: Dict[Any, LayerDetach] = {}
self.ready_bw: Dict[Any, Union[LayerDetach, LayerRoute]] = {}
def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
......@@ -120,7 +129,8 @@ class LayerPipeRuntime:
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(None)
self.ready_bw[(layer_i, snap_i, 0)] = LayerDetach(*xs)
route = self.program.get_route()
self.ready_bw[(layer_i, snap_i, 0)] = LayerRoute(route, *xs)
elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i):
......@@ -136,12 +146,15 @@ class LayerPipeRuntime:
def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync":
self.ready_bw.pop((layer_i, snap_i, 0)).backward()
self.ready_bw[(layer_i, snap_i, 0)].backward()
elif op == "comp":
if layer_i + 1 < self.num_layers:
self.ready_bw.pop((layer_i + 1, snap_i, 0)).wait_gradients()
self.ready_bw.pop((layer_i, snap_i, 1)).backward()
for snap_i in range(self.num_snapshots):
self.ready_bw.pop((0, snap_i, 0)).wait_gradients()
assert len(self.ready_bw) == 0
class LayerDetach:
def __init__(self,
*inputs: Tensor,
......@@ -166,6 +179,69 @@ class LayerDetach:
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class LayerRoute:
def __init__(self,
route: Route,
*inputs: Tensor,
) -> None:
self._route = route
self._works: Optional[List[Union[Tensor, RouteWork]]] = []
for t in inputs:
r = t.requires_route if hasattr(t, "requires_route") else False
if r:
self._works.append(self._route.fw_tensor(t, async_op=True))
else:
self._works.append(t.detach())
self._inputs = inputs
self._outputs: Optional[List[Tensor]] = None
def values(self) -> Sequence[Tensor]:
if self._outputs is None:
works, self._works = self._works, None
assert works is not None
outputs = []
for s, t in zip(self._inputs, works):
if isinstance(t, RouteWork):
t = t.wait()
t = t.requires_grad_(s.requires_grad)
outputs.append(t)
self._outputs = outputs
return self._outputs
def backward(self):
assert self._works is None
assert self._outputs is not None
works = []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
rs = s.requires_route if hasattr(s, "requires_route") else False
rg = s.requires_grad
if rg and rs:
works.append(self._route.bw_tensor(g, async_op=True))
elif rg:
works.append(g)
else:
works.append(None)
self._works = works
self._outputs = None
def wait_gradients(self):
if self._works is None:
return
works, self._works = self._works, None
vec_loss, vec_grad = [], []
for t, g in zip(self._inputs, works):
if isinstance(g, RouteWork):
g = g.wait()
if not t.requires_grad:
continue
vec_loss.append(t)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint:
def __init__(self,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment