Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
035ce537
Commit
035ce537
authored
Dec 19, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
apply async_execution() to DistTensor.[methods]
parent
1fa9b9fa
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
55 deletions
+56
-55
csrc/export.cpp
+0
-4
starrygl/distributed/context.py
+4
-1
starrygl/distributed/rpc.py
+10
-0
starrygl/distributed/utils.py
+42
-50
No files found.
csrc/export.cpp
View file @
035ce537
...
...
@@ -2,12 +2,8 @@
#include "uvm.h"
#include "partition.h"
torch
::
Tensor
add
(
torch
::
Tensor
a
,
torch
::
Tensor
b
)
{
return
a
+
b
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"add"
,
&
add
,
"a function implemented using pybind11"
);
m
.
def
(
"uvm_storage_new"
,
&
uvm_storage_new
,
"return storage of unified virtual memory"
);
m
.
def
(
"uvm_storage_to_cuda"
,
&
uvm_storage_to_cuda
,
"share uvm storage with another cuda device"
);
m
.
def
(
"uvm_storage_to_cpu"
,
&
uvm_storage_to_cpu
,
"share uvm storage with cpu"
);
...
...
starrygl/distributed/context.py
View file @
035ce537
...
...
@@ -11,7 +11,7 @@ from contextlib import contextmanager
import
logging
from
.rpc
import
rpc_remote_call
,
rpc_remote_void_call
from
.rpc
import
*
...
...
@@ -158,6 +158,9 @@ class DistributedContext:
def
remote_void_call
(
self
,
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
rpc_remote_void_call
(
method
,
rref
,
*
args
,
**
kwargs
)
def
remote_exec
(
self
,
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
rpc_remote_exec
(
method
,
rref
,
*
args
,
**
kwargs
)
@contextmanager
def
use_stream
(
self
,
stream
:
torch
.
cuda
.
Stream
,
with_event
:
bool
=
True
):
event
=
torch
.
cuda
.
Event
()
if
with_event
else
None
...
...
starrygl/distributed/rpc.py
View file @
035ce537
...
...
@@ -7,6 +7,7 @@ from typing import *
__all__
=
[
"rpc_remote_call"
,
"rpc_remote_void_call"
,
"rpc_remote_exec"
]
def
rpc_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
...
...
@@ -24,3 +25,12 @@ def rpc_remote_void_call(method, rref: rpc.RRef, *args, **kwargs):
def
rpc_method_void_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
self
=
rref
.
local_value
()
method
(
self
,
*
args
,
**
kwargs
)
# return None
def
rpc_remote_exec
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
args
=
(
method
,
rref
)
+
args
return
rpc
.
rpc_async
(
rref
.
owner
(),
rpc_method_exec
,
args
=
args
,
kwargs
=
kwargs
)
@rpc.functions.async_execution
def
rpc_method_exec
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
self
=
rref
.
local_value
()
return
method
(
self
,
*
args
,
**
kwargs
)
starrygl/distributed/utils.py
View file @
035ce537
...
...
@@ -17,7 +17,6 @@ class TensorAccessor:
self
.
_data
=
data
self
.
_ctx
=
DistributedContext
.
get_default_context
()
self
.
_rref
=
rpc
.
RRef
(
data
)
self
.
_rref
.
confirmed_by_owner
@property
def
data
(
self
):
...
...
@@ -30,20 +29,6 @@ class TensorAccessor:
@property
def
ctx
(
self
):
return
self
.
_ctx
# def all_gather_index(self,index,input_split) -> Tensor:
# out_split = torch.empty_like(input_split)
# torch.distributed.all_to_all_single(out_split,input_split)
# input_split = list(input_split)
# output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
# out_split = list(out_split)
# torch.distributed.all_to_all_single(output,index,out_split,input_split)
# return output,out_split,input_split
# def all_gather_data(self,index,input_split,out_split):
# output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
# torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
# return output
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
...
...
@@ -51,32 +36,55 @@ class TensorAccessor:
def
async_index_select
(
self
,
dim
:
int
,
index
:
Tensor
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
return
self
.
ctx
.
remote_
call
(
Tensor
.
index_select
,
rref
,
dim
=
dim
,
index
=
index
)
return
self
.
ctx
.
remote_
exec
(
TensorAccessor
.
_
index_select
,
rref
,
dim
=
dim
,
index
=
index
)
def
async_index_copy_
(
self
,
dim
:
int
,
index
:
Tensor
,
source
:
Tensor
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
return
self
.
ctx
.
remote_
void_call
(
Tensor
.
index_copy_
,
rref
,
dim
=
dim
,
index
=
index
,
source
=
source
)
return
self
.
ctx
.
remote_
exec
(
TensorAccessor
.
_
index_copy_
,
rref
,
dim
=
dim
,
index
=
index
,
source
=
source
)
def
async_index_add_
(
self
,
dim
:
int
,
index
:
Tensor
,
source
:
Tensor
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
return
self
.
ctx
.
remote_
void_call
(
Tensor
.
index_add_
,
rref
,
dim
=
dim
,
index
=
index
,
source
=
source
)
return
self
.
ctx
.
remote_
exec
(
TensorAccessor
.
_
index_add_
,
rref
,
dim
=
dim
,
index
=
index
,
source
=
source
)
def
async_index_fill_
(
self
,
dim
:
int
,
index
:
Tensor
,
value
:
Number
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
return
self
.
ctx
.
remote_void_call
(
Tensor
.
index_fill_
,
rref
,
dim
=
dim
,
index
=
index
,
value
=
value
)
def
async_fill_
(
self
,
value
:
Number
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
return
self
.
ctx
.
remote_void_call
(
Tensor
.
fill_
,
rref
,
value
=
value
)
def
async_zero_
(
self
,
rref
:
Optional
[
rpc
.
RRef
]
=
None
):
if
rref
is
None
:
rref
=
self
.
rref
self
.
ctx
.
remote_void_call
(
Tensor
.
zero_
,
rref
)
@staticmethod
def
_index_select
(
data
:
Tensor
,
dim
:
int
,
index
:
Tensor
):
stream
=
TensorAccessor
.
get_stream
()
with
torch
.
cuda
.
stream
(
stream
):
data
=
data
.
index_select
(
dim
,
index
)
fut
=
torch
.
futures
.
Future
()
fut
.
set_result
(
data
)
return
fut
@staticmethod
def
_index_copy_
(
data
:
Tensor
,
dim
:
int
,
index
:
Tensor
,
source
:
Tensor
):
stream
=
TensorAccessor
.
get_stream
()
with
torch
.
cuda
.
stream
(
stream
):
data
.
index_copy_
(
dim
,
index
,
source
)
fut
=
torch
.
futures
.
Future
()
fut
.
set_result
(
None
)
return
fut
@staticmethod
def
_index_add_
(
data
:
Tensor
,
dim
:
int
,
index
:
Tensor
,
source
:
Tensor
):
stream
=
TensorAccessor
.
get_stream
()
with
torch
.
cuda
.
stream
(
stream
):
data
.
index_add_
(
dim
,
index
,
source
)
fut
=
torch
.
futures
.
Future
()
fut
.
set_result
(
None
)
return
fut
@staticmethod
def
get_stream
()
->
Optional
[
torch
.
cuda
.
Stream
]:
global
_TENSOR_ACCESSOR_STREAM
if
torch
.
cuda
.
is_available
():
return
None
if
_TENSOR_ACCESSOR_STREAM
is
None
:
_TENSOR_ACCESSOR_STREAM
=
torch
.
cuda
.
Stream
()
return
_TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
class
DistInt
:
...
...
@@ -228,10 +236,6 @@ class DistributedTensor:
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
self
.
num_parts
()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f
=
self
.
accessor
.
async_index_select
(
0
,
index
[
part_idx
==
i
],
self
.
rrefs
[
i
])
futs
.
append
(
f
)
...
...
@@ -275,17 +279,4 @@ class DistributedTensor:
f
=
self
.
accessor
.
async_index_add_
(
0
,
index
[
mask
],
source
[
mask
],
self
.
rrefs
[
i
])
futs
.
append
(
f
)
return
torch
.
futures
.
collect_all
(
futs
)
def
index_fill_
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
],
value
:
Number
):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
part_idx
=
dist_index
.
part
index
=
dist_index
.
loc
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
self
.
num_parts
()):
mask
=
part_idx
==
i
f
=
self
.
accessor
.
async_index_fill_
(
0
,
index
[
mask
],
value
,
self
.
rrefs
[
i
])
futs
.
append
(
f
)
return
torch
.
futures
.
collect_all
(
futs
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment