Commit 66045271 by Wenjie Huang

add DistTensor.all_to_all_[get|set]()

parent 510a41f8
......@@ -4,6 +4,12 @@ import torch.distributed as dist
from torch import Tensor
from typing import *
__all__ = [
"all_to_all_v",
"all_to_all_s",
"BatchWork",
]
class BatchWork:
def __init__(self, works, buffer_tensor_list) -> None:
......
......@@ -11,7 +11,6 @@ from contextlib import contextmanager
import logging
from .cclib import all_to_all_v, all_to_all_s
from .rpc import rpc_remote_call, rpc_remote_void_call
......@@ -159,76 +158,6 @@ class DistributedContext:
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs)
# def all_to_all_v(self,
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# def all_to_all_g(self,
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# send_sizes = [t.size(0) for t in input_tensor_list]
# recv_sizes = self.get_all_to_all_recv_sizes(send_sizes, group)
# output_tensor_list: List[Tensor] = []
# for s, t in zip(recv_sizes, input_tensor_list):
# output_tensor_list.append(
# torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device),
# )
# work = all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# if async_op:
# assert work is not None
# return output_tensor_list, work
# else:
# return output_tensor_list
# def all_to_all_s(self,
# output_tensor: Tensor,
# input_tensor: Tensor,
# output_rowptr: List[int],
# input_rowptr: List[int],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_s(
# output_tensor, input_tensor,
# output_rowptr, input_rowptr,
# group=group, async_op=async_op,
# )
# def get_all_to_all_recv_sizes(self,
# send_sizes: List[int],
# group: Optional[Any] = None,
# ) -> List[int]:
# world_size = dist.get_world_size(group)
# assert len(send_sizes) == world_size
# if dist.get_backend(group) == "gloo":
# send_t = torch.tensor(send_sizes, dtype=torch.long)
# else:
# send_t = torch.tensor(send_sizes, dtype=torch.long, device=self.device)
# recv_t = torch.empty_like(send_t)
# dist.all_to_all_single(recv_t, send_t, group=group)
# return recv_t.tolist()
@contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None
......
......@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
__all__ = [
"rpc_remote_call",
"rpc_remote_void_call",
]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
......
from typing import Any
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.types import Number
from typing import *
from torch_sparse import SparseTensor
from .cclib import all_to_all_s
class TensorAccessor:
......@@ -29,19 +31,19 @@ class TensorAccessor:
def ctx(self):
return self._ctx
def all_gather_index(self,index,input_split) -> Tensor:
out_split = torch.empty_like(input_split)
torch.distributed.all_to_all_single(out_split,input_split)
input_split = list(input_split)
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
out_split = list(out_split)
torch.distributed.all_to_all_single(output,index,out_split,input_split)
return output,out_split,input_split
# def all_gather_index(self,index,input_split) -> Tensor:
# out_split = torch.empty_like(input_split)
# torch.distributed.all_to_all_single(out_split,input_split)
# input_split = list(input_split)
# output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
# out_split = list(out_split)
# torch.distributed.all_to_all_single(output,index,out_split,input_split)
# return output,out_split,input_split
def all_gather_data(self,index,input_split,out_split):
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
return output
# def all_gather_data(self,index,input_split,out_split):
# output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
# torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
# return output
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
......@@ -92,40 +94,46 @@ class DistInt:
class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None:
self.data = index.long()
self._data = index.long()
else:
index, part_ids = index.long(), part_ids.long()
self.data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
self._data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property
def loc(self) -> Tensor:
return self.data & 0xFFFFFFFFFFFF
return self._data & 0xFFFFFFFFFFFF
@property
def part(self) -> Tensor:
return (self.data >> 48).int() & 0xFFFF
return (self._data >> 48) & 0xFFFF
@property
def dist(self) -> Tensor:
return self.data
return self._data
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def to(self,device) -> Tensor:
return DistIndex(self.data.to(device))
return DistIndex(self._data.to(device))
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs()
# self.num_parts = len(self.rrefs)
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs))
self.distptr = torch.tensor([((part_ids & 0xFFFF) << 48) for part_ids in range(self.num_parts()+1)],device = 'cuda')#data.device)
self._num_nodes = DistInt(local_sizes)
self._num_parts = DistInt([1] * len(self.rrefs))
@property
def dtype(self):
......@@ -135,6 +143,14 @@ class DistributedTensor:
def device(self):
return self.accessor.data.device
@property
def num_nodes(self) -> DistInt:
return self._num_nodes
@property
def num_parts(self) -> DistInt:
return self._num_parts
def to(self,device):
return self.accessor.data.to(device)
......@@ -145,16 +161,63 @@ class DistributedTensor:
def ctx(self):
return self.accessor.ctx
def gather_select_index(self,dist_index: Union[Tensor,DistIndex]):
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex]) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
data = dist_index.data
posptr = torch.searchsorted(data,self.distptr,right = False)
input_split = posptr[1:] - posptr[:-1]
return self.accessor.all_gather_index(DistIndex(data).loc,input_split)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
recv_ptr = torch.zeros(recv_sizes.numel() + 1).type_as(recv_sizes)
recv_ptr[1:] = recv_sizes.cumsum(dim=0)
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device)
all_to_all_s(recv_ind, dist_index.loc, send_ptr, recv_ptr)
return {
"send_ptr": send_ptr,
"recv_ptr": recv_ptr,
"recv_ind": recv_ind,
}
def all_to_all_get(self,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, send_ptr, recv_ptr)
return recv
def scatter_data(self,local_index,input_split,out_split):
return self.accessor.all_gather_data(local_index,input_split=input_split,out_split=out_split)
def all_to_all_set(self,
data: Tensor,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment