Commit 88de1d9c by Wenjie Huang

SquencePipe support long type of tensors

parent 32fec45c
...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected ...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp import os.path as osp
import sys import sys
from starrygl.graph import GraphData from starrygl.data import GraphData
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
......
...@@ -149,6 +149,9 @@ def batch_send( ...@@ -149,6 +149,9 @@ def batch_send(
group: Any = None, group: Any = None,
async_op: bool = False, async_op: bool = False,
): ):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
...@@ -171,6 +174,9 @@ def batch_recv( ...@@ -171,6 +174,9 @@ def batch_recv(
group: Any = None, group: Any = None,
async_op: bool = False, async_op: bool = False,
): ):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors) # tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group) backend = dist.get_backend(group)
......
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None):
for p in net.parameters():
dist.all_reduce(p.grad, op=op, group=group)
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None):
for b in net.buffers():
dist.all_reduce(b.data, op=op, group=group)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment