Commit 3d03c937 by Wenjie Huang

add rpc_remote_void_call() which returns None always.

parent 649ec368
...@@ -10,7 +10,7 @@ from typing import * ...@@ -10,7 +10,7 @@ from typing import *
import logging import logging
from .cclib import all_to_all_v from .cclib import all_to_all_v
from .rpc import rpc_remote_call from .rpc import rpc_remote_call, rpc_remote_void_call
...@@ -154,6 +154,9 @@ class DistributedContext: ...@@ -154,6 +154,9 @@ class DistributedContext:
def remote_call(self, method, rref: rpc.RRef, *args, **kwargs): def remote_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_call(method, rref, *args, **kwargs) return rpc_remote_call(method, rref, *args, **kwargs)
def remote_void_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_void_call(method, rref, *args, **kwargs)
def all_to_all_v(self, def all_to_all_v(self,
output_tensor_list: List[Tensor], output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor], input_tensor_list: List[Tensor],
......
...@@ -12,3 +12,11 @@ def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs): ...@@ -12,3 +12,11 @@ def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
def rpc_method_call(method, rref: rpc.RRef, *args, **kwargs): def rpc_method_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value() self = rref.local_value()
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_void_call, args=args, kwargs=kwargs)
def rpc_method_void_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
method(self, *args, **kwargs) # return None
...@@ -39,47 +39,27 @@ class TensorAccessor: ...@@ -39,47 +39,27 @@ class TensorAccessor:
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None): def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source) return self.ctx.remote_void_call(Tensor.index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None): def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source) return self.ctx.remote_void_call(Tensor.index_add_, rref, dim=dim, index=index, source=source)
def async_index_fill_(self, dim: int, index: Tensor, value: Number, rref: Optional[rpc.RRef] = None): def async_index_fill_(self, dim: int, index: Tensor, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_fill_, rref, dim=dim, index=index, value=value) return self.ctx.remote_void_call(Tensor.index_fill_, rref, dim=dim, index=index, value=value)
def async_fill_(self, value: Number, rref: Optional[rpc.RRef] = None): def async_fill_(self, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
return self.ctx.remote_call(TensorAccessor._fill_, rref, value=value) return self.ctx.remote_void_call(Tensor.fill_, rref, value=value)
def async_zero_(self, rref: Optional[rpc.RRef] = None): def async_zero_(self, rref: Optional[rpc.RRef] = None):
if rref is None: if rref is None:
rref = self.rref rref = self.rref
self.ctx.remote_call(TensorAccessor._zero_, rref) self.ctx.remote_void_call(Tensor.zero_, rref)
@staticmethod
def _index_copy_(self: Tensor, dim: int, index: Tensor, source: Tensor):
self.index_copy_(dim=dim, index=index, source=source)
@staticmethod
def _index_add_(self: Tensor, dim: int, index: Tensor, source: Tensor):
self.index_add_(dim=dim, index=index, source=source)
@staticmethod
def _index_fill_(self: Tensor, dim: int, index: Tensor, value: Number):
self.index_fill_(dim=dim, index=index, value=value)
@staticmethod
def _fill_(self: Tensor, value: Number):
self.fill_(value=value)
@staticmethod
def _zero_(self: Tensor):
self.zero_()
class DistInt: class DistInt:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment