Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
66045271
Commit
66045271
authored
Dec 19, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add DistTensor.all_to_all_[get|set]()
parent
510a41f8
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
103 deletions
+105
-103
starrygl/distributed/cclib.py
+6
-0
starrygl/distributed/context.py
+0
-71
starrygl/distributed/rpc.py
+4
-0
starrygl/distributed/utils.py
+95
-32
No files found.
starrygl/distributed/cclib.py
View file @
66045271
...
...
@@ -4,6 +4,12 @@ import torch.distributed as dist
from
torch
import
Tensor
from
typing
import
*
__all__
=
[
"all_to_all_v"
,
"all_to_all_s"
,
"BatchWork"
,
]
class
BatchWork
:
def
__init__
(
self
,
works
,
buffer_tensor_list
)
->
None
:
...
...
starrygl/distributed/context.py
View file @
66045271
...
...
@@ -11,7 +11,6 @@ from contextlib import contextmanager
import
logging
from
.cclib
import
all_to_all_v
,
all_to_all_s
from
.rpc
import
rpc_remote_call
,
rpc_remote_void_call
...
...
@@ -159,76 +158,6 @@ class DistributedContext:
def
remote_void_call
(
self
,
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
rpc_remote_void_call
(
method
,
rref
,
*
args
,
**
kwargs
)
# def all_to_all_v(self,
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# def all_to_all_g(self,
# input_tensor_list: List[Tensor],
# group: Any = None,
# async_op: bool = False,
# ):
# send_sizes = [t.size(0) for t in input_tensor_list]
# recv_sizes = self.get_all_to_all_recv_sizes(send_sizes, group)
# output_tensor_list: List[Tensor] = []
# for s, t in zip(recv_sizes, input_tensor_list):
# output_tensor_list.append(
# torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device),
# )
# work = all_to_all_v(
# output_tensor_list,
# input_tensor_list,
# group=group,
# async_op=async_op,
# )
# if async_op:
# assert work is not None
# return output_tensor_list, work
# else:
# return output_tensor_list
# def all_to_all_s(self,
# output_tensor: Tensor,
# input_tensor: Tensor,
# output_rowptr: List[int],
# input_rowptr: List[int],
# group: Any = None,
# async_op: bool = False,
# ):
# return all_to_all_s(
# output_tensor, input_tensor,
# output_rowptr, input_rowptr,
# group=group, async_op=async_op,
# )
# def get_all_to_all_recv_sizes(self,
# send_sizes: List[int],
# group: Optional[Any] = None,
# ) -> List[int]:
# world_size = dist.get_world_size(group)
# assert len(send_sizes) == world_size
# if dist.get_backend(group) == "gloo":
# send_t = torch.tensor(send_sizes, dtype=torch.long)
# else:
# send_t = torch.tensor(send_sizes, dtype=torch.long, device=self.device)
# recv_t = torch.empty_like(send_t)
# dist.all_to_all_single(recv_t, send_t, group=group)
# return recv_t.tolist()
@contextmanager
def
use_stream
(
self
,
stream
:
torch
.
cuda
.
Stream
,
with_event
:
bool
=
True
):
event
=
torch
.
cuda
.
Event
()
if
with_event
else
None
...
...
starrygl/distributed/rpc.py
View file @
66045271
...
...
@@ -4,6 +4,10 @@ import torch.distributed.rpc as rpc
from
torch
import
Tensor
from
typing
import
*
__all__
=
[
"rpc_remote_call"
,
"rpc_remote_void_call"
,
]
def
rpc_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
args
=
(
method
,
rref
)
+
args
...
...
starrygl/distributed/utils.py
View file @
66045271
from
typing
import
Any
import
torch
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
torch.types
import
Number
from
typing
import
*
from
torch_sparse
import
SparseTensor
from
.cclib
import
all_to_all_s
class
TensorAccessor
:
...
...
@@ -29,19 +31,19 @@ class TensorAccessor:
def
ctx
(
self
):
return
self
.
_ctx
def
all_gather_index
(
self
,
index
,
input_split
)
->
Tensor
:
out_split
=
torch
.
empty_like
(
input_split
)
torch
.
distributed
.
all_to_all_single
(
out_split
,
input_split
)
input_split
=
list
(
input_split
)
output
=
torch
.
empty
([
out_split
.
sum
()],
dtype
=
index
.
dtype
,
device
=
index
.
device
)
out_split
=
list
(
out_split
)
torch
.
distributed
.
all_to_all_single
(
output
,
index
,
out_split
,
input_split
)
return
output
,
out_split
,
input_split
#
def all_gather_index(self,index,input_split) -> Tensor:
#
out_split = torch.empty_like(input_split)
#
torch.distributed.all_to_all_single(out_split,input_split)
#
input_split = list(input_split)
#
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
#
out_split = list(out_split)
#
torch.distributed.all_to_all_single(output,index,out_split,input_split)
#
return output,out_split,input_split
def
all_gather_data
(
self
,
index
,
input_split
,
out_split
):
output
=
torch
.
empty
([
int
(
Tensor
(
out_split
)
.
sum
()
.
item
()),
*
self
.
_data
.
shape
[
1
:]],
dtype
=
self
.
_data
.
dtype
,
device
=
'cuda'
)
#self._data.device)
torch
.
distributed
.
all_to_all_single
(
output
,
self
.
data
[
index
.
to
(
self
.
data
.
device
)]
.
to
(
'cuda'
),
output_split_sizes
=
out_split
,
input_split_sizes
=
input_split
)
return
output
#
def all_gather_data(self,index,input_split,out_split):
#
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
#
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
#
return output
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
...
...
@@ -92,40 +94,46 @@ class DistInt:
class
DistIndex
:
def
__init__
(
self
,
index
:
Tensor
,
part_ids
:
Optional
[
Tensor
]
=
None
)
->
None
:
if
part_ids
is
None
:
self
.
data
=
index
.
long
()
self
.
_
data
=
index
.
long
()
else
:
index
,
part_ids
=
index
.
long
(),
part_ids
.
long
()
self
.
data
=
(
index
&
0xFFFFFFFFFFFF
)
|
((
part_ids
&
0xFFFF
)
<<
48
)
self
.
_
data
=
(
index
&
0xFFFFFFFFFFFF
)
|
((
part_ids
&
0xFFFF
)
<<
48
)
@property
def
loc
(
self
)
->
Tensor
:
return
self
.
data
&
0xFFFFFFFFFFFF
return
self
.
_
data
&
0xFFFFFFFFFFFF
@property
def
part
(
self
)
->
Tensor
:
return
(
self
.
data
>>
48
)
.
int
(
)
&
0xFFFF
return
(
self
.
_data
>>
48
)
&
0xFFFF
@property
def
dist
(
self
)
->
Tensor
:
return
self
.
data
return
self
.
_data
@property
def
dtype
(
self
):
return
self
.
_data
.
dtype
@property
def
device
(
self
):
return
self
.
_data
.
device
def
to
(
self
,
device
)
->
Tensor
:
return
DistIndex
(
self
.
data
.
to
(
device
))
return
DistIndex
(
self
.
_
data
.
to
(
device
))
class
DistributedTensor
:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
self
.
accessor
=
TensorAccessor
(
data
)
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
# self.num_parts = len(self.rrefs)
local_sizes
=
[]
for
rref
in
self
.
rrefs
:
n
=
self
.
ctx
.
remote_call
(
Tensor
.
size
,
rref
,
dim
=
0
)
.
wait
()
local_sizes
.
append
(
n
)
self
.
num_nodes
=
DistInt
(
local_sizes
)
self
.
num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
distptr
=
torch
.
tensor
([((
part_ids
&
0xFFFF
)
<<
48
)
for
part_ids
in
range
(
self
.
num_parts
()
+
1
)],
device
=
'cuda'
)
#data.device)
self
.
_num_nodes
=
DistInt
(
local_sizes
)
self
.
_num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
@property
def
dtype
(
self
):
...
...
@@ -135,6 +143,14 @@ class DistributedTensor:
def
device
(
self
):
return
self
.
accessor
.
data
.
device
@property
def
num_nodes
(
self
)
->
DistInt
:
return
self
.
_num_nodes
@property
def
num_parts
(
self
)
->
DistInt
:
return
self
.
_num_parts
def
to
(
self
,
device
):
return
self
.
accessor
.
data
.
to
(
device
)
...
...
@@ -145,16 +161,63 @@ class DistributedTensor:
def
ctx
(
self
):
return
self
.
accessor
.
ctx
def
gather_select_index
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
])
:
def
all_to_all_ind2ptr
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
])
->
Dict
[
str
,
Union
[
List
[
int
],
Tensor
]]
:
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
data
=
dist_index
.
data
posptr
=
torch
.
searchsorted
(
data
,
self
.
distptr
,
right
=
False
)
input_split
=
posptr
[
1
:]
-
posptr
[:
-
1
]
return
self
.
accessor
.
all_gather_index
(
DistIndex
(
data
)
.
loc
,
input_split
)
def
scatter_data
(
self
,
local_index
,
input_split
,
out_split
):
return
self
.
accessor
.
all_gather_data
(
local_index
,
input_split
=
input_split
,
out_split
=
out_split
)
send_ptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
dist_index
.
part
,
self
.
num_parts
())
send_sizes
=
send_ptr
[
1
:]
-
send_ptr
[:
-
1
]
recv_sizes
=
torch
.
empty_like
(
send_sizes
)
dist
.
all_to_all_single
(
recv_sizes
,
send_sizes
)
recv_ptr
=
torch
.
zeros
(
recv_sizes
.
numel
()
+
1
)
.
type_as
(
recv_sizes
)
recv_ptr
[
1
:]
=
recv_sizes
.
cumsum
(
dim
=
0
)
send_ptr
=
send_ptr
.
tolist
()
recv_ptr
=
recv_ptr
.
tolist
()
recv_ind
=
torch
.
full
((
recv_ptr
[
-
1
],),
(
2
**
62
-
1
)
*
2
+
1
,
dtype
=
dist_index
.
dtype
,
device
=
dist_index
.
device
)
all_to_all_s
(
recv_ind
,
dist_index
.
loc
,
send_ptr
,
recv_ptr
)
return
{
"send_ptr"
:
send_ptr
,
"recv_ptr"
:
recv_ptr
,
"recv_ind"
:
recv_ind
,
}
def
all_to_all_get
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
,
None
]
=
None
,
send_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ind
:
Optional
[
List
[
int
]]
=
None
,
)
->
Tensor
:
if
dist_index
is
not
None
:
dist_dict
=
self
.
all_to_all_ind2ptr
(
dist_index
)
send_ptr
=
dist_dict
[
"send_ptr"
]
recv_ptr
=
dist_dict
[
"recv_ptr"
]
recv_ind
=
dist_dict
[
"recv_ind"
]
data
=
self
.
accessor
.
data
[
recv_ind
]
recv
=
torch
.
empty
(
send_ptr
[
-
1
],
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
)
all_to_all_s
(
recv
,
data
,
send_ptr
,
recv_ptr
)
return
recv
def
all_to_all_set
(
self
,
data
:
Tensor
,
dist_index
:
Union
[
Tensor
,
DistIndex
,
None
]
=
None
,
send_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ptr
:
Optional
[
List
[
int
]]
=
None
,
recv_ind
:
Optional
[
List
[
int
]]
=
None
,
):
if
dist_index
is
not
None
:
dist_dict
=
self
.
all_to_all_ind2ptr
(
dist_index
)
send_ptr
=
dist_dict
[
"send_ptr"
]
recv_ptr
=
dist_dict
[
"recv_ptr"
]
recv_ind
=
dist_dict
[
"recv_ind"
]
recv
=
torch
.
empty
(
recv_ptr
[
-
1
],
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
)
all_to_all_s
(
recv
,
data
,
recv_ptr
,
send_ptr
)
self
.
accessor
.
data
.
index_copy_
(
0
,
recv_ind
,
recv
)
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment