Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
9586742a
Commit
9586742a
authored
Dec 19, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bugs. Route
parent
10c38111
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
552 additions
and
365 deletions
+552
-365
cora.py
+4
-3
run_route.py
+25
-10
starrygl/distributed/cclib.py
+0
-2
starrygl/graph/__init__.py
+3
-24
starrygl/graph/data.py
+268
-98
starrygl/graph/route.py
+252
-182
starrygl/graph/utils.py
+0
-46
No files found.
cora.py
View file @
9586742a
...
@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
...
@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import
os.path
as
osp
import
os.path
as
osp
import
sys
import
sys
from
starrygl.
utils.data
import
partition_pyg
from
starrygl.
graph
import
GraphData
import
logging
import
logging
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
...
@@ -19,6 +19,8 @@ if __name__ == "__main__":
...
@@ -19,6 +19,8 @@ if __name__ == "__main__":
print
(
f
"num_edges: {data.num_edges}"
)
print
(
f
"num_edges: {data.num_edges}"
)
print
(
f
"num_features: {data.num_features}"
)
print
(
f
"num_features: {data.num_features}"
)
data
=
GraphData
.
from_pyg_data
(
data
)
num_parts_list
=
[
1
,
2
,
3
,
5
,
7
,
9
,
11
]
num_parts_list
=
[
1
,
2
,
3
,
5
,
7
,
9
,
11
]
algos
=
[
"metis"
,
'mt-metis'
,
"random"
]
algos
=
[
"metis"
,
'mt-metis'
,
"random"
]
...
@@ -27,4 +29,4 @@ if __name__ == "__main__":
...
@@ -27,4 +29,4 @@ if __name__ == "__main__":
for
num_parts
in
num_parts_list
:
for
num_parts
in
num_parts_list
:
for
algo
in
algos
:
for
algo
in
algos
:
print
(
f
"======== {num_parts} + {algo} ========"
)
print
(
f
"======== {num_parts} + {algo} ========"
)
partition_pyg
(
root
,
data
,
num_parts
,
algo
)
data
.
save_partition
(
root
,
num_parts
,
algo
)
\ No newline at end of file
run_route.py
View file @
9586742a
...
@@ -5,7 +5,7 @@ from torch import Tensor
...
@@ -5,7 +5,7 @@ from torch import Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.distributed
import
DistributedContext
from
starrygl.graph
import
new_vc_route
from
starrygl.graph
import
*
from
torch_scatter
import
scatter_sum
from
torch_scatter
import
scatter_sum
...
@@ -28,32 +28,38 @@ all_eparts = [
...
@@ -28,32 +28,38 @@ all_eparts = [
],
],
]
]
def
get_data
():
def
get_route
(
bipartite
:
bool
=
True
):
ctx
=
DistributedContext
.
get_default_context
()
ctx
=
DistributedContext
.
get_default_context
()
assert
ctx
.
world_size
==
3
assert
ctx
.
world_size
==
3
dst_ids
=
torch
.
tensor
(
all_nparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
dst_ids
=
torch
.
tensor
(
all_nparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
edge_index
=
torch
.
tensor
(
all_eparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
.
t
()
edge_index
=
torch
.
tensor
(
all_eparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
.
t
()
return
new_vc_route
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
src_ids
,
edge_index
=
init_vc_edge_index
(
dst_ids
,
edge_index
)
return
GraphData
.
from_bipartite
(
edge_index
,
raw_src_ids
=
src_ids
,
raw_dst_ids
=
dst_ids
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ctx
=
DistributedContext
.
init
(
backend
=
"gloo"
,
use_gpu
=
True
)
ctx
=
DistributedContext
.
init
(
backend
=
"gloo"
,
use_gpu
=
True
)
src_ids
,
edge_index
,
dst_ids
,
route
=
get_route
(
False
)
g
=
get_data
()
src_size
=
route
.
src_len
route
=
g
.
to_route
()
dst_size
=
route
.
dst_len
edge_index
=
g
.
edge_index
()
# src_ids, edge_index, dst_ids, route = get_route(False)
# src_size = route.src_len
# dst_size = route.dst_len
ctx
.
sync_print
(
route
.
src_len
,
route
.
dst_len
)
ctx
.
sync_print
(
route
.
src_len
,
route
.
dst_len
)
ctx
.
sync_print
(
route
.
_fw_ptr
,
route
.
_fw_ind
)
ctx
.
sync_print
(
route
.
_bw_ptr
,
route
.
_bw_ind
)
edge_ones
=
torch
.
ones
(
edge_index
.
size
(
1
),
device
=
ctx
.
device
)
.
requires_grad_
()
edge_ones
=
torch
.
ones
(
edge_index
.
size
(
1
),
device
=
ctx
.
device
)
.
requires_grad_
()
src_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
route
.
src_len
)
src_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
route
.
src_len
)
dst_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
route
.
dst_len
)
dst_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
route
.
dst_len
)
#
ctx.sync_print(route.fw_tensor(dst_ones))
ctx
.
sync_print
(
route
.
fw_tensor
(
dst_ones
))
#
ctx.sync_print(route.bw_tensor(src_ones))
ctx
.
sync_print
(
route
.
bw_tensor
(
src_ones
))
out
=
route
.
rev
erse_route
()
.
apply
(
src_ones
)
out
=
route
.
rev
()
.
apply
(
src_ones
)
ctx
.
sync_print
(
out
)
ctx
.
sync_print
(
out
)
out
.
sum
()
.
backward
()
out
.
sum
()
.
backward
()
...
@@ -61,4 +67,13 @@ if __name__ == "__main__":
...
@@ -61,4 +67,13 @@ if __name__ == "__main__":
ctx
.
sync_print
(
route
.
get_src_part_ids
())
ctx
.
sync_print
(
route
.
get_src_part_ids
())
dst_mask
=
torch
.
full
((
route
.
dst_len
,),
ctx
.
rank
%
2
,
dtype
=
torch
.
bool
,
device
=
ctx
.
device
)
ctx
.
main_print
(
"="
*
64
)
ctx
.
sync_print
(
dst_mask
)
_
,
_
,
r2
=
route
.
filter
(
dst_mask
)
ctx
.
sync_print
(
r2
.
apply
(
dst_ones
)
.
detach
())
ctx
.
sync_print
(
r2
.
rev
()
.
apply
(
src_ones
)
.
detach
())
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
ctx
.
shutdown
()
ctx
.
shutdown
()
starrygl/distributed/cclib.py
View file @
9586742a
...
@@ -45,8 +45,6 @@ def all_to_all_v(
...
@@ -45,8 +45,6 @@ def all_to_all_v(
assert
len
(
output_tensor_list
)
==
world_size
assert
len
(
output_tensor_list
)
==
world_size
assert
len
(
input_tensor_list
)
==
world_size
assert
len
(
input_tensor_list
)
==
world_size
# if group is None:
# group = dist.distributed_c10d._get_default_group()
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
if
backend
==
"nccl"
:
if
backend
==
"nccl"
:
...
...
starrygl/graph/__init__.py
View file @
9586742a
from
.route
import
Route
from
.data
import
*
from
.utils
import
init_vc_edge_index
from
.route
import
*
\ No newline at end of file
from
torch
import
Tensor
from
typing
import
Tuple
__all__
=
[
"Route"
,
"init_vc_edge_index"
,
"new_vc_route"
,
]
def
new_vc_route
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Route
]:
src_ids
,
local_edge_index
=
init_vc_edge_index
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
route
=
Route
.
from_raw_indices
(
src_ids
,
dst_ids
,
bipartite
=
bipartite
)
return
src_ids
,
local_edge_index
,
dst_ids
,
route
starrygl/
utils
/data.py
→
starrygl/
graph
/data.py
View file @
9586742a
...
@@ -3,19 +3,18 @@ import torch
...
@@ -3,19 +3,18 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
import
os
import
os.path
as
osp
import
shutil
import
shutil
from
pathlib
import
Path
from
pathlib
import
Path
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
.partition
import
*
from
starrygl.utils.partition
import
*
from
.route
import
Route
import
logging
import
logging
__all__
=
[
__all__
=
[
"GraphData"
,
"GraphData"
,
"partition_pyg"
,
"init_vc_edge_index"
,
"partition_load"
,
]
]
...
@@ -75,6 +74,17 @@ class GraphData:
...
@@ -75,6 +74,17 @@ class GraphData:
return
data
return
data
return
self
.
_edge_indices
[
edge_type
]
return
self
.
_edge_indices
[
edge_type
]
def
node_types
(
self
)
->
List
[
str
]:
return
list
(
self
.
_node_data
.
keys
())
def
edge_types
(
self
)
->
List
[
Tuple
[
str
,
str
,
str
]]:
return
list
(
self
.
_edge_data
.
keys
())
def
to_route
(
self
,
group
:
Any
=
None
)
->
Route
:
src_ids
=
self
.
node
(
"src"
)[
"raw_ids"
]
dst_ids
=
self
.
node
(
"dst"
)[
"raw_ids"
]
return
Route
.
from_raw_indices
(
src_ids
,
dst_ids
,
group
=
group
)
@property
@property
def
is_heterogeneous
(
self
)
->
bool
:
def
is_heterogeneous
(
self
)
->
bool
:
return
self
.
_heterogeneous
return
self
.
_heterogeneous
...
@@ -92,6 +102,126 @@ class GraphData:
...
@@ -92,6 +102,126 @@ class GraphData:
return
self
return
self
@staticmethod
def
from_bipartite
(
edge_index
:
Tensor
,
num_src_nodes
:
Optional
[
int
]
=
None
,
num_dst_nodes
:
Optional
[
int
]
=
None
,
raw_src_ids
:
Optional
[
Tensor
]
=
None
,
raw_dst_ids
:
Optional
[
Tensor
]
=
None
,
)
->
'GraphData'
:
if
num_src_nodes
is
None
:
num_src_nodes
=
raw_src_ids
.
numel
()
if
num_dst_nodes
is
None
:
num_dst_nodes
=
raw_dst_ids
.
numel
()
g
=
GraphData
(
edge_indices
=
{
(
"src"
,
"@"
,
"dst"
):
edge_index
,
},
num_nodes
=
{
"src"
:
num_src_nodes
,
"dst"
:
num_dst_nodes
,
}
)
if
raw_src_ids
is
not
None
:
g
.
node
(
"src"
)[
"raw_ids"
]
=
raw_src_ids
if
raw_dst_ids
is
not
None
:
g
.
node
(
"dst"
)[
"raw_ids"
]
=
raw_dst_ids
return
g
@staticmethod
def
from_pyg_data
(
data
)
->
'GraphData'
:
from
torch_geometric.data
import
Data
assert
isinstance
(
data
,
Data
),
f
"must be Data class in pyg"
g
=
GraphData
(
data
.
edge_index
,
data
.
num_nodes
)
for
key
,
val
in
data
:
if
key
==
"edge_index"
:
continue
elif
isinstance
(
val
,
Tensor
):
if
val
.
size
(
0
)
==
data
.
num_nodes
:
g
.
node
()[
key
]
=
val
elif
val
.
size
(
0
)
==
data
.
num_edges
:
g
.
edge
()[
key
]
=
val
elif
isinstance
(
val
,
SparseTensor
):
logging
.
warning
(
f
"found sparse matrix {key}, but ignored."
)
else
:
g
.
meta
()[
key
]
=
val
return
g
@staticmethod
def
load_partition
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
'GraphData'
:
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algo}_{num_parts}"
/
f
"{part_id:03d}"
return
torch
.
load
(
p
.
__str__
())
def
save_partition
(
self
,
root
:
str
,
num_parts
:
int
,
algo
:
str
=
"metis"
):
assert
not
self
.
is_heterogeneous
,
"only support homomorphic graph"
num_nodes
:
int
=
self
.
node
()
.
num_nodes
edge_index
:
Tensor
=
self
.
edge_index
()
logging
.
info
(
f
"running partition aglorithm: {algo}"
)
if
algo
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
)
else
:
raise
ValueError
(
f
"unknown partition algorithm: {algo}"
)
root_path
=
Path
(
root
)
.
expanduser
()
.
resolve
()
base_path
=
root_path
/
f
"{algo}_{num_parts}"
if
base_path
.
exists
():
logging
.
warning
(
f
"directory '{base_path.__str__()}' exists, and will be removed."
)
shutil
.
rmtree
(
base_path
.
__str__
())
base_path
.
mkdir
(
parents
=
True
)
for
i
in
range
(
num_parts
):
npart_mask
=
node_parts
==
i
epart_mask
=
npart_mask
[
edge_index
[
1
]]
raw_dst_ids
:
Tensor
=
torch
.
where
(
npart_mask
)[
0
]
local_edges
=
edge_index
[:,
epart_mask
]
raw_src_ids
,
local_edges
=
init_vc_edge_index
(
raw_dst_ids
,
local_edges
,
bipartite
=
True
,
)
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g
=
GraphData
.
from_bipartite
(
local_edges
,
raw_src_ids
=
raw_src_ids
,
raw_dst_ids
=
raw_dst_ids
,
)
for
key
in
self
.
node
()
.
keys
():
g
.
node
(
"dst"
)[
key
]
=
self
.
node
()[
key
][
npart_mask
]
for
key
in
self
.
edge
()
.
keys
():
g
.
edge
()[
key
]
=
self
.
edge
()[
key
][
epart_mask
]
for
key
in
self
.
meta
()
.
keys
():
g
.
meta
()[
key
]
=
self
.
meta
()[
key
]
logging
.
info
(
f
"saving partition data: {i+1}/{num_parts}"
)
torch
.
save
(
g
,
(
base_path
/
f
"{i:03d}"
)
.
__str__
())
class
MetaData
:
class
MetaData
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_data
:
Dict
[
str
,
Any
]
=
{}
self
.
_data
:
Dict
[
str
,
Any
]
=
{}
...
@@ -193,97 +323,137 @@ class EdgeData:
...
@@ -193,97 +323,137 @@ class EdgeData:
return
self
return
self
def
partition_load
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
GraphData
:
def
init_vc_edge_index
(
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algo}_{num_parts}"
/
f
"{part_id:03d}"
dst_ids
:
Tensor
,
return
torch
.
load
(
p
.
__str__
())
edge_index
:
Tensor
,
bipartite
:
bool
=
True
,
)
->
Tuple
[
Tensor
,
Tensor
]:
def
partition_pyg
(
root
:
str
,
data
,
num_parts
:
int
,
algo
:
str
=
"metis"
):
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
dst_ids
.
device
)
root_path
=
Path
(
root
)
.
expanduser
()
.
resolve
()
base_path
=
root_path
/
f
"{algo}_{num_parts}"
local_num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
dst_ids
.
numel
()
>
0
:
if
base_path
.
exists
():
local_num_nodes
=
dst_ids
.
max
()
.
max
(
local_num_nodes
)
shutil
.
rmtree
(
base_path
.
__str__
())
if
edge_index
.
numel
()
>
0
:
base_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
local_num_nodes
=
edge_index
.
max
()
.
max
(
local_num_nodes
)
local_num_nodes
=
local_num_nodes
.
item
()
+
1
for
i
,
g
in
enumerate
(
partition_pyg_data
(
data
,
num_parts
,
algo
)):
logging
.
info
(
f
"saving partition data: {i+1}/{num_parts}"
)
xmp
:
Tensor
=
torch
.
zeros
(
local_num_nodes
,
**
ikw
)
torch
.
save
(
g
,
(
base_path
/
f
"{i:03d}"
)
.
__str__
())
xmp
[
edge_index
[
1
]
.
unique
()]
+=
0
b01
xmp
[
dst_ids
.
unique
()]
+=
0
b10
if
not
(
xmp
!=
0x01
)
.
all
():
def
partition_pyg_data
(
data
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
Iterator
[
GraphData
]:
raise
RuntimeError
(
f
"must be vertex-cut partition graph"
)
from
torch_geometric.data
import
Data
assert
isinstance
(
data
,
Data
),
f
"must be Data class in pyg"
if
bipartite
:
src_ids
=
edge_index
[
0
]
.
unique
()
logging
.
info
(
f
"running partition aglorithm: {algo}"
)
num_nodes
:
int
=
data
.
num_nodes
num_edges
:
int
=
data
.
num_edges
edge_index
:
Tensor
=
data
.
edge_index
if
algo
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
)
else
:
else
:
raise
ValueError
(
f
"unknown partition algorithm: {algo}"
)
xmp
.
fill_
(
0
)
xmp
[
edge_index
[
0
]]
=
1
if
data
.
y
.
dtype
==
torch
.
long
:
xmp
[
dst_ids
]
=
0
if
data
.
y
.
dim
()
==
1
:
src_ids
=
torch
.
cat
([
dst_ids
,
torch
.
where
(
xmp
>
0
)[
0
]],
dim
=-
1
)
num_classes
=
data
.
y
.
max
()
.
item
()
+
1
else
:
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
num_classes
=
data
.
y
.
size
(
1
)
xmp
[
src_ids
]
=
torch
.
arange
(
src_ids
.
size
(
0
),
**
ikw
)
else
:
src
=
xmp
[
edge_index
[
0
]]
num_classes
=
None
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
for
i
in
range
(
num_parts
):
xmp
[
dst_ids
]
=
torch
.
arange
(
dst_ids
.
size
(
0
),
**
ikw
)
npart_mask
=
node_parts
==
i
dst
=
xmp
[
edge_index
[
1
]]
epart_mask
=
npart_mask
[
edge_index
[
1
]]
local_edge_index
=
torch
.
vstack
([
src
,
dst
])
local_edges
=
edge_index
[:,
epart_mask
]
return
src_ids
,
local_edge_index
raw_src_ids
:
Tensor
=
local_edges
[
0
]
.
unique
()
raw_dst_ids
:
Tensor
=
torch
.
where
(
npart_mask
)[
0
]
# def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
# p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
M
:
int
=
raw_src_ids
.
max
()
.
item
()
+
1
# return torch.load(p.__str__())
imap
=
torch
.
full
((
M
,),
(
2
**
62
-
1
)
*
2
+
1
)
.
type_as
(
raw_src_ids
)
imap
[
raw_src_ids
]
=
torch
.
arange
(
raw_src_ids
.
numel
())
.
type_as
(
raw_src_ids
)
local_src
=
imap
[
local_edges
[
0
]]
# def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
# root_path = Path(root).expanduser().resolve()
M
:
int
=
raw_dst_ids
.
max
()
.
item
()
+
1
# base_path = root_path / f"{algo}_{num_parts}"
imap
=
torch
.
full
((
M
,),
(
2
**
62
-
1
)
*
2
+
1
)
.
type_as
(
raw_dst_ids
)
imap
[
raw_dst_ids
]
=
torch
.
arange
(
raw_dst_ids
.
numel
())
.
type_as
(
raw_dst_ids
)
# if base_path.exists():
local_dst
=
imap
[
local_edges
[
1
]]
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
local_edges
=
torch
.
vstack
([
local_src
,
local_dst
])
# for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
g
=
GraphData
(
# logging.info(f"saving partition data: {i+1}/{num_parts}")
edge_indices
=
{
# torch.save(g, (base_path / f"{i:03d}").__str__())
(
"src"
,
"@"
,
"dst"
):
local_edges
,
},
num_nodes
=
{
# def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
"src"
:
raw_src_ids
.
numel
(),
# from torch_geometric.data import Data
"dst"
:
raw_dst_ids
.
numel
(),
# assert isinstance(data, Data), f"must be Data class in pyg"
},
)
# logging.info(f"running partition aglorithm: {algo}")
g
.
node
(
"src"
)[
"raw_ids"
]
=
raw_src_ids
g
.
node
(
"dst"
)[
"raw_ids"
]
=
raw_dst_ids
# num_nodes: int = data.num_nodes
# num_edges: int = data.num_edges
if
num_classes
is
not
None
:
# edge_index: Tensor = data.edge_index
g
.
meta
()[
"num_classes"
]
=
num_classes
# if algo == "metis":
for
key
,
val
in
data
:
# node_parts = metis_partition(edge_index, num_nodes, num_parts)
if
key
==
"edge_index"
:
# elif algo == "mt-metis":
continue
# node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif
isinstance
(
val
,
Tensor
):
# elif algo == "random":
if
val
.
size
(
0
)
==
num_nodes
:
# node_parts = random_partition(edge_index, num_nodes, num_parts)
g
.
node
(
"dst"
)[
key
]
=
val
[
npart_mask
]
# else:
elif
val
.
size
(
0
)
==
num_edges
:
# raise ValueError(f"unknown partition algorithm: {algo}")
g
.
edge
()[
key
]
=
val
[
epart_mask
]
elif
isinstance
(
val
,
SparseTensor
):
# if data.y.dtype == torch.long:
pass
# if data.y.dim() == 1:
else
:
# num_classes = data.y.max().item() + 1
g
.
meta
()[
key
]
=
val
# else:
yield
g
# num_classes = data.y.size(1)
# else:
# num_classes = None
# for i in range(num_parts):
# npart_mask = node_parts == i
# epart_mask = npart_mask[edge_index[1]]
# local_edges = edge_index[:, epart_mask]
# raw_src_ids: Tensor = local_edges[0].unique()
# raw_dst_ids: Tensor = torch.where(npart_mask)[0]
# M: int = raw_src_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
# imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
# local_src = imap[local_edges[0]]
# M: int = raw_dst_ids.max().item() + 1
# imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
# imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
# local_dst = imap[local_edges[1]]
# local_edges = torch.vstack([local_src, local_dst])
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# },
# )
# g.node("src")["raw_ids"] = raw_src_ids
# g.node("dst")["raw_ids"] = raw_dst_ids
# if num_classes is not None:
# g.meta()["num_classes"] = num_classes
# for key, val in data:
# if key == "edge_index":
# continue
# elif isinstance(val, Tensor):
# if val.size(0) == num_nodes:
# g.node("dst")[key] = val[npart_mask]
# elif val.size(0) == num_edges:
# g.edge()[key] = val[epart_mask]
# elif isinstance(val, SparseTensor):
# pass
# else:
# g.meta()[key] = val
# yield g
starrygl/graph/route.py
View file @
9586742a
...
@@ -2,13 +2,18 @@ import torch
...
@@ -2,13 +2,18 @@ import torch
import
torch.autograd
as
autograd
import
torch.autograd
as
autograd
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
multiprocessing.pool
import
ThreadPool
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.distributed.cclib
import
all_to_all_s
,
all_to_all_v
__all__
=
[
"Route"
,
"RouteWork"
,
"RouteWorkCache"
,
"RouteAlltoAll"
,
]
class
Route
:
class
Route
:
...
@@ -31,50 +36,51 @@ class Route:
...
@@ -31,50 +36,51 @@ class Route:
group
=
group
,
group
=
group
,
)
)
def
subroute
(
self
,
def
filter
(
self
,
dst_mask
:
Tensor
,
dst_mask
:
Optional
[
Tensor
]
=
None
,
src_mask
:
Optional
[
Tensor
]
=
None
,
src_mask
:
Optional
[
Tensor
]
=
None
,
remap
:
bool
=
False
,
):
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
if
dst_mask
is
None
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
if
src_mask
is
None
:
if
src_mask
is
None
:
fw_dst_masks
,
work
=
self
.
all_to_all_fw
(
raise
ValueError
(
"please provide at least one parameter."
)
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
else
:
else
:
if
src_mask
.
dtype
!=
torch
.
bool
:
assert
src_mask
.
dtype
==
torch
.
bool
src_mask
=
Route
.
__expand_index_to_mask
(
src_mask
,
self
.
src_len
)
assert
src_mask
.
numel
()
==
self
.
src_len
assert
src_mask
.
size
(
0
)
==
self
.
src_len
dst_mask
=
self
.
bw_tensor
(
src_mask
.
long
())
!=
0
else
:
fw_tables
:
List
[
Tensor
]
=
[]
assert
dst_mask
.
dtype
==
torch
.
bool
for
i
in
range
(
self
.
num_parts
):
assert
dst_mask
.
numel
()
==
self
.
dst_len
m
=
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
tmp_src_mask
=
self
.
fw_tensor
(
dst_mask
.
long
())
!=
0
fw_tables
.
append
(
self
.
fw_table
(
i
)[:,
m
])
bw_tables
:
List
[
Tensor
]
=
[]
if
src_mask
is
None
:
if
src_mask
is
None
:
src_mask
=
torch
.
zeros
(
self
.
src_len
,
dtype
=
dst_mask
.
dtype
,
device
=
dst_mask
.
device
)
src_mask
=
tmp_src_mask
work
.
wait
()
for
i
,
m
in
enumerate
(
fw_dst_masks
):
src_mask
[
self
.
bw_table
(
i
)[
0
]]
|=
m
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
else
:
else
:
for
i
in
range
(
self
.
num_parts
):
tmp_src_mask
&=
src_mask
m
=
src_mask
[
self
.
bw_table
(
i
)[
0
]]
src_mask
=
tmp_src_mask
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
dst_mask
=
self
.
bw_tensor
(
src_mask
.
long
())
!=
0
fw_ptr
,
fw_ind
=
Route
.
__filter_ind_and_ptr
(
self
.
_fw_ptr
,
self
.
_fw_ind
,
dst_mask
)
return
dst_mask
,
Route
(
bw_ptr
,
bw_ind
=
Route
.
__filter_ind_and_ptr
(
self
.
_bw_ptr
,
self
.
_bw_ind
,
src_mask
)
route
=
Route
(
src_len
=
self
.
src_len
,
src_len
=
self
.
src_len
,
dst_len
=
self
.
dst_len
,
dst_len
=
self
.
dst_len
,
**
Route
.
__tables_to_indptr
(
fw_tables
,
bw_tables
),
fw_ptr
=
fw_ptr
,
fw_ind
=
fw_ind
,
bw_ptr
=
bw_ptr
,
bw_ind
=
bw_ind
,
group
=
self
.
group
,
group
=
self
.
group
,
),
src_mask
,
dst_mask
)
if
remap
:
fw_ind
,
dst_len
=
Route
.
__remap_ind
(
route
.
_fw_ind
,
dst_mask
)
bw_ind
,
src_len
=
Route
.
__remap_ind
(
route
.
_bw_ind
,
src_mask
)
route
=
Route
(
src_len
=
src_len
,
dst_len
=
dst_len
,
fw_ptr
=
route
.
_fw_ptr
,
fw_ind
=
fw_ind
,
bw_ptr
=
route
.
_bw_ptr
,
bw_ind
=
bw_ind
,
group
=
route
.
group
,
)
return
dst_mask
,
src_mask
,
route
def
rev
erse_route
(
self
):
def
rev
(
self
):
return
Route
(
return
Route
(
src_len
=
self
.
dst_len
,
src_len
=
self
.
dst_len
,
dst_len
=
self
.
src_len
,
dst_len
=
self
.
src_len
,
...
@@ -83,63 +89,22 @@ class Route:
...
@@ -83,63 +89,22 @@ class Route:
group
=
self
.
group
,
group
=
self
.
group
,
)
)
def
filter_nodes
(
self
,
dst_mask
:
Tensor
,
src_mask
:
Optional
[
Tensor
]
=
None
,
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
new_dst_len
=
dst_mask
.
count_nonzero
()
.
item
()
xmp
=
torch
.
empty
(
self
.
dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_mask
]
=
torch
.
arange
(
new_dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
fw_dst_masks
,
fw_m_work
=
self
.
all_to_all_fw
(
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def
filter_edges
(
self
,
edge_mask
:
Tensor
,
edge_index
:
Tensor
,
):
pass
def
__init__
(
self
,
def
__init__
(
self
,
src_len
:
int
,
dst_len
:
int
,
src_len
:
int
,
dst_len
:
int
,
fw_ptr
:
List
[
int
],
fw_ind
:
Tensor
,
fw_ptr
:
List
[
int
],
fw_ind
:
Tensor
,
bw_ptr
:
List
[
int
],
bw_ind
:
Tensor
,
bw_ptr
:
List
[
int
],
bw_ind
:
Tensor
,
group
:
Any
,
group
:
Any
,
)
->
None
:
)
->
None
:
self
.
_ctx
=
DistributedContext
.
get_default_context
(
)
assert
len
(
fw_ptr
)
==
len
(
bw_ptr
)
self
.
_src_len
=
src_len
self
.
_src_len
=
src_len
self
.
_dst_len
=
dst_len
self
.
_dst_len
=
dst_len
self
.
_fw_ptr
=
fw_ptr
self
.
_fw_ptr
=
tuple
(
fw_ptr
)
self
.
_fw_ind
=
fw_ind
self
.
_fw_ind
=
fw_ind
self
.
_bw_ptr
=
bw_ptr
self
.
_bw_ptr
=
tuple
(
bw_ptr
)
self
.
_bw_ind
=
bw_ind
self
.
_bw_ind
=
bw_ind
self
.
_group
=
group
self
.
_group
=
group
@property
@property
def
ctx
(
self
):
return
self
.
_ctx
@property
def
group
(
self
):
def
group
(
self
):
return
self
.
_group
return
self
.
_group
...
@@ -164,104 +129,84 @@ class Route:
...
@@ -164,104 +129,84 @@ class Route:
self
.
_bw_ind
=
self
.
_bw_ind
.
to
(
device
)
self
.
_bw_ind
=
self
.
_bw_ind
.
to
(
device
)
return
self
return
self
def
fw_table
(
self
,
i
:
int
):
#
def fw_table(self, i: int):
return
self
.
_fw_ind
[:,
self
.
_fw_ptr
[
i
]:
self
.
_fw_ptr
[
i
+
1
]]
# return self._fw_ind[
self._fw_ptr[i]:self._fw_ptr[i+1]]
def
bw_table
(
self
,
i
:
int
):
#
def bw_table(self, i: int):
return
self
.
_bw_ind
[:,
self
.
_bw_ptr
[
i
]:
self
.
_bw_ptr
[
i
+
1
]]
# return self._bw_ind[
self._bw_ptr[i]:self._bw_ptr[i+1]]
def
apply
(
self
,
data
:
Tensor
)
->
Tensor
:
def
apply
(
self
,
return
RouteAlltoAll
.
apply
(
data
,
self
)
data
:
Tensor
,
cache
:
Optional
[
'RouteWorkCache'
]
=
None
,
cache_key
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
return
RouteAlltoAll
.
apply
(
data
,
self
,
cache
,
cache_key
)
def
fw_tensor
(
self
,
data
:
Tensor
)
->
Tensor
:
@torch.no_grad
()
def
fw_tensor
(
self
,
data
:
Tensor
,
async_op
:
bool
=
False
):
assert
data
.
size
(
0
)
==
self
.
dst_len
assert
data
.
size
(
0
)
==
self
.
dst_len
xs
=
self
.
all_to_all_fw
(
[
data
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
False
,
)
out
=
torch
.
zeros
(
out
put_tensor
=
torch
.
empty
(
self
.
src_len
,
*
data
.
shape
[
1
:],
self
.
_bw_ind
.
numel
()
,
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
,
dtype
=
data
.
dtype
,
device
=
data
.
device
,
)
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
bw_table
(
i
)[
0
]]
+=
t
return
out
def
bw_tensor
(
self
,
data
:
Tensor
)
->
Tensor
:
work
=
all_to_all_s
(
assert
data
.
size
(
0
)
==
self
.
src_len
output_tensor
,
data
[
self
.
_fw_ind
],
xs
=
self
.
all_to_all_bw
(
self
.
_bw_ptr
,
self
.
_fw_ptr
,
[
data
[
self
.
bw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)]
,
group
=
self
.
group
,
async_op
=
False
,
async_op
=
async_op
,
)
)
work
=
RouteWork
(
work
if
async_op
else
None
,
self
.
_bw_ptr
,
self
.
_bw_ind
,
self
.
src_len
,
output_tensor
,
)
return
work
if
async_op
else
work
.
wait
()
out
=
torch
.
zeros
(
@torch.no_grad
()
self
.
dst_len
,
*
data
.
shape
[
1
:],
def
bw_tensor
(
self
,
data
:
Tensor
,
async_op
:
bool
=
False
):
assert
data
.
size
(
0
)
==
self
.
src_len
output_tensor
=
torch
.
empty
(
self
.
_fw_ind
.
numel
(),
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
,
dtype
=
data
.
dtype
,
device
=
data
.
device
,
)
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
fw_table
(
i
)[
0
]]
+=
t
return
out
def
get_src_part_ids
(
self
)
->
Tensor
:
work
=
all_to_all_s
(
xs
=
self
.
all_to_all_fw
(
output_tensor
,
data
[
self
.
_bw_ind
],
[
self
.
_fw_ptr
,
self
.
_bw_ptr
,
torch
.
full
(
group
=
self
.
group
,
(
self
.
fw_table
(
i
)
.
size
(
1
),),
self
.
part_id
,
async_op
=
async_op
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
,
)
for
i
in
range
(
self
.
num_parts
)
],
async_op
=
False
,
)
)
out
=
torch
.
full
((
self
.
src_len
,),
2
**
16
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
)
work
=
RouteWork
(
for
i
,
t
in
enumerate
(
xs
):
work
if
async_op
else
None
,
out
[
self
.
bw_table
(
i
)[
0
]]
=
t
self
.
_fw_ptr
,
self
.
_fw_ind
,
return
out
.
int
()
&
0xFFFF
self
.
dst_len
,
output_tensor
,
def
all_to_all_fw
(
self
,
input_tensor_list
:
List
[
Tensor
],
async_op
:
bool
=
False
):
output_tensor_list
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
t
=
input_tensor_list
[
i
]
assert
t
.
size
(
0
)
==
self
.
_fw_ptr
[
i
+
1
]
-
self
.
_fw_ptr
[
i
]
s
=
self
.
_bw_ptr
[
i
+
1
]
-
self
.
_bw_ptr
[
i
]
output_tensor_list
.
append
(
torch
.
empty
(
s
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
)
return
work
if
async_op
else
work
.
wait
()
work
=
self
.
ctx
.
all_to_all_v
(
@torch.no_grad
()
output_tensor_list
,
def
get_src_part_ids
(
self
)
->
Tensor
:
input_tensor_list
,
input_tensor
=
torch
.
full_like
(
self
.
_fw_ind
,
self
.
part_id
)
group
=
self
.
group
,
async_op
=
async_op
,
output_tensor
=
torch
.
empty_like
(
self
.
_bw_ind
)
)
if
async_op
:
return
output_tensor_list
,
work
else
:
return
output_tensor_list
def
all_to_all_bw
(
self
,
input_tensor_list
:
List
[
Tensor
],
async_op
:
bool
=
False
):
output_tensor_list
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
t
=
input_tensor_list
[
i
]
assert
t
.
size
(
0
)
==
self
.
_bw_ptr
[
i
+
1
]
-
self
.
_bw_ptr
[
i
]
s
=
self
.
_fw_ptr
[
i
+
1
]
-
self
.
_fw_ptr
[
i
]
all_to_all_s
(
output_tensor_list
.
append
(
output_tensor
,
input_tensor
,
torch
.
empty
(
s
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
self
.
_bw_ptr
,
self
.
_fw_ptr
,
group
=
self
.
group
,
)
)
work
=
self
.
ctx
.
all_to_all_v
(
out
=
torch
.
full
(
output_tensor_list
,
(
self
.
src_len
,),
2
**
16
-
1
,
input_tensor_list
,
dtype
=
self
.
_bw_ind
.
dtype
,
group
=
self
.
group
,
async_op
=
async_op
,
device
=
self
.
_bw_ind
.
device
,
)
)
for
s
,
t
in
zip
(
self
.
_bw_ptr
,
self
.
_bw_ptr
[
1
:]):
if
async_op
:
ind
=
self
.
_bw_ind
[
s
:
t
]
return
output_tensor_list
,
work
assert
(
out
[
ind
]
==
2
**
16
-
1
)
.
all
(),
f
"some vertices exist in more than one partition"
else
:
out
[
ind
]
=
output_tensor
[
s
:
t
]
&
0xFF
return
output_tensor_lis
t
return
ou
t
@staticmethod
@staticmethod
def
_build_route_tables
(
def
_build_route_tables
(
...
@@ -275,8 +220,6 @@ class Route:
...
@@ -275,8 +220,6 @@ class Route:
assert
src_ids
.
dim
()
==
1
assert
src_ids
.
dim
()
==
1
assert
dst_ids
.
dim
()
==
1
assert
dst_ids
.
dim
()
==
1
ctx
=
DistributedContext
.
get_default_context
()
src_len
=
src_ids
.
size
(
0
)
src_len
=
src_ids
.
size
(
0
)
dst_len
=
dst_ids
.
size
(
0
)
dst_len
=
dst_ids
.
size
(
0
)
...
@@ -292,6 +235,7 @@ class Route:
...
@@ -292,6 +235,7 @@ class Route:
all_dst_lens
:
List
[
int
]
=
[
None
]
*
world_size
all_dst_lens
:
List
[
int
]
=
[
None
]
*
world_size
dist
.
all_gather_object
(
all_dst_lens
,
dst_len
,
group
=
group
)
dist
.
all_gather_object
(
all_dst_lens
,
dst_len
,
group
=
group
)
# all_reduce number of nodes
num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
src_ids
.
numel
()
>
0
:
if
src_ids
.
numel
()
>
0
:
num_nodes
=
src_ids
.
max
()
.
max
(
num_nodes
)
num_nodes
=
src_ids
.
max
()
.
max
(
num_nodes
)
...
@@ -342,25 +286,54 @@ class Route:
...
@@ -342,25 +286,54 @@ class Route:
src_ind
=
smp
[
ind
]
src_ind
=
smp
[
ind
]
dst_ind
=
xmp
[
ind
]
dst_ind
=
xmp
[
ind
]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_tables
.
append
(
dst_ind
)
fw_tables
.
append
(
torch
.
vstack
([
dst_ind
,
src_ind
]))
bw_tables
.
append
(
src_ind
)
bw_tables
.
append
(
torch
.
vstack
([
src_ind
,
dst_ind
]))
fw_tables
=
[
t
.
t
()
.
contiguous
()
for
t
in
fw_tables
]
fw_tables
=
Route
.
__backward_fw_tables
(
fw_tables
,
group
=
group
)
fw_tables
=
ctx
.
all_to_all_g
(
fw_tables
,
group
=
group
)
fw_tables
=
[
t
.
t
()
.
contiguous
()
for
t
in
fw_tables
]
#
非二分图,每个点添加自环
#
add self-loops if not bipartite graph
if
not
bipartite
:
if
not
bipartite
:
rank_ind
=
torch
.
arange
(
dst_len
,
**
ikw
)
rank_ind
=
torch
.
arange
(
dst_len
,
**
ikw
)
fw_tables
[
rank
]
=
bw_tables
[
rank
]
=
torch
.
vstack
([
rank_ind
,
rank_ind
])
fw_tables
[
rank
]
=
bw_tables
[
rank
]
=
rank_ind
return
fw_tables
,
bw_tables
return
fw_tables
,
bw_tables
@staticmethod
@staticmethod
def
__expand_index_to_mask
(
index
:
Tensor
,
len
:
int
)
->
Tensor
:
def
__filter_ind_and_ptr
(
ptr
:
List
[
int
],
ind
:
Tensor
,
mask
:
Tensor
)
->
Tuple
[
List
[
int
],
Tensor
]:
mask
=
torch
.
zeros
(
len
,
dtype
=
torch
.
bool
,
device
=
index
.
device
)
m
=
mask
[
ind
]
mask
[
index
]
=
True
new_ptr
:
List
[
int
]
=
[
0
]
return
mask
new_ind
:
List
[
Tensor
]
=
[]
for
s
,
t
in
zip
(
ptr
,
ptr
[
1
:]):
new_ind
.
append
(
ind
[
s
:
t
][
m
[
s
:
t
]])
new_ptr
.
append
(
new_ptr
[
-
1
]
+
new_ind
[
-
1
]
.
numel
())
return
new_ptr
,
torch
.
cat
(
new_ind
,
dim
=
0
)
@staticmethod
def
__remap_ind
(
ind
:
Tensor
,
mask
:
Tensor
)
->
Tuple
[
Tensor
,
int
]:
n
:
int
=
mask
.
count_nonzero
()
.
item
()
imp
=
torch
.
full
((
mask
.
numel
(),),
(
2
**
62
-
1
)
*
2
+
1
,
dtype
=
ind
.
dtype
,
device
=
ind
.
device
)
imp
[
mask
]
=
torch
.
arange
(
n
,
dtype
=
ind
.
dtype
,
device
=
ind
.
device
)
return
ind
,
int
(
n
)
@staticmethod
def
__backward_fw_tables
(
fw_tables
:
List
[
Tensor
],
group
:
Any
,
)
->
List
[
Tensor
]:
rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
send_sizes
=
[
t
.
size
()
for
t
in
fw_tables
]
recv_sizes
=
[
None
]
*
world_size
dist
.
all_gather_object
(
recv_sizes
,
send_sizes
,
group
=
group
)
recv_sizes
=
[
s
[
rank
]
for
s
in
recv_sizes
]
fixed_tables
=
[]
for
s
,
t
in
zip
(
recv_sizes
,
fw_tables
):
t
=
torch
.
empty
(
*
s
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
fixed_tables
.
append
(
t
)
all_to_all_v
(
fixed_tables
,
fw_tables
,
group
=
group
)
return
fixed_tables
@staticmethod
@staticmethod
def
__tables_to_indptr
(
def
__tables_to_indptr
(
...
@@ -369,21 +342,101 @@ class Route:
...
@@ -369,21 +342,101 @@ class Route:
):
):
fw_ptr
:
List
[
int
]
=
[
0
]
fw_ptr
:
List
[
int
]
=
[
0
]
for
t
in
fw_tables
:
for
t
in
fw_tables
:
last_n
=
fw_ptr
[
-
1
]
assert
t
.
dim
()
==
1
fw_ptr
.
append
(
last_n
+
t
.
size
(
1
))
fw_ptr
.
append
(
fw_ptr
[
-
1
]
+
t
.
numel
(
))
fw_ind
=
torch
.
cat
(
fw_tables
,
dim
=
1
)
fw_ind
=
torch
.
cat
(
fw_tables
,
dim
=
0
)
bw_ptr
:
List
[
int
]
=
[
0
]
bw_ptr
:
List
[
int
]
=
[
0
]
for
t
in
bw_tables
:
for
t
in
bw_tables
:
last_n
=
bw_ptr
[
-
1
]
assert
t
.
dim
()
==
1
bw_ptr
.
append
(
last_n
+
t
.
size
(
1
))
bw_ptr
.
append
(
bw_ptr
[
-
1
]
+
t
.
numel
(
))
bw_ind
=
torch
.
cat
(
bw_tables
,
dim
=
1
)
bw_ind
=
torch
.
cat
(
bw_tables
,
dim
=
0
)
return
{
return
{
"fw_ptr"
:
fw_ptr
,
"bw_ptr"
:
bw_ptr
,
"fw_ptr"
:
fw_ptr
,
"bw_ptr"
:
bw_ptr
,
"fw_ind"
:
fw_ind
,
"bw_ind"
:
bw_ind
,
"fw_ind"
:
fw_ind
,
"bw_ind"
:
bw_ind
,
}
}
class
RouteWork
:
def
__init__
(
self
,
work
:
Any
,
ptr
:
List
[
int
],
ind
:
Tensor
,
len
:
int
,
recv_t
:
Tensor
,
)
->
None
:
self
.
_work
=
work
self
.
_ptr
=
ptr
self
.
_ind
=
ind
self
.
_len
=
len
self
.
_recv_t
=
recv_t
if
self
.
_work
is
None
:
self
.
_reduce
()
def
_reduce
(
self
):
out
=
torch
.
zeros
(
self
.
_len
,
*
self
.
_recv_t
.
shape
[
1
:],
dtype
=
self
.
_recv_t
.
dtype
,
device
=
self
.
_recv_t
.
device
,
)
for
s
,
t
in
zip
(
self
.
_ptr
,
self
.
_ptr
[
1
:]):
ind
=
self
.
_ind
[
s
:
t
]
out
[
ind
]
+=
self
.
_recv_t
[
s
:
t
]
self
.
_work
=
None
self
.
_ptr
=
None
self
.
_ind
=
None
self
.
_len
=
None
self
.
_recv_t
=
out
def
wait
(
self
)
->
Tensor
:
if
self
.
_work
is
None
:
return
self
.
_recv_t
self
.
_work
.
wait
()
self
.
_reduce
()
return
self
.
_recv_t
class
RouteWorkCache
:
def
__init__
(
self
,
enable_fw
:
bool
=
True
,
enable_bw
:
bool
=
True
,
)
->
None
:
self
.
enable_fw
=
enable_fw
self
.
enable_bw
=
enable_bw
self
.
_cached_works
:
Dict
[
str
,
RouteWork
]
=
{}
def
enable_fw_
(
self
,
enable
:
bool
=
True
):
self
.
enable_fw
=
enable
return
self
def
enable_bw_
(
self
,
enable
:
bool
=
True
):
self
.
enable_bw
=
enable
return
self
def
wait
(
self
):
for
work
in
self
.
_cached_works
.
values
():
work
.
wait
()
def
clear
(
self
):
self
.
_cached_works
.
clear
()
def
get_and_set
(
self
,
key
:
str
,
work
:
RouteWork
,
bw
:
bool
=
False
,
)
->
Optional
[
RouteWork
]:
if
bw
and
self
.
enable_bw
:
key
=
key
+
"_bw"
elif
not
bw
and
self
.
enable_fw
:
key
=
key
+
"_fw"
else
:
return
work
t
=
self
.
_cached_works
.
get
(
key
,
work
)
self
.
_cached_works
[
key
]
=
work
return
t
class
RouteAlltoAll
(
autograd
.
Function
):
class
RouteAlltoAll
(
autograd
.
Function
):
@staticmethod
@staticmethod
...
@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function):
...
@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function):
ctx
:
autograd
.
function
.
FunctionCtx
,
ctx
:
autograd
.
function
.
FunctionCtx
,
x
:
Tensor
,
x
:
Tensor
,
route
:
Route
,
route
:
Route
,
cache
:
Optional
[
RouteWorkCache
],
cache_key
:
Optional
[
str
],
):
):
ctx
.
saved_route
=
route
ctx
.
saved_route
=
route
ctx
.
saved_cache
=
cache
ctx
.
saved_cache_key
=
cache_key
if
cache
is
None
or
cache_key
is
None
:
return
route
.
fw_tensor
(
x
)
return
route
.
fw_tensor
(
x
)
else
:
work
=
route
.
fw_tensor
(
x
,
async_op
=
True
)
work
=
cache
.
get_and_set
(
cache_key
,
work
,
bw
=
False
)
return
work
.
wait
()
@staticmethod
@staticmethod
def
backward
(
def
backward
(
...
@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function):
...
@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function):
grad
:
Tensor
,
grad
:
Tensor
,
)
->
Tensor
:
)
->
Tensor
:
route
:
Route
=
ctx
.
saved_route
route
:
Route
=
ctx
.
saved_route
return
route
.
bw_tensor
(
grad
),
None
cache
:
Optional
[
RouteWorkCache
]
=
ctx
.
saved_cache
\ No newline at end of file
cache_key
:
Optional
[
str
]
=
ctx
.
saved_cache_key
if
cache
is
None
or
cache_key
is
None
:
return
route
.
bw_tensor
(
grad
),
None
,
None
,
None
else
:
work
=
route
.
bw_tensor
(
grad
,
async_op
=
True
)
work
=
cache
.
get_and_set
(
cache_key
,
work
,
bw
=
True
)
return
work
.
wait
(),
None
,
None
,
None
\ No newline at end of file
starrygl/graph/utils.py
deleted
100644 → 0
View file @
10c38111
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
def
init_vc_edge_index
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
,
)
->
Tuple
[
Tensor
,
Tensor
]:
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
dst_ids
.
device
)
local_num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
dst_ids
.
numel
()
>
0
:
local_num_nodes
=
dst_ids
.
max
()
.
max
(
local_num_nodes
)
if
edge_index
.
numel
()
>
0
:
local_num_nodes
=
edge_index
.
max
()
.
max
(
local_num_nodes
)
local_num_nodes
=
local_num_nodes
.
item
()
+
1
xmp
:
Tensor
=
torch
.
zeros
(
local_num_nodes
,
**
ikw
)
xmp
[
edge_index
[
1
]
.
unique
()]
+=
0
b01
xmp
[
dst_ids
.
unique
()]
+=
0
b10
if
not
(
xmp
!=
0x01
)
.
all
():
raise
RuntimeError
(
f
"must be vertex-cut partition graph"
)
if
bipartite
:
src_ids
=
edge_index
[
0
]
.
unique
()
else
:
xmp
.
fill_
(
0
)
xmp
[
edge_index
[
0
]]
=
1
xmp
[
dst_ids
]
=
0
src_ids
=
torch
.
cat
([
dst_ids
,
torch
.
where
(
xmp
>
0
)[
0
]],
dim
=-
1
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
src_ids
]
=
torch
.
arange
(
src_ids
.
size
(
0
),
**
ikw
)
src
=
xmp
[
edge_index
[
0
]]
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_ids
]
=
torch
.
arange
(
dst_ids
.
size
(
0
),
**
ikw
)
dst
=
xmp
[
edge_index
[
1
]]
local_edge_index
=
torch
.
vstack
([
src
,
dst
])
return
src_ids
,
local_edge_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment