Commit 70c47cf1 by Wenjie Huang

fix route remap

parent b294749e
......@@ -315,10 +315,15 @@ class Route:
@staticmethod
def __remap_ind(ind: Tensor, mask: Tensor) -> Tuple[Tensor, int]:
n: int = mask.count_nonzero().item()
idx = torch.where(mask)[0]
imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
imp[mask] = torch.arange(n, dtype=ind.dtype, device=ind.device)
return ind, int(n)
imp[idx] = torch.arange(idx.numel(), dtype=ind.dtype, device=ind.device)
return imp[ind], idx.numel()
# n: int = mask.count_nonzero().item()
# imp = torch.full((mask.numel(),), (2**62-1)*2+1, dtype=ind.dtype, device=ind.device)
# imp[mask] = torch.arange(n, dtype=ind.dtype, device=ind.device)
# return ind, int(n)
@staticmethod
def __backward_fw_tables(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment