Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
66045271
Commit
66045271
authored
Dec 19, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add DistTensor.all_to_all_[get|set]()
parent
510a41f8
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
103 deletions
+105
-103
starrygl/distributed/cclib.py
+6
-0
starrygl/distributed/context.py
+0
-71
starrygl/distributed/rpc.py
+4
-0
starrygl/distributed/utils.py
+95
-32
No files found.
starrygl/distributed/cclib.py
View file @
66045271
...
@@ -4,6 +4,12 @@ import torch.distributed as dist
...
@@ -4,6 +4,12 @@ import torch.distributed as dist
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
__all__
=
[
"all_to_all_v"
,
"all_to_all_s"
,
"BatchWork"
,
]
class
BatchWork
:
class
BatchWork
:
def
__init__
(
self
,
works
,
buffer_tensor_list
)
->
None
:
def
__init__
(
self
,
works
,
buffer_tensor_list
)
->
None
:
...
...
starrygl/distributed/context.py
View file @
66045271
...
@@ -11,7 +11,6 @@ from contextlib import contextmanager
...
@@ -11,7 +11,6 @@ from contextlib import contextmanager
import
logging
import
logging
from
.cclib
import
all_to_all_v
,
all_to_all_s
from
.rpc
import
rpc_remote_call
,
rpc_remote_void_call
from
.rpc
import
rpc_remote_call
,
rpc_remote_void_call
...
@@ -159,76 +158,6 @@ class DistributedContext:
...
@@ -159,76 +158,6 @@ class DistributedContext:
def
remote_void_call
(
self
,
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
def
remote_void_call
(
self
,
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
rpc_remote_void_call
(
method
,
rref
,
*
args
,
**
kwargs
)
return
rpc_remote_void_call
(
method
,
rref
,
*
args
,
**
kwargs
)
# def all_to_all_v(self,
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# def all_to_all_g(self,
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# send_sizes = [t.size(0) for t in input_tensor_list]
# recv_sizes = self.get_all_to_all_recv_sizes(send_sizes, group)
# output_tensor_list: List[Tensor] = []
# for s, t in zip(recv_sizes, input_tensor_list):
# output_tensor_list.append(
# torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device),
# )
# work = all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# if async_op:
# assert work is not None
# return output_tensor_list, work
# else:
# return output_tensor_list
# def all_to_all_s(self,
# output_tensor: Tensor,
# input_tensor: Tensor,
# output_rowptr: List[int],
# input_rowptr: List[int],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_s(
# output_tensor, input_tensor,
# output_rowptr, input_rowptr,
# group=group, async_op=async_op,
# )
# def get_all_to_all_recv_sizes(self,
# send_sizes: List[int],
# group: Optional[Any] = None,
# ) -> List[int]:
# world_size = dist.get_world_size(group)
# assert len(send_sizes) == world_size
# if dist.get_backend(group) == "gloo":
# send_t = torch.tensor(send_sizes, dtype=torch.long)
# else:
# send_t = torch.tensor(send_sizes, dtype=torch.long, device=self.device)
# recv_t = torch.empty_like(send_t)
# dist.all_to_all_single(recv_t, send_t, group=group)
# return recv_t.tolist()
@contextmanager
@contextmanager
def
use_stream
(
self
,
stream
:
torch
.
cuda
.
Stream
,
with_event
:
bool
=
True
):
def
use_stream
(
self
,
stream
:
torch
.
cuda
.
Stream
,
with_event
:
bool
=
True
):
event
=
torch
.
cuda
.
Event
()
if
with_event
else
None
event
=
torch
.
cuda
.
Event
()
if
with_event
else
None
...
...
starrygl/distributed/rpc.py
View file @
66045271
...
@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc
...
@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
__all__
=
[
"rpc_remote_call"
,
"rpc_remote_void_call"
,
]
def
rpc_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
def
rpc_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
args
=
(
method
,
rref
)
+
args
args
=
(
method
,
rref
)
+
args
...
...
starrygl/distributed/utils.py
View file @
66045271
from
typing
import
Any
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.types
import
Number
from
torch.types
import
Number
from
typing
import
*
from
typing
import
*
from
torch_sparse
import
SparseTensor
from
.cclib
import
all_to_all_s
class
TensorAccessor
:
class
TensorAccessor
:
...
@@ -29,19 +31,19 @@ class TensorAccessor:
...
@@ -29,19 +31,19 @@ class TensorAccessor:
def
ctx
(
self
):
def
ctx
(
self
):
return
self
.
_ctx
return
self
.
_ctx
def
all_gather_index
(
self
,
index
,
input_split
)
->
Tensor
:
#
def all_gather_index(self,index,input_split) -> Tensor:
out_split
=
torch
.
empty_like
(
input_split
)
#
out_split = torch.empty_like(input_split)
torch
.
distributed
.
all_to_all_single
(
out_split
,
input_split
)
#
torch.distributed.all_to_all_single(out_split,input_split)
input_split
=
list
(
input_split
)
#
input_split = list(input_split)
output
=
torch
.
empty
([
out_split
.
sum
()],
dtype
=
index
.
dtype
,
device
=
index
.
device
)
#
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
out_split
=
list
(
out_split
)
#
out_split = list(out_split)
torch
.
distributed
.
all_to_all_single
(
output
,
index
,
out_split
,
input_split
)
#
torch.distributed.all_to_all_single(output,index,out_split,input_split)
return
output
,
out_split
,
input_split
#
return output,out_split,input_split
def
all_gather_data
(
self
,
index
,
input_split
,
out_split
):
#
def all_gather_data(self,index,input_split,out_split):
output
=
torch
.
empty
([
int
(
Tensor
(
out_split
)
.
sum
()
.
item
()),
*
self
.
_data
.
shape
[
1
:]],
dtype
=
self
.
_data
.
dtype
,
device
=
'cuda'
)
#self._data.device)
#
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
torch
.
distributed
.
all_to_all_single
(
output
,
self
.
data
[
index
.
to
(
self
.
data
.
device
)]
.
to
(
'cuda'
),
output_split_sizes
=
out_split
,
input_split_sizes
=
input_split
)
#
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
return
output
#
return output
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
...
@@ -92,40 +94,46 @@ class DistInt:
...
@@ -92,40 +94,46 @@ class DistInt:
class
DistIndex
:
class
DistIndex
:
def
__init__
(
self
,
index
:
Tensor
,
part_ids
:
Optional
[
Tensor
]
=
None
)
->
None
:
def
__init__
(
self
,
index
:
Tensor
,
part_ids
:
Optional
[
Tensor
]
=
None
)
->
None
:
if
part_ids
is
None
:
if
part_ids
is
None
:
self
.
data
=
index
.
long
()
self
.
_
data
=
index
.
long
()
else
:
else
:
index
,
part_ids
=
index
.
long
(),
part_ids
.
long
()
index
,
part_ids
=
index
.
long
(),
part_ids
.
long
()
self
.
data
=
(
index
&
0xFFFFFFFFFFFF
)
|
((
part_ids
&
0xFFFF
)
<<
48
)
self
.
_
data
=
(
index
&
0xFFFFFFFFFFFF
)
|
((
part_ids
&
0xFFFF
)
<<
48
)
@property
@property
def
loc
(
self
)
->
Tensor
:
def
loc
(
self
)
->
Tensor
:
return
self
.
data
&
0xFFFFFFFFFFFF
return
self
.
_
data
&
0xFFFFFFFFFFFF
@property
@property
def
part
(
self
)
->
Tensor
:
def
part
(
self
)
->
Tensor
:
return
(
self
.
data
>>
48
)
.
int
(
)
&
0xFFFF
return
(
self
.
_data
>>
48
)
&
0xFFFF
@property
@property
def
dist
(
self
)
->
Tensor
:
def
dist
(
self
)
->
Tensor
:
return
self
.
data
return
self
.
_data
@property
def
dtype
(
self
):
return
self
.
_data
.
dtype
@property
def
device
(
self
):
return
self
.
_data
.
device
def
to
(
self
,
device
)
->
Tensor
:
def
to
(
self
,
device
)
->
Tensor
:
return
DistIndex
(
self
.
data
.
to
(
device
))
return
DistIndex
(
self
.
_
data
.
to
(
device
))
class
DistributedTensor
:
class
DistributedTensor
:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
self
.
accessor
=
TensorAccessor
(
data
)
self
.
accessor
=
TensorAccessor
(
data
)
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
# self.num_parts = len(self.rrefs)
local_sizes
=
[]
local_sizes
=
[]
for
rref
in
self
.
rrefs
:
for
rref
in
self
.
rrefs
:
n
=
self
.
ctx
.
remote_call
(
Tensor
.
size
,
rref
,
dim
=
0
)
.
wait
()
n
=
self
.
ctx
.
remote_call
(
Tensor
.
size
,
rref
,
dim
=
0
)
.
wait
()
local_sizes
.
append
(
n
)
local_sizes
.
append
(
n
)
self
.
num_nodes
=
DistInt
(
local_sizes
)
self
.
_num_nodes
=
DistInt
(
local_sizes
)
self
.
num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
_num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
distptr
=
torch
.
tensor
([((
part_ids
&
0xFFFF
)
<<
48
)
for
part_ids
in
range
(
self
.
num_parts
()
+
1
)],
device
=
'cuda'
)
#data.device)
@property
@property
def
dtype
(
self
):
def
dtype
(
self
):
...
@@ -135,6 +143,14 @@ class DistributedTensor:
...
@@ -135,6 +143,14 @@ class DistributedTensor:
def
device
(
self
):
def
device
(
self
):
return
self
.
accessor
.
data
.
device
return
self
.
accessor
.
data
.
device
@property
def
num_nodes
(
self
)
->
DistInt
:
return
self
.
_num_nodes
@property
def
num_parts
(
self
)
->
DistInt
:
return
self
.
_num_parts
def
to
(
self
,
device
):
def
to
(
self
,
device
):
return
self
.
accessor
.
data
.
to
(
device
)
return
self
.
accessor
.
data
.
to
(
device
)
...
@@ -145,16 +161,63 @@ class DistributedTensor:
...
@@ -145,16 +161,63 @@ class DistributedTensor:
def
ctx
(
self
):
def
ctx
(
self
):
return
self
.
accessor
.
ctx
return
self
.
accessor
.
ctx
def
gather_select_index
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
])
:
def
all_to_all_ind2ptr
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
])
->
Dict
[
str
,
Union
[
List
[
int
],
Tensor
]]
:
if
isinstance
(
dist_index
,
Tensor
):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
dist_index
=
DistIndex
(
dist_index
)
data
=
dist_index
.
data
send_ptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
dist_index
.
part
,
self
.
num_parts
())
posptr
=
torch
.
searchsorted
(
data
,
self
.
distptr
,
right
=
False
)
input_split
=
posptr
[
1
:]
-
posptr
[:
-
1
]
send_sizes
=
send_ptr
[
1
:]
-
send_ptr
[:
-
1
]
return
self
.
accessor
.
all_gather_index
(
DistIndex
(
data
)
.
loc
,
input_split
)
recv_sizes
=
torch
.
empty_like
(
send_sizes
)
dist
.
all_to_all_single
(
recv_sizes
,
send_sizes
)
def
scatter_data
(
self
,
local_index
,
input_split
,
out_split
):
return
self
.
accessor
.
all_gather_data
(
local_index
,
input_split
=
input_split
,
out_split
=
out_split
)
recv_ptr
=
torch
.
zeros
(
recv_sizes
.
numel
()
+
1
)
.
type_as
(
recv_sizes
)
recv_ptr
[
1
:]
=
recv_sizes
.
cumsum
(
dim
=
0
)
send_ptr
=
send_ptr
.
tolist
()
recv_ptr
=
recv_ptr
.
tolist
()
recv_ind
=
torch
.
full
((
recv_ptr
[
-
1
],),
(
2
**
62
-
1
)
*
2
+
1
,
dtype
=
dist_index
.
dtype
,
device
=
dist_index
.
device
)
all_to_all_s
(
recv_ind
,
dist_index
.
loc
,
send_ptr
,
recv_ptr
)
return
{
"send_ptr"
:
send_ptr
,
"recv_ptr"
:
recv_ptr
,
"recv_ind"
:
recv_ind
,
}
def
all_to_all_get
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
,
None
]
=
None
,
send_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ind
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tensor
:
if
dist_index
is
not
None
:
dist_dict
=
self
.
all_to_all_ind2ptr
(
dist_index
)
send_ptr
=
dist_dict
[
"send_ptr"
]
recv_ptr
=
dist_dict
[
"recv_ptr"
]
recv_ind
=
dist_dict
[
"recv_ind"
]
data
=
self
.
accessor
.
data
[
recv_ind
]
recv
=
torch
.
empty
(
send_ptr
[
-
1
],
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
)
all_to_all_s
(
recv
,
data
,
send_ptr
,
recv_ptr
)
return
recv
def
all_to_all_set
(
self
,
data
:
Tensor
,
dist_index
:
Union
[
Tensor
,
DistIndex
,
None
]
=
None
,
send_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ind
:
Optional
[
List
[
int
]]
=
None
,
):
if
dist_index
is
not
None
:
dist_dict
=
self
.
all_to_all_ind2ptr
(
dist_index
)
send_ptr
=
dist_dict
[
"send_ptr"
]
recv_ptr
=
dist_dict
[
"recv_ptr"
]
recv_ind
=
dist_dict
[
"recv_ind"
]
recv
=
torch
.
empty
(
recv_ptr
[
-
1
],
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
)
all_to_all_s
(
recv
,
data
,
recv_ptr
,
send_ptr
)
self
.
accessor
.
data
.
index_copy_
(
0
,
recv_ind
,
recv
)
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
if
isinstance
(
dist_index
,
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment