Commit 035ce537 by Wenjie Huang

apply async_execution() to DistTensor.[methods]

parent 1fa9b9fa
......@@ -2,12 +2,8 @@
#include "uvm.h"
#include "partition.h"
torch::Tensor add(torch::Tensor a, torch::Tensor b) {
return a + b;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add", &add, "a function implemented using pybind11");
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
......
......@@ -11,7 +11,7 @@ from contextlib import contextmanager
import logging
from .rpc import rpc_remote_call, rpc_remote_void_call
from .rpc import *
......@@ -158,6 +158,9 @@ class DistributedContext:
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs)
def remote_exec(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_exec(method, rref, *args, **kwargs)
@contextmanager
def use_stream(self, stream: torch.cuda.Stream, with_event: bool = True):
event = torch.cuda.Event() if with_event else None
......
......@@ -7,6 +7,7 @@ from typing import *
__all__ = [
"rpc_remote_call",
"rpc_remote_void_call",
"rpc_remote_exec"
]
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
......@@ -24,3 +25,12 @@ def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
method(self, *args, **kwargs) # return None
def rpc_remote_exec(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_exec, args=args, kwargs=kwargs)
@rpc.functions.async_execution
def rpc_method_exec(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
......@@ -17,7 +17,6 @@ class TensorAccessor:
self._data = data
self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data)
self._rref.confirmed_by_owner
@property
def data(self):
......@@ -31,52 +30,61 @@ class TensorAccessor:
def ctx(self):
return self._ctx
# def all_gather_index(self,index,input_split) -> Tensor:
# out_split = torch.empty_like(input_split)
# torch.distributed.all_to_all_single(out_split,input_split)
# input_split = list(input_split)
# output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
# out_split = list(out_split)
# torch.distributed.all_to_all_single(output,index,out_split,input_split)
# return output,out_split,input_split
# def all_gather_data(self,index,input_split,out_split):
# output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
# torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
# return output
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(Tensor.index_select, rref, dim=dim, index=index)
return self.ctx.remote_exec(TensorAccessor._index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_void_call(Tensor.index_copy_, rref, dim=dim, index=index, source=source)
return self.ctx.remote_exec(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_void_call(Tensor.index_add_, rref, dim=dim, index=index, source=source)
def async_index_fill_(self, dim: int, index: Tensor, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_void_call(Tensor.index_fill_, rref, dim=dim, index=index, value=value)
def async_fill_(self, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_void_call(Tensor.fill_, rref, value=value)
def async_zero_(self, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
self.ctx.remote_void_call(Tensor.zero_, rref)
return self.ctx.remote_exec(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
@staticmethod
def _index_select(data: Tensor, dim: int, index: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data = data.index_select(dim, index)
fut = torch.futures.Future()
fut.set_result(data)
return fut
@staticmethod
def _index_copy_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_copy_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def _index_add_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_add_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def get_stream() -> Optional[torch.cuda.Stream]:
global _TENSOR_ACCESSOR_STREAM
if torch.cuda.is_available():
return None
if _TENSOR_ACCESSOR_STREAM is None:
_TENSOR_ACCESSOR_STREAM = torch.cuda.Stream()
return _TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM: Optional[torch.cuda.Stream] = None
class DistInt:
......@@ -228,10 +236,6 @@ class DistributedTensor:
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
......@@ -276,16 +280,3 @@ class DistributedTensor:
futs.append(f)
return torch.futures.collect_all(futs)
\ No newline at end of file
def index_fill_(self, dist_index: Union[Tensor, DistIndex], value: Number):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
mask = part_idx == i
f = self.accessor.async_index_fill_(0, index[mask], value, self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment