Commit 057dc57f by Wenjie Huang

disable rpc by default

parent f84aa32e
...@@ -24,6 +24,7 @@ class DistributedContext: ...@@ -24,6 +24,7 @@ class DistributedContext:
@staticmethod @staticmethod
def init( def init(
backend: str, backend: str,
use_rpc: bool = False,
use_gpu: Optional[bool] = None, use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None, rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext': ) -> 'DistributedContext':
...@@ -63,7 +64,9 @@ class DistributedContext: ...@@ -63,7 +64,9 @@ class DistributedContext:
rpc_init_method=rpc_init_url, rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size, rank=rank, world_size=world_size,
local_rank=local_rank, local_rank=local_rank,
use_gpu=use_gpu, rpc_gpu=rpc_gpu, use_rpc=use_rpc,
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
) )
_set_default_dist_context(ctx) _set_default_dist_context(ctx)
...@@ -86,7 +89,9 @@ class DistributedContext: ...@@ -86,7 +89,9 @@ class DistributedContext:
rpc_init_method: str, rpc_init_method: str,
rank: int, world_size: int, rank: int, world_size: int,
local_rank: int, local_rank: int,
use_gpu: bool, rpc_gpu: bool, use_rpc: bool,
use_gpu: bool,
rpc_gpu: bool,
) -> None: ) -> None:
if use_gpu: if use_gpu:
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
...@@ -112,11 +117,13 @@ class DistributedContext: ...@@ -112,11 +117,13 @@ class DistributedContext:
device_map={local_rank: dev}, device_map={local_rank: dev},
) )
if use_rpc:
rpc.init_rpc( rpc.init_rpc(
name=f"worker{rank}", name=f"worker{rank}",
rank=rank, world_size=world_size, rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options, rpc_backend_options=rpc_backend_options,
) )
self._use_rpc = use_rpc
self._local_rank = local_rank self._local_rank = local_rank
self._compute_device = device self._compute_device = device
...@@ -124,6 +131,7 @@ class DistributedContext: ...@@ -124,6 +131,7 @@ class DistributedContext:
self.__temp_ag_remote_object: Optional[rpc.RRef] = None self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self): def shutdown(self):
if self._use_rpc:
rpc.shutdown() rpc.shutdown()
@property @property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment