Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
6fb76097
Commit
6fb76097
authored
Dec 07, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update
parent
57d75031
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
161 additions
and
7 deletions
+161
-7
run_route.py
+64
-0
starrygl/graph/__init__.py
+22
-0
starrygl/graph/route.py
+75
-7
No files found.
run_route.py
0 → 100644
View file @
6fb76097
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.graph
import
new_vc_route
from
torch_scatter
import
scatter_sum
all_nparts
=
[
[
0
,
1
],
[
2
,
3
,
4
],
[
5
,
6
],
]
all_eparts
=
[
[
[
2
,
0
],
[
1
,
0
],
[
0
,
1
],
[
3
,
1
],
[
4
,
1
],
],
[
[
0
,
2
],
[
3
,
2
],
[
5
,
2
],
[
1
,
3
],
[
2
,
3
],
[
4
,
3
],
[
6
,
3
],
[
1
,
4
],
[
3
,
4
],
],
[
[
2
,
5
],
[
6
,
5
],
[
5
,
6
],
[
3
,
6
],
],
]
def
get_route
(
bipartite
:
bool
=
True
):
ctx
=
DistributedContext
.
get_default_context
()
assert
ctx
.
world_size
==
3
dst_ids
=
torch
.
tensor
(
all_nparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
edge_index
=
torch
.
tensor
(
all_eparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
.
t
()
return
new_vc_route
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
if
__name__
==
"__main__"
:
ctx
=
DistributedContext
.
init
(
backend
=
"gloo"
,
use_gpu
=
True
)
src_ids
,
edge_index
,
dst_ids
,
route
=
get_route
(
False
)
src_size
=
route
.
src_len
dst_size
=
route
.
dst_len
ctx
.
sync_print
(
route
.
src_len
,
route
.
dst_len
)
edge_ones
=
torch
.
ones
(
edge_index
.
size
(
1
),
device
=
ctx
.
device
)
.
requires_grad_
()
src_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
route
.
src_len
)
dst_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
route
.
dst_len
)
# ctx.sync_print(route.fw_tensor(dst_ones))
# ctx.sync_print(route.bw_tensor(src_ones))
out
=
route
.
reverse_route
()
.
apply
(
src_ones
)
ctx
.
sync_print
(
out
)
out
.
sum
()
.
backward
()
ctx
.
sync_print
(
edge_ones
.
grad
)
ctx
.
sync_print
(
route
.
get_src_part_ids
())
ctx
.
shutdown
()
starrygl/graph/__init__.py
View file @
6fb76097
from
.route
import
Route
from
.utils
import
init_vc_edge_index
from
torch
import
Tensor
from
typing
import
Tuple
__all__
=
[
"Route"
,
"init_vc_edge_index"
,
"new_vc_route"
,
]
def
new_vc_route
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Route
]:
src_ids
,
local_edge_index
=
init_vc_edge_index
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
route
=
Route
.
from_raw_indices
(
src_ids
,
dst_ids
,
bipartite
=
bipartite
)
return
src_ids
,
local_edge_index
,
dst_ids
,
route
starrygl/graph/route.py
View file @
6fb76097
...
...
@@ -31,33 +31,48 @@ class Route:
group
=
group
,
)
def
new_subroute
(
self
,
dst_mask
:
Tensor
):
def
subroute
(
self
,
dst_mask
:
Tensor
,
src_mask
:
Optional
[
Tensor
]
=
None
,
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
if
src_mask
is
None
:
fw_dst_masks
,
work
=
self
.
all_to_all_fw
(
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
else
:
if
src_mask
.
dtype
!=
torch
.
bool
:
src_mask
=
Route
.
__expand_index_to_mask
(
src_mask
,
self
.
src_len
)
assert
src_mask
.
size
(
0
)
==
self
.
src_len
fw_tables
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
m
=
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
fw_tables
.
append
(
self
.
fw_table
(
i
)[:,
m
])
work
.
wait
()
bw_tables
:
List
[
Tensor
]
=
[]
if
src_mask
is
None
:
src_mask
=
torch
.
zeros
(
self
.
src_len
,
dtype
=
dst_mask
.
dtype
,
device
=
dst_mask
.
device
)
work
.
wait
()
for
i
,
m
in
enumerate
(
fw_dst_masks
):
src_mask
[
self
.
bw_table
(
i
)[
0
]]
|=
m
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
else
:
for
i
in
range
(
self
.
num_parts
):
m
=
src_mask
[
self
.
bw_table
(
i
)[
0
]]
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
return
Route
(
return
dst_mask
,
Route
(
src_len
=
self
.
src_len
,
dst_len
=
self
.
dst_len
,
**
Route
.
__tables_to_indptr
(
fw_tables
,
bw_tables
),
group
=
self
.
group
,
)
)
,
src_mask
,
dst_mask
def
reverse_route
(
self
):
return
Route
(
...
...
@@ -68,6 +83,43 @@ class Route:
group
=
self
.
group
,
)
def
filter_nodes
(
self
,
dst_mask
:
Tensor
,
src_mask
:
Optional
[
Tensor
]
=
None
,
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
new_dst_len
=
dst_mask
.
count_nonzero
()
.
item
()
xmp
=
torch
.
empty
(
self
.
dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_mask
]
=
torch
.
arange
(
new_dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
fw_dst_masks
,
fw_m_work
=
self
.
all_to_all_fw
(
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def
filter_edges
(
self
,
edge_mask
:
Tensor
,
edge_index
:
Tensor
,
):
pass
def
__init__
(
self
,
src_len
:
int
,
dst_len
:
int
,
fw_ptr
:
List
[
int
],
fw_ind
:
Tensor
,
...
...
@@ -151,6 +203,22 @@ class Route:
out
[
self
.
fw_table
(
i
)[
0
]]
+=
t
return
out
def
get_src_part_ids
(
self
)
->
Tensor
:
xs
=
self
.
all_to_all_fw
(
[
torch
.
full
(
(
self
.
fw_table
(
i
)
.
size
(
1
),),
self
.
part_id
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
,
)
for
i
in
range
(
self
.
num_parts
)
],
async_op
=
False
,
)
out
=
torch
.
full
((
self
.
src_len
,),
2
**
16
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
bw_table
(
i
)[
0
]]
=
t
return
out
.
int
()
&
0xFFFF
def
all_to_all_fw
(
self
,
input_tensor_list
:
List
[
Tensor
],
async_op
:
bool
=
False
):
output_tensor_list
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
...
...
@@ -334,4 +402,4 @@ class RouteAlltoAll(autograd.Function):
grad
:
Tensor
,
)
->
Tensor
:
route
:
Route
=
ctx
.
saved_route
return
route
.
bw_tensor
(
grad
)
\ No newline at end of file
return
route
.
bw_tensor
(
grad
),
None
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment