Commit f84aa32e by xxx

Merge branch 'master' of http://192.168.1.53:8082/wjie98/starrygl into hzq

parents 64861da6 2177c0ed
...@@ -110,6 +110,7 @@ add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS}) ...@@ -110,6 +110,7 @@ add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include") target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME}) target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
......
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
#dist.barrier()
#for i in range(100):
# print(i)
dist.barrier()
idx = ((graph.eids_mapper >> 48).int() & 0xFFFF)
print((idx==0).nonzero().shape,(idx==1).nonzero().shape)
t1 = time.time()
"""
fut = []
for i in range(1000):
#print(i)
out = graph.edge_attr.index_select(graph.eids_mapper[(idx== 0)|(idx ==1)].to('cuda'))
fut.append(out)
#out.wait()
#out.value()
if i>0 and i%100==0:
f = torch.futures.collect_all(fut)
f.wait()
f.value()
fut = []
"""
partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(3) ],device = 'cuda')
for i in range(1000):
if i%100==0:
idx = graph.eids_mapper.to('cuda')
idx,inv = idx.unique(return_inverse=True)
ind = torch.searchsorted(idx,partptr,right=False)
len = ind[1:]-ind[:-1]
gatherlen = torch.empty([2],dtype = torch.long,device = 'cuda')
dist.all_to_all_single(gatherlen,len)
query_idx = torch.empty([gatherlen.sum()],dtype = torch.long,device = 'cuda')
input_s = list(len)
output_s = list(gatherlen)
dist.all_to_all_single(query_idx,idx,output_s,input_s)
input_f = graph.edge_attr.accessor.data[DistIndex(query_idx).loc]
f = torch.empty([idx.shape[0],graph.edge_attr.accessor.data.shape[1]],dtype=torch.float,device='cuda')
dist.all_to_all_single(f,input_f,input_s,output_s)
torch.cuda.synchronize()
t2 = time.time()-t1
print(t2)
#dist.barrier()
ctx.shutdown()
if __name__ == "__main__":
main()
...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected ...@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp import os.path as osp
import sys import sys
from starrygl.graph import GraphData from starrygl.data import GraphData
import logging import logging
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
......
...@@ -2,12 +2,8 @@ ...@@ -2,12 +2,8 @@
#include "uvm.h" #include "uvm.h"
#include "partition.h" #include "partition.h"
torch::Tensor add(torch::Tensor a, torch::Tensor b) {
return a + b;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add", &add, "a function implemented using pybind11");
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory"); m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device"); m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu"); m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
......
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
ERROR:root:unable to import libstarrygl.so, some features may not be available.
the number of nodes in graph is 1980, the number of edges in graph is 1293103
directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_1' not empty and cleared
running partition algorithm: metis_for_tgnn
saving partition data: 1/1
running partition algorithm: metis_for_tgnn
saving partition data: 1/2
saving partition data: 2/2
creating directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_4'
running partition algorithm: metis_for_tgnn
saving partition data: 1/4
saving partition data: 2/4
saving partition data: 3/4
saving partition data: 4/4
...@@ -5,7 +5,7 @@ from torch import Tensor ...@@ -5,7 +5,7 @@ from torch import Tensor
from typing import * from typing import *
from starrygl.distributed import DistributedContext from starrygl.distributed import DistributedContext
from starrygl.graph import * from starrygl.data import *
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
...@@ -76,4 +76,14 @@ if __name__ == "__main__": ...@@ -76,4 +76,14 @@ if __name__ == "__main__":
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device) # dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max")) # ctx.sync_print(route.fw_tensor(dst_true, "max"))
spb = g.to_sparse()
x = torch.randn(g.node("dst").num_nodes, 64, device=ctx.device).requires_grad_()
ctx.sync_print(x[:,0])
y = spb.apply(x)
ctx.sync_print(y[:,0])
y.sum().backward()
ctx.sync_print(x.grad[:,0])
ctx.shutdown() ctx.shutdown()
#!/bin/sh
#conda activate gnn
cd ./starrygl/sample/sample_core
if [ -f "setup.py" ]; then
rm -r build
rm sample_cores.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../part_utils
if [ -f "setup.py" ]; then
rm -r build
rm torch_utils.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../../
from .graph import *
from .utils import *
\ No newline at end of file
...@@ -8,13 +8,15 @@ from pathlib import Path ...@@ -8,13 +8,15 @@ from pathlib import Path
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from starrygl.utils.partition import * from starrygl.utils.partition import *
from .route import Route from starrygl.parallel.route import Route
from starrygl.parallel.sparse import SparseBlocks
from .utils import init_vc_edge_index
import logging import logging
__all__ = [ __all__ = [
"GraphData", "GraphData",
"init_vc_edge_index",
] ]
...@@ -85,6 +87,17 @@ class GraphData: ...@@ -85,6 +87,17 @@ class GraphData:
dst_ids = self.node("dst")["raw_ids"] dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group) return Route.from_raw_indices(src_ids, dst_ids, group=group)
def to_sparse(self, key: Optional[str] = None, group: Any = None) -> SparseBlocks:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
edge_index = self.edge_index()
edge_index = torch.vstack([
src_ids[edge_index[0]],
dst_ids[edge_index[1]],
])
edge_attr = None if key is None else self.edge()[key]
return SparseBlocks.from_raw_indices(dst_ids, edge_index, edge_attr=edge_attr, group=group)
@property @property
def is_heterogeneous(self) -> bool: def is_heterogeneous(self) -> bool:
return self._heterogeneous return self._heterogeneous
...@@ -194,15 +207,6 @@ class GraphData: ...@@ -194,15 +207,6 @@ class GraphData:
raw_dst_ids, local_edges, bipartite=True, raw_dst_ids, local_edges, bipartite=True,
) )
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g = GraphData.from_bipartite( g = GraphData.from_bipartite(
local_edges, local_edges,
raw_src_ids=raw_src_ids, raw_src_ids=raw_src_ids,
...@@ -321,139 +325,3 @@ class EdgeData: ...@@ -321,139 +325,3 @@ class EdgeData:
def to(self, device: Any) -> 'EdgeData': def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()} self._data = {k:v.to(device) for k,v in self._data.items()}
return self return self
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
# def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
# p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
# return torch.load(p.__str__())
# def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
# root_path = Path(root).expanduser().resolve()
# base_path = root_path / f"{algo}_{num_parts}"
# if base_path.exists():
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
# for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
# logging.info(f"saving partition data: {i+1}/{num_parts}")
# torch.save(g, (base_path / f"{i:03d}").__str__())
# def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
# from torch_geometric.data import Data
# assert isinstance(data, Data), f"must be Data class in pyg"
# logging.info(f"running partition aglorithm: {algo}")
# num_nodes: int = data.num_nodes
# num_edges: int = data.num_edges
# edge_index: Tensor = data.edge_index
# if algo == "metis":
# node_parts = metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "mt-metis":
# node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
# elif algo == "random":
# node_parts = random_partition(edge_index, num_nodes, num_parts)
# else:
# raise ValueError(f"unknown partition algorithm: {algo}")
# if data.y.dtype == torch.long:
# if data.y.dim() == 1:
# num_classes = data.y.max().item() + 1
# else:
# num_classes = data.y.size(1)
# else:
# num_classes = None
# for i in range(num_parts):
# npart_mask = node_parts == i
# epart_mask = npart_mask[edge_index[1]]
# local_edges = edge_index[:, epart_mask]
# raw_src_ids: Tensor = local_edges[0].unique()
# raw_dst_ids: Tensor = torch.where(npart_mask)[0]
# M: int = raw_src_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
# imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
# local_src = imap[local_edges[0]]
# M: int = raw_dst_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
# imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
# local_dst = imap[local_edges[1]]
# local_edges = torch.vstack([local_src, local_dst])
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# },
# )
# g.node("src")["raw_ids"] = raw_src_ids
# g.node("dst")["raw_ids"] = raw_dst_ids
# if num_classes is not None:
# g.meta()["num_classes"] = num_classes
# for key, val in data:
# if key == "edge_index":
# continue
# elif isinstance(val, Tensor):
# if val.size(0) == num_nodes:
# g.node("dst")[key] = val[npart_mask]
# elif val.size(0) == num_edges:
# g.edge()[key] = val[epart_mask]
# elif isinstance(val, SparseTensor):
# pass
# else:
# g.meta()[key] = val
# yield g
import torch
from torch import Tensor
from typing import *
__all__ = [
"init_vc_edge_index",
]
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
...@@ -8,33 +8,45 @@ __all__ = [ ...@@ -8,33 +8,45 @@ __all__ = [
"all_to_all_v", "all_to_all_v",
"all_to_all_s", "all_to_all_s",
"BatchWork", "BatchWork",
"batch_send",
"batch_recv",
] ]
class BatchWork: class BatchWork:
def __init__(self, works, buffer_tensor_list) -> None: def __init__(self,
if isinstance(works, (list, tuple)): works: Optional[List[Any]],
assert len(works) // 2 == len(buffer_tensor_list) buffer_tensor_list: Optional[List[Tuple[Tensor, Optional[Tensor]]]],
self._works = works step: int = 1,
self._buffer_tensor_list = buffer_tensor_list ) -> None:
if works is None:
self._step = None
self._works = None
self._buffer_tensor_list = None
else: else:
assert self._buffer_tensor_list is None if buffer_tensor_list:
assert len(works) // step == len(buffer_tensor_list)
self._step = step
self._works = works self._works = works
self._buffer_tensor_list = None self._buffer_tensor_list = buffer_tensor_list
def wait(self): def wait(self):
if self._works is None: if self._works is None:
return return
if isinstance(self._works, (list, tuple)):
for i, w in enumerate(self._works): for i, w in enumerate(self._works):
if w is not None:
w.wait() w.wait()
if i % 2 == 0:
if (i + 1) % self._step != 0:
continue continue
out, buf = self._buffer_tensor_list[i // 2]
if self._buffer_tensor_list:
out, buf = self._buffer_tensor_list[i // self._step]
if buf is not None: if buf is not None:
out.copy_(buf) out.copy_(buf)
else:
self._works.wait() self._step = None
self._works = None self._works = None
self._buffer_tensor_list = None self._buffer_tensor_list = None
...@@ -60,7 +72,7 @@ def all_to_all_v( ...@@ -60,7 +72,7 @@ def all_to_all_v(
group=group, group=group,
async_op=async_op, async_op=async_op,
) )
return BatchWork(work, None) if async_op else None return BatchWork([work], None) if async_op else None
elif backend == "mpi": elif backend == "mpi":
work = dist.all_to_all( work = dist.all_to_all(
...@@ -69,7 +81,7 @@ def all_to_all_v( ...@@ -69,7 +81,7 @@ def all_to_all_v(
group=group, group=group,
async_op=async_op, async_op=async_op,
) )
return BatchWork(work, None) if async_op else None return BatchWork([work], None) if async_op else None
else: else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo" assert backend == "gloo", f"backend must be nccl, mpi or gloo"
...@@ -98,7 +110,7 @@ def all_to_all_v( ...@@ -98,7 +110,7 @@ def all_to_all_v(
dist.irecv(recv_b, recv_i, group=group), dist.irecv(recv_b, recv_i, group=group),
]) ])
work = BatchWork(p2p_op_works, buffer_tensor_list) work = BatchWork(p2p_op_works, buffer_tensor_list, 2)
output_tensor_list[rank].copy_(input_tensor_list[rank]) output_tensor_list[rank].copy_(input_tensor_list[rank])
if async_op: if async_op:
...@@ -130,3 +142,61 @@ def all_to_all_s( ...@@ -130,3 +142,61 @@ def all_to_all_s(
group=group, group=group,
async_op=async_op, async_op=async_op,
) )
def batch_send(
*tensors: Tensor,
dst: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
if async_op:
works = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
works.append(dist.isend(t, dst=dst, group=group))
return BatchWork(works, None)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
t = t.cpu()
dist.send(t, dst=dst, group=group)
def batch_recv(
*tensors: Tensor,
src: int,
group: Any = None,
async_op: bool = False,
):
if len(tensors) == 0:
return BatchWork(None, None)
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
if async_op:
works = []
output_tensor_list = []
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
works.append(dist.irecv(b, src=src, group=group))
else:
b = None
works.append(dist.irecv(t, src=src, group=group))
output_tensor_list.append((t, b))
return BatchWork(works, output_tensor_list, 1)
else:
for t in tensors:
if backend == "gloo" and t.is_cuda:
b = torch.empty_like(t, device="cpu")
dist.recv(b, src=src, group=group)
t.copy_(b)
else:
dist.recv(t, src=src, group=group)
...@@ -11,7 +11,7 @@ from contextlib import contextmanager ...@@ -11,7 +11,7 @@ from contextlib import contextmanager
import logging import logging
from .rpc import rpc_remote_call, rpc_remote_void_call from .rpc import *
...@@ -158,6 +158,9 @@ class DistributedContext: ...@@ -158,6 +158,9 @@ class DistributedContext:
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs): def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs) return rpc_remote_void_call(method, rref, *args, **kwargs)
def remote_exec(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_exec(method, rref, *args, **kwargs)
@contextmanager @contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True): def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None event = torch.cuda.Event() if with_event else None
......
...@@ -7,6 +7,7 @@ from typing import * ...@@ -7,6 +7,7 @@ from typing import *
__all__ = [ __all__ = [
"rpc_remote_call", "rpc_remote_call",
"rpc_remote_void_call", "rpc_remote_void_call",
"rpc_remote_exec"
] ]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs): def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
...@@ -24,3 +25,12 @@ def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs): ...@@ -24,3 +25,12 @@ def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs): def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value() self = rref.local_value()
method(self, *args, **kwargs) # return None method(self, *args, **kwargs) # return None
def rpc_remote_exec(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_exec, args=args, kwargs=kwargs)
@rpc.functions.async_execution
def rpc_method_exec(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
...@@ -17,7 +17,6 @@ class TensorAccessor: ...@@ -17,7 +17,6 @@ class TensorAccessor:
self._data = data self._data = data
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data) self._rref = rpc.RRef(data)
self._rref.confirmed_by_owner
@property @property
def data(self): def data(self):
...@@ -31,52 +30,61 @@ class TensorAccessor: ...@@ -31,52 +30,61 @@ class TensorAccessor:
def ctx(self): def ctx(self):
return self._ctx return self._ctx
# def all_gather_index(self,index,input_split) -> Tensor:
# out_split = torch.empty_like(input_split)
# torch.distributed.all_to_all_single(out_split,input_split)
# input_split = list(input_split)
# output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
# out_split = list(out_split)
# torch.distributed.all_to_all_single(output,index,out_split,input_split)
# return output,out_split,input_split
# def all_gather_data(self,index,input_split,out_split):
# output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
# torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
# return output
def all_gather_rrefs(self) -> List[rpc.RRef]: def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref) return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None): def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_call(Tensor.index_select, rref, dim=dim, index=index) return self.ctx.remote_exec(TensorAccessor._index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None): def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_void_call(Tensor.index_copy_, rref, dim=dim, index=index, source=source) return self.ctx.remote_exec(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None): def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_void_call(Tensor.index_add_, rref, dim=dim, index=index, source=source) return self.ctx.remote_exec(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
def async_index_fill_(self, dim: int, index: Tensor, value: Number, rref: Optional[rpc.RRef] = None): @staticmethod
if rref is None: def _index_select(data: Tensor, dim: int, index: Tensor):
rref = self.rref stream = TensorAccessor.get_stream()
return self.ctx.remote_void_call(Tensor.index_fill_, rref, dim=dim, index=index, value=value) with torch.cuda.stream(stream):
data = data.index_select(dim, index)
def async_fill_(self, value: Number, rref: Optional[rpc.RRef] = None): fut = torch.futures.Future()
if rref is None: fut.set_result(data)
rref = self.rref return fut
return self.ctx.remote_void_call(Tensor.fill_, rref, value=value)
@staticmethod
def async_zero_(self, rref: Optional[rpc.RRef] = None): def _index_copy_(data: Tensor, dim: int, index: Tensor, source: Tensor):
if rref is None: stream = TensorAccessor.get_stream()
rref = self.rref with torch.cuda.stream(stream):
self.ctx.remote_void_call(Tensor.zero_, rref) data.index_copy_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def _index_add_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_add_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def get_stream() -> Optional[torch.cuda.Stream]:
global _TENSOR_ACCESSOR_STREAM
if torch.cuda.is_available():
return None
if _TENSOR_ACCESSOR_STREAM is None:
_TENSOR_ACCESSOR_STREAM = torch.cuda.Stream()
return _TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM: Optional[torch.cuda.Stream] = None
class DistInt: class DistInt:
...@@ -132,8 +140,11 @@ class DistributedTensor: ...@@ -132,8 +140,11 @@ class DistributedTensor:
for rref in self.rrefs: for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait() n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n) local_sizes.append(n)
self._num_nodes = DistInt(local_sizes) self._num_nodes: int = sum(local_sizes)
self._num_parts = DistInt([1] * len(self.rrefs)) self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property @property
def dtype(self): def dtype(self):
...@@ -144,11 +155,19 @@ class DistributedTensor: ...@@ -144,11 +155,19 @@ class DistributedTensor:
return self.accessor.data.device return self.accessor.data.device
@property @property
def num_nodes(self) -> DistInt: def num_nodes(self) -> int:
return self._num_nodes return self._num_nodes
@property @property
def num_parts(self) -> DistInt: def num_part_nodes(self) -> tuple[int,...]:
return self._num_part_nodes
@property
def part_id(self) -> int:
return self._part_id
@property
def num_parts(self) -> int:
return self._num_parts return self._num_parts
def to(self,device): def to(self,device):
...@@ -227,11 +246,7 @@ class DistributedTensor: ...@@ -227,11 +246,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i]) f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f) futs.append(f)
...@@ -276,16 +291,3 @@ class DistributedTensor: ...@@ -276,16 +291,3 @@ class DistributedTensor:
futs.append(f) futs.append(f)
return torch.futures.collect_all(futs) return torch.futures.collect_all(futs)
\ No newline at end of file
def index_fill_(self, dist_index: Union[Tensor, DistIndex], value: Number):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
mask = part_idx == i
f = self.accessor.async_index_fill_(0, index[mask], value, self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
from .data import *
from .route import *
\ No newline at end of file
import torch from .route import *
import torch.nn as nn from .sequence import *
import torch.distributed as dist from .sparse import *
import torch.distributed.rpc as rpc \ No newline at end of file
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
from typing import Any
import torch
import torch.distributed as dist
import torch.autograd as autograd
from torch import Tensor
from typing import *
from torch_sparse import SparseTensor
__all__ = [
"SparseBlocks",
]
class SparseBlocks:
@staticmethod
def from_raw_indices(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Optional[Tensor] = None,
edge_attr: Optional[Tensor] = None,
group: Any = None,
) -> 'SparseBlocks':
assert edge_index.dim() == 2 and edge_index.size(0) == 2
if src_ids is None:
src_ids = dst_ids
src_ids, src_ptr = SparseBlocks.__fetch_ids_sizes(src_ids, group=group)
adj_ts = SparseBlocks.__remap_adj_t(dst_ids, edge_index, src_ids, src_ptr, edge_attr)
return SparseBlocks(adj_ts, group=group)
def __init__(self, adj_ts: List[SparseTensor], group: Any) -> None:
self._adj_ts = adj_ts
self._group = group
def adj_t(self, i: int) -> SparseTensor:
return self._adj_ts[i]
@property
def group(self):
return self._group
@property
def part_id(self) -> int:
return dist.get_rank(self._group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self._group)
def apply(self, x: Tensor) -> Tensor:
return SparseBlockMM.apply(self, x)
@staticmethod
def __fetch_ids_sizes(local_ids: Tensor, group: Any):
assert local_ids.dim() == 1
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=local_ids.device)
all_lens: List[int] = [None] * world_size
dist.all_gather_object(all_lens, local_ids.numel(), group=group)
# all reduce num_nodes
num_nodes = torch.zeros(1, **ikw)
if local_ids.numel() > 0:
num_nodes = local_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes: int = num_nodes.item() + 1
# async fetch remote ids
all_ids: List[Tensor] = [None] * world_size
all_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_ids[i] = local_ids
else:
all_ids[i] = torch.empty(all_lens[i], **ikw)
all_get[i] = dist.broadcast(
all_ids[i], src=i, async_op=True, group=group
)
imp: Tensor = torch.full((num_nodes,), (2**62-1)*2+1, **ikw)
offset: int = 0
for i in range(world_size):
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_get[i].wait()
ids = all_ids[i]
assert (imp[ids] == (2**62-1)*2+1).all(), "all ids must be orthogonal."
imp[ids] = torch.arange(offset, offset + all_lens[i], **ikw)
offset += all_lens[i]
assert (imp != (2**62-1)*2+1).all(), "some points that do not exist."
ids = torch.cat(all_ids, dim=0)
ptr: List[int] = [0]
for s in all_lens:
ptr.append(ptr[-1] + s)
return ids, ptr
@staticmethod
def __remap_adj_t(
dst_ids: Tensor,
edge_index: Tensor,
src_ids: Tensor,
src_ptr: List[int],
edge_attr: Optional[Tensor],
) -> List[SparseTensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
imp: Tensor = torch.full((dst_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[dst_ids] = torch.arange(dst_ids.numel(), **ikw)
dst = imp[edge_index[1]]
assert (dst != (2**62-1)*2+1).all()
imp: Tensor = torch.full((src_ids.max().item()+1,), (2**62-1)*2+1, **ikw)
imp[src_ids] = torch.arange(src_ids.numel(), **ikw)
src = imp[edge_index[0]]
assert (src != (2**62-1)*2+1).all()
edge_index = torch.vstack([src, dst])
adj = SparseTensor.from_edge_index(
edge_index=edge_index,
edge_attr=edge_attr,
sparse_sizes=(src_ids.numel(), dst_ids.numel()),
)
adj_ts: List[SparseTensor] = []
for s, t in zip(src_ptr, src_ptr[1:]):
adj_ts.append(adj[s:t].t())
return adj_ts
class SparseBlockMM(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
sp: SparseBlocks,
x: Tensor,
):
part_id = sp.part_id
num_parts = sp.num_parts
def async_fetch(i: int):
n = sp.adj_t(i).sparse_size(1)
if i == part_id:
h = x.clone()
else:
h = torch.empty(n, *x.shape[1:], dtype=x.dtype, device=x.device)
return dist.broadcast(h, src=i, group=sp.group, async_op=True)
last_work = None
out = None
for i in range(num_parts):
if i == 0:
work = async_fetch(0)
else:
work = last_work
if i + 1 < sp.num_parts:
last_work = async_fetch(i + 1)
work.wait()
h, = work.result()
if out is None:
out = sp.adj_t(i) @ h
else:
out += sp.adj_t(i) @ h
ctx.saved_sp = sp
return out
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
sp: SparseBlocks = ctx.saved_sp
part_id = sp.part_id
num_parts = sp.num_parts
def async_reduce(i: int, g: Tensor):
return dist.reduce(
g, dst=i, op=dist.ReduceOp.SUM,
group=sp.group, async_op=True,
)
out = None
last_work = None
for i in range(num_parts):
g = sp.adj_t(i).t() @ grad
if i > 0:
last_work.wait()
last_work = async_reduce(i, g)
if i == part_id:
out = g
if last_work is not None:
last_work.wait()
return None, out
...@@ -8,6 +8,9 @@ from torch import Tensor ...@@ -8,6 +8,9 @@ from torch import Tensor
from typing import * from typing import *
__all__ = [
"SyncBatchNorm",
]
class SyncBatchNorm(nn.Module): class SyncBatchNorm(nn.Module):
def __init__(self, def __init__(self,
......
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from collections import defaultdict
__all__ = [
"all_reduce_gradients",
"all_reduce_buffers",
]
# def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
# works = []
# for p in net.parameters():
# if p.grad is None:
# p.grad = torch.zeros_like(p.data)
# w = dist.all_reduce(p.grad, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
# def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
# works = []
# for b in net.buffers():
# w = dist.all_reduce(b.data, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.parameters():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_tensors[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.parameters():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
if p.grad is not None:
typed_tensors[p.dtype][s:t] = p.grad.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
g = torch.empty(0, dtype=p.dtype, device=device)
p.grad = g.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
device = None
works = []
if op is None:
return works
typed_numel = defaultdict(lambda: 0)
for p in net.buffers():
typed_numel[p.dtype] += p.numel()
device = p.device
if device is None:
return works
typed_tensors: Dict[torch.dtype, Tensor] = {}
for t, n in typed_numel.items():
typed_numel[t] = torch.zeros(n, dtype=t, device=device)
typed_offset = defaultdict(lambda: 0)
for p in net.buffers():
s = typed_offset[p.dtype]
t = s + p.numel()
typed_offset[p.dtype] = t
typed_tensors[p.dtype][s:t] = p.flatten()
storage = typed_tensors[p.dtype].untyped_storage()
p.set_(storage, s, p.size(), default_stride(*p.size()))
for t in typed_tensors.values():
w = dist.all_reduce(t, op=op, group=group, async_op=async_op)
if async_op:
works.append(w)
return works
def default_stride(*size: int) -> Tuple[int,...]:
dims = len(size)
stride = [1] * dims
for i in range(1, dims):
k = dims - i
stride[k - 1] = stride[k] * size[k]
return tuple(stride)
# from .functional import train_epoch, eval_epoch from .partition import *
# from .partition import partition_load, partition_save, partition_data from .uvm import *
from .printer import sync_print, main_print \ No newline at end of file
from .metrics import all_reduce_loss, accuracy
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment