Commit e4ac1c7a by Wenjie Huang

support cpu training

parent f4af3231
......@@ -138,7 +138,10 @@ class NodeHandle:
return x, fut
def async_select(self, src: Union[Tensor, TensorBuffer], idx: Tensor) -> torch.futures.Future[Tensor]:
if self._device != torch.device("cpu"):
fut = torch.futures.Future(devices=[self._device])
else:
fut = torch.futures.Future()
def run():
try:
with torch.no_grad():
......@@ -157,7 +160,10 @@ class NodeHandle:
def async_update(self, src: Union[Tensor, TensorBuffer], val: Tensor, idx: Tensor, ops: str = "mov") -> torch.futures.Future:
assert ops in ["mov", "add"]
if self._device != torch.device("cpu"):
fut = torch.futures.Future(devices=[self._device])
else:
fut = torch.futures.Future()
def run():
try:
with torch.no_grad():
......
......@@ -6,7 +6,7 @@ import torch.distributed.rpc as rpc
import os
from typing import *
from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment