Commit 2cdc59cc by Wenjie Huang

optimize SequencePipe's communication

parent 7302be6b
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
from .route import *
from .sequence import *
from .timeline import SequencePipe
from .sparse import *
\ No newline at end of file
from .pipe import SequencePipe
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
def vector_backward(
vec_loss: Sequence[Tensor],
vec_grad: Sequence[Tensor],
):
loss: List[Tensor] = []
grad: List[Tensor] = []
for x, g in zip(vec_loss, vec_grad):
if g is None:
continue
if not x.requires_grad:
continue
loss.append(x.flatten())
grad.append(g.flatten())
if loss:
loss = torch.cat(loss, dim=0)
grad = torch.cat(grad, dim=0)
loss.backward(grad)
\ No newline at end of file
......@@ -10,7 +10,6 @@ from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe
from starrygl.parallel.sequence import STensor
from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
......@@ -129,7 +128,6 @@ if __name__ == "__main__":
hybrid_matrix = ctx.get_hybrid_matrix()
if hybrid_matrix.size(0) == 1:
hybrid_matrix = hybrid_matrix.view(2, -1)
ctx.sync_print(hybrid_matrix)
# sp is sequence parallel
# pp is partition parallel
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment