Commit e6c06ee9 by Wenjie Huang

apply async routing to layerpipe

parent 6928f9b4
...@@ -7,6 +7,7 @@ from typing import * ...@@ -7,6 +7,7 @@ from typing import *
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from .route import Route, RouteWork
from .timeline.utils import vector_backward from .timeline.utils import vector_backward
from .utils import * from .utils import *
...@@ -74,6 +75,14 @@ class LayerPipe(ABC): ...@@ -74,6 +75,14 @@ class LayerPipe(ABC):
models.append((key, val)) models.append((key, val))
return tuple(models) return tuple(models)
def register_route(self, *xs: Tensor):
for t in xs:
t.requires_route = True
@abstractmethod
def get_route(self) -> Route:
raise NotImplementedError
@abstractmethod @abstractmethod
def layer_inputs(self, def layer_inputs(self,
inputs: Optional[Sequence[Tensor]] = None, inputs: Optional[Sequence[Tensor]] = None,
...@@ -84,7 +93,7 @@ class LayerPipe(ABC): ...@@ -84,7 +93,7 @@ class LayerPipe(ABC):
def layer_forward(self, def layer_forward(self,
inputs: Sequence[Tensor], inputs: Sequence[Tensor],
) -> Sequence[Tensor]: ) -> Sequence[Tensor]:
raise NotImplemented raise NotImplementedError
@contextmanager @contextmanager
def _switch_layer(self, def _switch_layer(self,
...@@ -112,7 +121,7 @@ class LayerPipeRuntime: ...@@ -112,7 +121,7 @@ class LayerPipeRuntime:
self.num_layers = num_layers self.num_layers = num_layers
self.num_snapshots = num_snapshots self.num_snapshots = num_snapshots
self.program = program self.program = program
self.ready_bw: Dict[Any, LayerDetach] = {} self.ready_bw: Dict[Any, Union[LayerDetach, LayerRoute]] = {}
def forward(self) -> Sequence[Sequence[Tensor]]: def forward(self) -> Sequence[Sequence[Tensor]]:
for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots): for op, layer_i, snap_i in ForwardFootprint(self.num_layers, self.num_snapshots):
...@@ -120,7 +129,8 @@ class LayerPipeRuntime: ...@@ -120,7 +129,8 @@ class LayerPipeRuntime:
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i): with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(None) xs = self.program.layer_inputs(None)
self.ready_bw[(layer_i, snap_i, 0)] = LayerDetach(*xs) route = self.program.get_route()
self.ready_bw[(layer_i, snap_i, 0)] = LayerRoute(route, *xs)
elif op == "comp": elif op == "comp":
xs = self.ready_bw[(layer_i, snap_i, 0)].values() xs = self.ready_bw[(layer_i, snap_i, 0)].values()
with self.program._switch_layer(layer_i, snap_i): with self.program._switch_layer(layer_i, snap_i):
...@@ -136,12 +146,15 @@ class LayerPipeRuntime: ...@@ -136,12 +146,15 @@ class LayerPipeRuntime:
def backward(self): def backward(self):
for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots): for op, layer_i, snap_i in BackwardFootprint(self.num_layers, self.num_snapshots):
if op == "sync": if op == "sync":
self.ready_bw.pop((layer_i, snap_i, 0)).backward() self.ready_bw[(layer_i, snap_i, 0)].backward()
elif op == "comp": elif op == "comp":
if layer_i + 1 < self.num_layers:
self.ready_bw.pop((layer_i + 1, snap_i, 0)).wait_gradients()
self.ready_bw.pop((layer_i, snap_i, 1)).backward() self.ready_bw.pop((layer_i, snap_i, 1)).backward()
for snap_i in range(self.num_snapshots):
self.ready_bw.pop((0, snap_i, 0)).wait_gradients()
assert len(self.ready_bw) == 0 assert len(self.ready_bw) == 0
class LayerDetach: class LayerDetach:
def __init__(self, def __init__(self,
*inputs: Tensor, *inputs: Tensor,
...@@ -166,6 +179,69 @@ class LayerDetach: ...@@ -166,6 +179,69 @@ class LayerDetach:
vec_grad.append(g) vec_grad.append(g)
vector_backward(vec_loss, vec_grad) vector_backward(vec_loss, vec_grad)
class LayerRoute:
def __init__(self,
route: Route,
*inputs: Tensor,
) -> None:
self._route = route
self._works: Optional[List[Union[Tensor, RouteWork]]] = []
for t in inputs:
r = t.requires_route if hasattr(t, "requires_route") else False
if r:
self._works.append(self._route.fw_tensor(t, async_op=True))
else:
self._works.append(t.detach())
self._inputs = inputs
self._outputs: Optional[List[Tensor]] = None
def values(self) -> Sequence[Tensor]:
if self._outputs is None:
works, self._works = self._works, None
assert works is not None
outputs = []
for s, t in zip(self._inputs, works):
if isinstance(t, RouteWork):
t = t.wait()
t = t.requires_grad_(s.requires_grad)
outputs.append(t)
self._outputs = outputs
return self._outputs
def backward(self):
assert self._works is None
assert self._outputs is not None
works = []
for s, t in zip(self._inputs, self._outputs):
g, t.grad = t.grad, None
rs = s.requires_route if hasattr(s, "requires_route") else False
rg = s.requires_grad
if rg and rs:
works.append(self._route.bw_tensor(g, async_op=True))
elif rg:
works.append(g)
else:
works.append(None)
self._works = works
self._outputs = None
def wait_gradients(self):
if self._works is None:
return
works, self._works = self._works, None
vec_loss, vec_grad = [], []
for t, g in zip(self._inputs, works):
if isinstance(g, RouteWork):
g = g.wait()
if not t.requires_grad:
continue
vec_loss.append(t)
vec_grad.append(g)
vector_backward(vec_loss, vec_grad)
class ForwardFootprint: class ForwardFootprint:
def __init__(self, def __init__(self,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment