Commit 057dc57f by Wenjie Huang

disable rpc by default

parent f84aa32e
......@@ -24,6 +24,7 @@ class DistributedContext:
@staticmethod
def init(
backend: str,
use_rpc: bool = False,
use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext':
......@@ -63,7 +64,9 @@ class DistributedContext:
rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size,
local_rank=local_rank,
use_gpu=use_gpu, rpc_gpu=rpc_gpu,
use_rpc=use_rpc,
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
......@@ -86,7 +89,9 @@ class DistributedContext:
rpc_init_method: str,
rank: int, world_size: int,
local_rank: int,
use_gpu: bool, rpc_gpu: bool,
use_rpc: bool,
use_gpu: bool,
rpc_gpu: bool,
) -> None:
if use_gpu:
device = torch.device(f"cuda:{local_rank}")
......@@ -111,12 +116,14 @@ class DistributedContext:
to=f"worker{i}",
device_map={local_rank: dev},
)
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
if use_rpc:
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
self._use_rpc = use_rpc
self._local_rank = local_rank
self._compute_device = device
......@@ -124,7 +131,8 @@ class DistributedContext:
self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self):
rpc.shutdown()
if self._use_rpc:
rpc.shutdown()
@property
def rank(self) -> int:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment