Commit 66045271 by Wenjie Huang

add DistTensor.all_to_all_[get|set]()

parent 510a41f8
...@@ -4,6 +4,12 @@ import torch.distributed as dist ...@@ -4,6 +4,12 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import * from typing import *
__all__ = [
"all_to_all_v",
"all_to_all_s",
"BatchWork",
]
class BatchWork: class BatchWork:
def __init__(self, works, buffer_tensor_list) -> None: def __init__(self, works, buffer_tensor_list) -> None:
......
...@@ -11,7 +11,6 @@ from contextlib import contextmanager ...@@ -11,7 +11,6 @@ from contextlib import contextmanager
import logging import logging
from .cclib import all_to_all_v, all_to_all_s
from .rpc import rpc_remote_call, rpc_remote_void_call from .rpc import rpc_remote_call, rpc_remote_void_call
...@@ -159,76 +158,6 @@ class DistributedContext: ...@@ -159,76 +158,6 @@ class DistributedContext:
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs): def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs) return rpc_remote_void_call(method, rref, *args, **kwargs)
# def all_to_all_v(self,
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# def all_to_all_g(self,
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# send_sizes = [t.size(0) for t in input_tensor_list]
# recv_sizes = self.get_all_to_all_recv_sizes(send_sizes, group)
# output_tensor_list: List[Tensor] = []
# for s, t in zip(recv_sizes, input_tensor_list):
# output_tensor_list.append(
# torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device),
# )
# work = all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# if async_op:
# assert work is not None
# return output_tensor_list, work
# else:
# return output_tensor_list
# def all_to_all_s(self,
# output_tensor: Tensor,
# input_tensor: Tensor,
# output_rowptr: List[int],
# input_rowptr: List[int],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_s(
# output_tensor, input_tensor,
# output_rowptr, input_rowptr,
# group=group, async_op=async_op,
# )
# def get_all_to_all_recv_sizes(self,
# send_sizes: List[int],
# group: Optional[Any] = None,
# ) -> List[int]:
# world_size = dist.get_world_size(group)
# assert len(send_sizes) == world_size
# if dist.get_backend(group) == "gloo":
# send_t = torch.tensor(send_sizes, dtype=torch.long)
# else:
# send_t = torch.tensor(send_sizes, dtype=torch.long, device=self.device)
# recv_t = torch.empty_like(send_t)
# dist.all_to_all_single(recv_t, send_t, group=group)
# return recv_t.tolist()
@contextmanager @contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True): def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None event = torch.cuda.Event() if with_event else None
......
...@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc ...@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc
from torch import Tensor from torch import Tensor
from typing import * from typing import *
__all__ = [
"rpc_remote_call",
"rpc_remote_void_call",
]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs): def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args args = (method, rref) + args
......
from typing import Any
import torch import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch import Tensor from torch import Tensor
from torch.types import Number from torch.types import Number
from typing import * from typing import *
from torch_sparse import SparseTensor
from .cclib import all_to_all_s
class TensorAccessor: class TensorAccessor:
...@@ -29,19 +31,19 @@ class TensorAccessor: ...@@ -29,19 +31,19 @@ class TensorAccessor:
def ctx(self): def ctx(self):
return self._ctx return self._ctx
def all_gather_index(self,index,input_split) -> Tensor: # def all_gather_index(self,index,input_split) -> Tensor:
out_split = torch.empty_like(input_split) # out_split = torch.empty_like(input_split)
torch.distributed.all_to_all_single(out_split,input_split) # torch.distributed.all_to_all_single(out_split,input_split)
input_split = list(input_split) # input_split = list(input_split)
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device) # output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
out_split = list(out_split) # out_split = list(out_split)
torch.distributed.all_to_all_single(output,index,out_split,input_split) # torch.distributed.all_to_all_single(output,index,out_split,input_split)
return output,out_split,input_split # return output,out_split,input_split
def all_gather_data(self,index,input_split,out_split): # def all_gather_data(self,index,input_split,out_split):
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device) # output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split) # torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
return output # return output
def all_gather_rrefs(self) -> List[rpc.RRef]: def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref) return self.ctx.all_gather_remote_objects(self.rref)
...@@ -92,40 +94,46 @@ class DistInt: ...@@ -92,40 +94,46 @@ class DistInt:
class DistIndex: class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None: def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None: if part_ids is None:
self.data = index.long() self._data = index.long()
else: else:
index, part_ids = index.long(), part_ids.long() index, part_ids = index.long(), part_ids.long()
self.data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48) self._data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property @property
def loc(self) -> Tensor: def loc(self) -> Tensor:
return self.data & 0xFFFFFFFFFFFF return self._data & 0xFFFFFFFFFFFF
@property @property
def part(self) -> Tensor: def part(self) -> Tensor:
return (self.data >> 48).int() & 0xFFFF return (self._data >> 48) & 0xFFFF
@property @property
def dist(self) -> Tensor: def dist(self) -> Tensor:
return self.data return self._data
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def to(self,device) -> Tensor: def to(self,device) -> Tensor:
return DistIndex(self.data.to(device)) return DistIndex(self._data.to(device))
class DistributedTensor: class DistributedTensor:
def __init__(self, data: Tensor) -> None: def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data) self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs() self.rrefs = self.accessor.all_gather_rrefs()
# self.num_parts = len(self.rrefs)
local_sizes = [] local_sizes = []
for rref in self.rrefs: for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait() n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n) local_sizes.append(n)
self.num_nodes = DistInt(local_sizes) self._num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs)) self._num_parts = DistInt([1] * len(self.rrefs))
self.distptr = torch.tensor([((part_ids & 0xFFFF) << 48) for part_ids in range(self.num_parts()+1)],device = 'cuda')#data.device)
@property @property
def dtype(self): def dtype(self):
...@@ -135,6 +143,14 @@ class DistributedTensor: ...@@ -135,6 +143,14 @@ class DistributedTensor:
def device(self): def device(self):
return self.accessor.data.device return self.accessor.data.device
@property
def num_nodes(self) -> DistInt:
return self._num_nodes
@property
def num_parts(self) -> DistInt:
return self._num_parts
def to(self,device): def to(self,device):
return self.accessor.data.to(device) return self.accessor.data.to(device)
...@@ -145,16 +161,63 @@ class DistributedTensor: ...@@ -145,16 +161,63 @@ class DistributedTensor:
def ctx(self): def ctx(self):
return self.accessor.ctx return self.accessor.ctx
def gather_select_index(self,dist_index: Union[Tensor,DistIndex]): def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex]) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor): if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index) dist_index = DistIndex(dist_index)
data = dist_index.data send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
posptr = torch.searchsorted(data,self.distptr,right = False)
input_split = posptr[1:] - posptr[:-1] send_sizes = send_ptr[1:] - send_ptr[:-1]
return self.accessor.all_gather_index(DistIndex(data).loc,input_split) recv_sizes = torch.empty_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
def scatter_data(self,local_index,input_split,out_split):
return self.accessor.all_gather_data(local_index,input_split=input_split,out_split=out_split) recv_ptr = torch.zeros(recv_sizes.numel() + 1).type_as(recv_sizes)
recv_ptr[1:] = recv_sizes.cumsum(dim=0)
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device)
all_to_all_s(recv_ind, dist_index.loc, send_ptr, recv_ptr)
return {
"send_ptr": send_ptr,
"recv_ptr": recv_ptr,
"recv_ind": recv_ind,
}
def all_to_all_get(self,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, send_ptr, recv_ptr)
return recv
def all_to_all_set(self,
data: Tensor,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]): def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor): if isinstance(dist_index, Tensor):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment