Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
9586742a
Commit
9586742a
authored
Dec 19, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bugs. Route
parent
10c38111
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
552 additions
and
365 deletions
+552
-365
cora.py
+5
-4
run_route.py
+25
-10
starrygl/distributed/cclib.py
+0
-2
starrygl/graph/__init__.py
+3
-24
starrygl/graph/data.py
+258
-88
starrygl/graph/route.py
+261
-191
starrygl/graph/utils.py
+0
-46
No files found.
cora.py
View file @
9586742a
...
...
@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import
os.path
as
osp
import
sys
from
starrygl.
utils.data
import
partition_pyg
from
starrygl.
graph
import
GraphData
import
logging
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
...
...
@@ -18,7 +18,9 @@ if __name__ == "__main__":
print
(
f
"num_nodes: {data.num_nodes}"
)
print
(
f
"num_edges: {data.num_edges}"
)
print
(
f
"num_features: {data.num_features}"
)
data
=
GraphData
.
from_pyg_data
(
data
)
num_parts_list
=
[
1
,
2
,
3
,
5
,
7
,
9
,
11
]
algos
=
[
"metis"
,
'mt-metis'
,
"random"
]
...
...
@@ -27,4 +29,4 @@ if __name__ == "__main__":
for
num_parts
in
num_parts_list
:
for
algo
in
algos
:
print
(
f
"======== {num_parts} + {algo} ========"
)
partition_pyg
(
root
,
data
,
num_parts
,
algo
)
\ No newline at end of file
data
.
save_partition
(
root
,
num_parts
,
algo
)
run_route.py
View file @
9586742a
...
...
@@ -5,7 +5,7 @@ from torch import Tensor
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.graph
import
new_vc_route
from
starrygl.graph
import
*
from
torch_scatter
import
scatter_sum
...
...
@@ -28,32 +28,38 @@ all_eparts = [
],
]
def
get_route
(
bipartite
:
bool
=
True
):
def
get_data
():
ctx
=
DistributedContext
.
get_default_context
()
assert
ctx
.
world_size
==
3
dst_ids
=
torch
.
tensor
(
all_nparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
edge_index
=
torch
.
tensor
(
all_eparts
[
ctx
.
rank
],
dtype
=
torch
.
long
,
device
=
ctx
.
device
)
.
t
()
return
new_vc_route
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
src_ids
,
edge_index
=
init_vc_edge_index
(
dst_ids
,
edge_index
)
return
GraphData
.
from_bipartite
(
edge_index
,
raw_src_ids
=
src_ids
,
raw_dst_ids
=
dst_ids
)
if
__name__
==
"__main__"
:
ctx
=
DistributedContext
.
init
(
backend
=
"gloo"
,
use_gpu
=
True
)
src_ids
,
edge_index
,
dst_ids
,
route
=
get_route
(
False
)
src_size
=
route
.
src_len
dst_size
=
route
.
dst_len
g
=
get_data
()
route
=
g
.
to_route
()
edge_index
=
g
.
edge_index
()
# src_ids, edge_index, dst_ids, route = get_route(False)
# src_size = route.src_len
# dst_size = route.dst_len
ctx
.
sync_print
(
route
.
src_len
,
route
.
dst_len
)
ctx
.
sync_print
(
route
.
_fw_ptr
,
route
.
_fw_ind
)
ctx
.
sync_print
(
route
.
_bw_ptr
,
route
.
_bw_ind
)
edge_ones
=
torch
.
ones
(
edge_index
.
size
(
1
),
device
=
ctx
.
device
)
.
requires_grad_
()
src_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
route
.
src_len
)
dst_ones
=
scatter_sum
(
edge_ones
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
route
.
dst_len
)
#
ctx.sync_print(route.fw_tensor(dst_ones))
#
ctx.sync_print(route.bw_tensor(src_ones))
ctx
.
sync_print
(
route
.
fw_tensor
(
dst_ones
))
ctx
.
sync_print
(
route
.
bw_tensor
(
src_ones
))
out
=
route
.
rev
erse_route
()
.
apply
(
src_ones
)
out
=
route
.
rev
()
.
apply
(
src_ones
)
ctx
.
sync_print
(
out
)
out
.
sum
()
.
backward
()
...
...
@@ -61,4 +67,13 @@ if __name__ == "__main__":
ctx
.
sync_print
(
route
.
get_src_part_ids
())
dst_mask
=
torch
.
full
((
route
.
dst_len
,),
ctx
.
rank
%
2
,
dtype
=
torch
.
bool
,
device
=
ctx
.
device
)
ctx
.
main_print
(
"="
*
64
)
ctx
.
sync_print
(
dst_mask
)
_
,
_
,
r2
=
route
.
filter
(
dst_mask
)
ctx
.
sync_print
(
r2
.
apply
(
dst_ones
)
.
detach
())
ctx
.
sync_print
(
r2
.
rev
()
.
apply
(
src_ones
)
.
detach
())
# dst_true = torch.ones(route.dst_len, dtype=torch.float, device=ctx.device)
# ctx.sync_print(route.fw_tensor(dst_true, "max"))
ctx
.
shutdown
()
starrygl/distributed/cclib.py
View file @
9586742a
...
...
@@ -45,8 +45,6 @@ def all_to_all_v(
assert
len
(
output_tensor_list
)
==
world_size
assert
len
(
input_tensor_list
)
==
world_size
# if group is None:
# group = dist.distributed_c10d._get_default_group()
backend
=
dist
.
get_backend
(
group
)
if
backend
==
"nccl"
:
...
...
starrygl/graph/__init__.py
View file @
9586742a
from
.route
import
Route
from
.utils
import
init_vc_edge_index
from
torch
import
Tensor
from
typing
import
Tuple
__all__
=
[
"Route"
,
"init_vc_edge_index"
,
"new_vc_route"
,
]
def
new_vc_route
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Route
]:
src_ids
,
local_edge_index
=
init_vc_edge_index
(
dst_ids
,
edge_index
,
bipartite
=
bipartite
)
route
=
Route
.
from_raw_indices
(
src_ids
,
dst_ids
,
bipartite
=
bipartite
)
return
src_ids
,
local_edge_index
,
dst_ids
,
route
from
.data
import
*
from
.route
import
*
\ No newline at end of file
starrygl/
utils
/data.py
→
starrygl/
graph
/data.py
View file @
9586742a
...
...
@@ -3,19 +3,18 @@ import torch
from
torch
import
Tensor
from
typing
import
*
import
os
import
os.path
as
osp
import
shutil
from
pathlib
import
Path
from
torch_sparse
import
SparseTensor
from
.partition
import
*
from
starrygl.utils.partition
import
*
from
.route
import
Route
import
logging
__all__
=
[
"GraphData"
,
"partition_pyg"
,
"partition_load"
,
"init_vc_edge_index"
,
]
...
...
@@ -75,6 +74,17 @@ class GraphData:
return
data
return
self
.
_edge_indices
[
edge_type
]
def
node_types
(
self
)
->
List
[
str
]:
return
list
(
self
.
_node_data
.
keys
())
def
edge_types
(
self
)
->
List
[
Tuple
[
str
,
str
,
str
]]:
return
list
(
self
.
_edge_data
.
keys
())
def
to_route
(
self
,
group
:
Any
=
None
)
->
Route
:
src_ids
=
self
.
node
(
"src"
)[
"raw_ids"
]
dst_ids
=
self
.
node
(
"dst"
)[
"raw_ids"
]
return
Route
.
from_raw_indices
(
src_ids
,
dst_ids
,
group
=
group
)
@property
def
is_heterogeneous
(
self
)
->
bool
:
return
self
.
_heterogeneous
...
...
@@ -91,6 +101,126 @@ class GraphData:
self
.
_edge_indices
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
self
.
_edge_indices
.
items
()}
return
self
@staticmethod
def
from_bipartite
(
edge_index
:
Tensor
,
num_src_nodes
:
Optional
[
int
]
=
None
,
num_dst_nodes
:
Optional
[
int
]
=
None
,
raw_src_ids
:
Optional
[
Tensor
]
=
None
,
raw_dst_ids
:
Optional
[
Tensor
]
=
None
,
)
->
'GraphData'
:
if
num_src_nodes
is
None
:
num_src_nodes
=
raw_src_ids
.
numel
()
if
num_dst_nodes
is
None
:
num_dst_nodes
=
raw_dst_ids
.
numel
()
g
=
GraphData
(
edge_indices
=
{
(
"src"
,
"@"
,
"dst"
):
edge_index
,
},
num_nodes
=
{
"src"
:
num_src_nodes
,
"dst"
:
num_dst_nodes
,
}
)
if
raw_src_ids
is
not
None
:
g
.
node
(
"src"
)[
"raw_ids"
]
=
raw_src_ids
if
raw_dst_ids
is
not
None
:
g
.
node
(
"dst"
)[
"raw_ids"
]
=
raw_dst_ids
return
g
@staticmethod
def
from_pyg_data
(
data
)
->
'GraphData'
:
from
torch_geometric.data
import
Data
assert
isinstance
(
data
,
Data
),
f
"must be Data class in pyg"
g
=
GraphData
(
data
.
edge_index
,
data
.
num_nodes
)
for
key
,
val
in
data
:
if
key
==
"edge_index"
:
continue
elif
isinstance
(
val
,
Tensor
):
if
val
.
size
(
0
)
==
data
.
num_nodes
:
g
.
node
()[
key
]
=
val
elif
val
.
size
(
0
)
==
data
.
num_edges
:
g
.
edge
()[
key
]
=
val
elif
isinstance
(
val
,
SparseTensor
):
logging
.
warning
(
f
"found sparse matrix {key}, but ignored."
)
else
:
g
.
meta
()[
key
]
=
val
return
g
@staticmethod
def
load_partition
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
'GraphData'
:
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algo}_{num_parts}"
/
f
"{part_id:03d}"
return
torch
.
load
(
p
.
__str__
())
def
save_partition
(
self
,
root
:
str
,
num_parts
:
int
,
algo
:
str
=
"metis"
):
assert
not
self
.
is_heterogeneous
,
"only support homomorphic graph"
num_nodes
:
int
=
self
.
node
()
.
num_nodes
edge_index
:
Tensor
=
self
.
edge_index
()
logging
.
info
(
f
"running partition aglorithm: {algo}"
)
if
algo
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
)
else
:
raise
ValueError
(
f
"unknown partition algorithm: {algo}"
)
root_path
=
Path
(
root
)
.
expanduser
()
.
resolve
()
base_path
=
root_path
/
f
"{algo}_{num_parts}"
if
base_path
.
exists
():
logging
.
warning
(
f
"directory '{base_path.__str__()}' exists, and will be removed."
)
shutil
.
rmtree
(
base_path
.
__str__
())
base_path
.
mkdir
(
parents
=
True
)
for
i
in
range
(
num_parts
):
npart_mask
=
node_parts
==
i
epart_mask
=
npart_mask
[
edge_index
[
1
]]
raw_dst_ids
:
Tensor
=
torch
.
where
(
npart_mask
)[
0
]
local_edges
=
edge_index
[:,
epart_mask
]
raw_src_ids
,
local_edges
=
init_vc_edge_index
(
raw_dst_ids
,
local_edges
,
bipartite
=
True
,
)
# g = GraphData(
# edge_indices={
# ("src", "@", "dst"): local_edges,
# },
# num_nodes={
# "src": raw_src_ids.numel(),
# "dst": raw_dst_ids.numel(),
# }
# )
g
=
GraphData
.
from_bipartite
(
local_edges
,
raw_src_ids
=
raw_src_ids
,
raw_dst_ids
=
raw_dst_ids
,
)
for
key
in
self
.
node
()
.
keys
():
g
.
node
(
"dst"
)[
key
]
=
self
.
node
()[
key
][
npart_mask
]
for
key
in
self
.
edge
()
.
keys
():
g
.
edge
()[
key
]
=
self
.
edge
()[
key
][
epart_mask
]
for
key
in
self
.
meta
()
.
keys
():
g
.
meta
()[
key
]
=
self
.
meta
()[
key
]
logging
.
info
(
f
"saving partition data: {i+1}/{num_parts}"
)
torch
.
save
(
g
,
(
base_path
/
f
"{i:03d}"
)
.
__str__
())
class
MetaData
:
def
__init__
(
self
)
->
None
:
...
...
@@ -193,97 +323,137 @@ class EdgeData:
return
self
def
partition_load
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
GraphData
:
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algo}_{num_parts}"
/
f
"{part_id:03d}"
return
torch
.
load
(
p
.
__str__
())
def
init_vc_edge_index
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
,
)
->
Tuple
[
Tensor
,
Tensor
]:
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
dst_ids
.
device
)
local_num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
dst_ids
.
numel
()
>
0
:
local_num_nodes
=
dst_ids
.
max
()
.
max
(
local_num_nodes
)
if
edge_index
.
numel
()
>
0
:
local_num_nodes
=
edge_index
.
max
()
.
max
(
local_num_nodes
)
local_num_nodes
=
local_num_nodes
.
item
()
+
1
xmp
:
Tensor
=
torch
.
zeros
(
local_num_nodes
,
**
ikw
)
xmp
[
edge_index
[
1
]
.
unique
()]
+=
0
b01
xmp
[
dst_ids
.
unique
()]
+=
0
b10
if
not
(
xmp
!=
0x01
)
.
all
():
raise
RuntimeError
(
f
"must be vertex-cut partition graph"
)
if
bipartite
:
src_ids
=
edge_index
[
0
]
.
unique
()
else
:
xmp
.
fill_
(
0
)
xmp
[
edge_index
[
0
]]
=
1
xmp
[
dst_ids
]
=
0
src_ids
=
torch
.
cat
([
dst_ids
,
torch
.
where
(
xmp
>
0
)[
0
]],
dim
=-
1
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
src_ids
]
=
torch
.
arange
(
src_ids
.
size
(
0
),
**
ikw
)
src
=
xmp
[
edge_index
[
0
]]
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_ids
]
=
torch
.
arange
(
dst_ids
.
size
(
0
),
**
ikw
)
dst
=
xmp
[
edge_index
[
1
]]
local_edge_index
=
torch
.
vstack
([
src
,
dst
])
return
src_ids
,
local_edge_index
def
partition_pyg
(
root
:
str
,
data
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
:
root_path
=
Path
(
root
)
.
expanduser
()
.
resolve
()
base_path
=
root_path
/
f
"{algo}_{num_parts}"
# def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData
:
# p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
# return torch.load(p.__str__())
if
base_path
.
exists
():
shutil
.
rmtree
(
base_path
.
__str__
())
base_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
# root_path = Path(root).expanduser().resolve()
# base_path = root_path / f"{algo}_{num_parts}"
# if base_path.exists():
# shutil.rmtree(base_path.__str__())
# base_path.mkdir(parents=True, exist_ok=True)
for
i
,
g
in
enumerate
(
partition_pyg_data
(
data
,
num_parts
,
algo
)):
logging
.
info
(
f
"saving partition data: {i+1}/{num_parts}"
)
torch
.
save
(
g
,
(
base_path
/
f
"{i:03d}"
)
.
__str__
())
#
for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
#
logging.info(f"saving partition data: {i+1}/{num_parts}")
#
torch.save(g, (base_path / f"{i:03d}").__str__())
def
partition_pyg_data
(
data
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
Iterator
[
GraphData
]:
from
torch_geometric.data
import
Data
assert
isinstance
(
data
,
Data
),
f
"must be Data class in pyg"
#
def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
#
from torch_geometric.data import Data
#
assert isinstance(data, Data), f"must be Data class in pyg"
logging
.
info
(
f
"running partition aglorithm: {algo}"
)
#
logging.info(f"running partition aglorithm: {algo}")
num_nodes
:
int
=
data
.
num_nodes
num_edges
:
int
=
data
.
num_edges
edge_index
:
Tensor
=
data
.
edge_index
if
algo
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
)
else
:
raise
ValueError
(
f
"unknown partition algorithm: {algo}"
)
#
num_nodes: int = data.num_nodes
#
num_edges: int = data.num_edges
#
edge_index: Tensor = data.edge_index
#
if algo == "metis":
#
node_parts = metis_partition(edge_index, num_nodes, num_parts)
#
elif algo == "mt-metis":
#
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
#
elif algo == "random":
#
node_parts = random_partition(edge_index, num_nodes, num_parts)
#
else:
#
raise ValueError(f"unknown partition algorithm: {algo}")
if
data
.
y
.
dtype
==
torch
.
long
:
if
data
.
y
.
dim
()
==
1
:
num_classes
=
data
.
y
.
max
()
.
item
()
+
1
else
:
num_classes
=
data
.
y
.
size
(
1
)
else
:
num_classes
=
None
for
i
in
range
(
num_parts
):
npart_mask
=
node_parts
==
i
epart_mask
=
npart_mask
[
edge_index
[
1
]]
local_edges
=
edge_index
[:,
epart_mask
]
raw_src_ids
:
Tensor
=
local_edges
[
0
]
.
unique
()
raw_dst_ids
:
Tensor
=
torch
.
where
(
npart_mask
)[
0
]
#
if data.y.dtype == torch.long:
#
if data.y.dim() == 1:
#
num_classes = data.y.max().item() + 1
#
else:
#
num_classes = data.y.size(1)
#
else:
#
num_classes = None
#
for i in range(num_parts):
#
npart_mask = node_parts == i
#
epart_mask = npart_mask[edge_index[1]]
#
local_edges = edge_index[:, epart_mask]
#
raw_src_ids: Tensor = local_edges[0].unique()
#
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
M
:
int
=
raw_src_ids
.
max
()
.
item
()
+
1
imap
=
torch
.
full
((
M
,),
(
2
**
62
-
1
)
*
2
+
1
)
.
type_as
(
raw_src_ids
)
imap
[
raw_src_ids
]
=
torch
.
arange
(
raw_src_ids
.
numel
())
.
type_as
(
raw_src_ids
)
local_src
=
imap
[
local_edges
[
0
]]
#
M: int = raw_src_ids.max().item() + 1
#
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
#
imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
#
local_src = imap[local_edges[0]]
M
:
int
=
raw_dst_ids
.
max
()
.
item
()
+
1
imap
=
torch
.
full
((
M
,),
(
2
**
62
-
1
)
*
2
+
1
)
.
type_as
(
raw_dst_ids
)
imap
[
raw_dst_ids
]
=
torch
.
arange
(
raw_dst_ids
.
numel
())
.
type_as
(
raw_dst_ids
)
local_dst
=
imap
[
local_edges
[
1
]]
local_edges
=
torch
.
vstack
([
local_src
,
local_dst
])
g
=
GraphData
(
edge_indices
=
{
(
"src"
,
"@"
,
"dst"
):
local_edges
,
},
num_nodes
=
{
"src"
:
raw_src_ids
.
numel
(),
"dst"
:
raw_dst_ids
.
numel
(),
},
)
g
.
node
(
"src"
)[
"raw_ids"
]
=
raw_src_ids
g
.
node
(
"dst"
)[
"raw_ids"
]
=
raw_dst_ids
if
num_classes
is
not
None
:
g
.
meta
()[
"num_classes"
]
=
num_classes
for
key
,
val
in
data
:
if
key
==
"edge_index"
:
continue
elif
isinstance
(
val
,
Tensor
):
if
val
.
size
(
0
)
==
num_nodes
:
g
.
node
(
"dst"
)[
key
]
=
val
[
npart_mask
]
elif
val
.
size
(
0
)
==
num_edges
:
g
.
edge
()[
key
]
=
val
[
epart_mask
]
elif
isinstance
(
val
,
SparseTensor
):
pass
else
:
g
.
meta
()[
key
]
=
val
yield
g
#
M: int = raw_dst_ids.max().item() + 1
#
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
#
imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
#
local_dst = imap[local_edges[1]]
#
local_edges = torch.vstack([local_src, local_dst])
#
g = GraphData(
#
edge_indices={
#
("src", "@", "dst"): local_edges,
#
},
#
num_nodes={
#
"src": raw_src_ids.numel(),
#
"dst": raw_dst_ids.numel(),
#
},
#
)
#
g.node("src")["raw_ids"] = raw_src_ids
#
g.node("dst")["raw_ids"] = raw_dst_ids
#
if num_classes is not None:
#
g.meta()["num_classes"] = num_classes
#
for key, val in data:
#
if key == "edge_index":
#
continue
#
elif isinstance(val, Tensor):
#
if val.size(0) == num_nodes:
#
g.node("dst")[key] = val[npart_mask]
#
elif val.size(0) == num_edges:
#
g.edge()[key] = val[epart_mask]
#
elif isinstance(val, SparseTensor):
#
pass
#
else:
#
g.meta()[key] = val
#
yield g
starrygl/graph/route.py
View file @
9586742a
...
...
@@ -2,13 +2,18 @@ import torch
import
torch.autograd
as
autograd
import
torch.distributed
as
dist
from
multiprocessing.pool
import
ThreadPool
from
torch
import
Tensor
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.distributed.cclib
import
all_to_all_s
,
all_to_all_v
__all__
=
[
"Route"
,
"RouteWork"
,
"RouteWorkCache"
,
"RouteAlltoAll"
,
]
class
Route
:
...
...
@@ -31,50 +36,51 @@ class Route:
group
=
group
,
)
def
subroute
(
self
,
dst_mask
:
Tensor
,
def
filter
(
self
,
dst_mask
:
Optional
[
Tensor
]
=
None
,
src_mask
:
Optional
[
Tensor
]
=
None
,
remap
:
bool
=
False
,
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
if
src_mask
is
None
:
fw_dst_masks
,
work
=
self
.
all_to_all_fw
(
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
else
:
if
src_mask
.
dtype
!=
torch
.
bool
:
src_mask
=
Route
.
__expand_index_to_mask
(
src_mask
,
self
.
src_len
)
assert
src_mask
.
size
(
0
)
==
self
.
src_len
fw_tables
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
m
=
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
fw_tables
.
append
(
self
.
fw_table
(
i
)[:,
m
])
bw_tables
:
List
[
Tensor
]
=
[]
if
src_mask
is
None
:
src_mask
=
torch
.
zeros
(
self
.
src_len
,
dtype
=
dst_mask
.
dtype
,
device
=
dst_mask
.
device
)
work
.
wait
()
for
i
,
m
in
enumerate
(
fw_dst_masks
):
src_mask
[
self
.
bw_table
(
i
)[
0
]]
|=
m
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
if
dst_mask
is
None
:
if
src_mask
is
None
:
raise
ValueError
(
"please provide at least one parameter."
)
else
:
assert
src_mask
.
dtype
==
torch
.
bool
assert
src_mask
.
numel
()
==
self
.
src_len
dst_mask
=
self
.
bw_tensor
(
src_mask
.
long
())
!=
0
else
:
for
i
in
range
(
self
.
num_parts
):
m
=
src_mask
[
self
.
bw_table
(
i
)[
0
]]
bw_tables
.
append
(
self
.
bw_table
(
i
)[:,
m
])
return
dst_mask
,
Route
(
assert
dst_mask
.
dtype
==
torch
.
bool
assert
dst_mask
.
numel
()
==
self
.
dst_len
tmp_src_mask
=
self
.
fw_tensor
(
dst_mask
.
long
())
!=
0
if
src_mask
is
None
:
src_mask
=
tmp_src_mask
else
:
tmp_src_mask
&=
src_mask
src_mask
=
tmp_src_mask
dst_mask
=
self
.
bw_tensor
(
src_mask
.
long
())
!=
0
fw_ptr
,
fw_ind
=
Route
.
__filter_ind_and_ptr
(
self
.
_fw_ptr
,
self
.
_fw_ind
,
dst_mask
)
bw_ptr
,
bw_ind
=
Route
.
__filter_ind_and_ptr
(
self
.
_bw_ptr
,
self
.
_bw_ind
,
src_mask
)
route
=
Route
(
src_len
=
self
.
src_len
,
dst_len
=
self
.
dst_len
,
**
Route
.
__tables_to_indptr
(
fw_tables
,
bw_tables
),
fw_ptr
=
fw_ptr
,
fw_ind
=
fw_ind
,
bw_ptr
=
bw_ptr
,
bw_ind
=
bw_ind
,
group
=
self
.
group
,
),
src_mask
,
dst_mask
)
if
remap
:
fw_ind
,
dst_len
=
Route
.
__remap_ind
(
route
.
_fw_ind
,
dst_mask
)
bw_ind
,
src_len
=
Route
.
__remap_ind
(
route
.
_bw_ind
,
src_mask
)
route
=
Route
(
src_len
=
src_len
,
dst_len
=
dst_len
,
fw_ptr
=
route
.
_fw_ptr
,
fw_ind
=
fw_ind
,
bw_ptr
=
route
.
_bw_ptr
,
bw_ind
=
bw_ind
,
group
=
route
.
group
,
)
return
dst_mask
,
src_mask
,
route
def
rev
erse_route
(
self
):
def
rev
(
self
):
return
Route
(
src_len
=
self
.
dst_len
,
dst_len
=
self
.
src_len
,
...
...
@@ -83,62 +89,21 @@ class Route:
group
=
self
.
group
,
)
def
filter_nodes
(
self
,
dst_mask
:
Tensor
,
src_mask
:
Optional
[
Tensor
]
=
None
,
):
if
dst_mask
.
dtype
!=
torch
.
bool
:
dst_mask
=
Route
.
__expand_index_to_mask
(
dst_mask
,
self
.
dst_len
)
assert
dst_mask
.
size
(
0
)
==
self
.
dst_len
new_dst_len
=
dst_mask
.
count_nonzero
()
.
item
()
xmp
=
torch
.
empty
(
self
.
dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_mask
]
=
torch
.
arange
(
new_dst_len
,
dtype
=
torch
.
long
,
device
=
dst_mask
.
device
)
fw_dst_masks
,
fw_m_work
=
self
.
all_to_all_fw
(
input_tensor_list
=
[
dst_mask
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
True
,
)
# fw_dst_inds, fw_i_work =
# if src_mask is None:
# fw_dst_masks, work = self.all_to_all_fw(
# input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
# async_op=True,
# )
# else:
# if src_mask.dtype != torch.bool:
# src_mask = Route.__expand_index_to_mask(src_mask, self.src_len)
# assert src_mask.size(0) == self.src_len
def
filter_edges
(
self
,
edge_mask
:
Tensor
,
edge_index
:
Tensor
,
):
pass
def
__init__
(
self
,
src_len
:
int
,
dst_len
:
int
,
fw_ptr
:
List
[
int
],
fw_ind
:
Tensor
,
bw_ptr
:
List
[
int
],
bw_ind
:
Tensor
,
group
:
Any
,
)
->
None
:
self
.
_ctx
=
DistributedContext
.
get_default_context
(
)
assert
len
(
fw_ptr
)
==
len
(
bw_ptr
)
self
.
_src_len
=
src_len
self
.
_dst_len
=
dst_len
self
.
_fw_ptr
=
fw_ptr
self
.
_fw_ptr
=
tuple
(
fw_ptr
)
self
.
_fw_ind
=
fw_ind
self
.
_bw_ptr
=
bw_ptr
self
.
_bw_ptr
=
tuple
(
bw_ptr
)
self
.
_bw_ind
=
bw_ind
self
.
_group
=
group
@property
def
ctx
(
self
):
return
self
.
_ctx
@property
def
group
(
self
):
return
self
.
_group
...
...
@@ -164,104 +129,84 @@ class Route:
self
.
_bw_ind
=
self
.
_bw_ind
.
to
(
device
)
return
self
def
fw_table
(
self
,
i
:
int
):
return
self
.
_fw_ind
[:,
self
.
_fw_ptr
[
i
]:
self
.
_fw_ptr
[
i
+
1
]]
#
def fw_table(self, i: int):
# return self._fw_ind[
self._fw_ptr[i]:self._fw_ptr[i+1]]
def
bw_table
(
self
,
i
:
int
):
return
self
.
_bw_ind
[:,
self
.
_bw_ptr
[
i
]:
self
.
_bw_ptr
[
i
+
1
]]
#
def bw_table(self, i: int):
# return self._bw_ind[
self._bw_ptr[i]:self._bw_ptr[i+1]]
def
apply
(
self
,
data
:
Tensor
)
->
Tensor
:
return
RouteAlltoAll
.
apply
(
data
,
self
)
def
apply
(
self
,
data
:
Tensor
,
cache
:
Optional
[
'RouteWorkCache'
]
=
None
,
cache_key
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
return
RouteAlltoAll
.
apply
(
data
,
self
,
cache
,
cache_key
)
def
fw_tensor
(
self
,
data
:
Tensor
)
->
Tensor
:
@torch.no_grad
()
def
fw_tensor
(
self
,
data
:
Tensor
,
async_op
:
bool
=
False
):
assert
data
.
size
(
0
)
==
self
.
dst_len
xs
=
self
.
all_to_all_fw
(
[
data
[
self
.
fw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
False
,
)
out
=
torch
.
zeros
(
self
.
src_len
,
*
data
.
shape
[
1
:],
out
put_tensor
=
torch
.
empty
(
self
.
_bw_ind
.
numel
()
,
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
,
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
bw_table
(
i
)[
0
]]
+=
t
return
out
work
=
all_to_all_s
(
output_tensor
,
data
[
self
.
_fw_ind
],
self
.
_bw_ptr
,
self
.
_fw_ptr
,
group
=
self
.
group
,
async_op
=
async_op
,
)
work
=
RouteWork
(
work
if
async_op
else
None
,
self
.
_bw_ptr
,
self
.
_bw_ind
,
self
.
src_len
,
output_tensor
,
)
return
work
if
async_op
else
work
.
wait
()
def
bw_tensor
(
self
,
data
:
Tensor
)
->
Tensor
:
@torch.no_grad
()
def
bw_tensor
(
self
,
data
:
Tensor
,
async_op
:
bool
=
False
):
assert
data
.
size
(
0
)
==
self
.
src_len
xs
=
self
.
all_to_all_bw
(
[
data
[
self
.
bw_table
(
i
)[
0
]]
for
i
in
range
(
self
.
num_parts
)],
async_op
=
False
,
)
out
=
torch
.
zeros
(
self
.
dst_len
,
*
data
.
shape
[
1
:],
out
put_tensor
=
torch
.
empty
(
self
.
_fw_ind
.
numel
()
,
*
data
.
shape
[
1
:],
dtype
=
data
.
dtype
,
device
=
data
.
device
,
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
fw_table
(
i
)[
0
]]
+=
t
return
out
def
get_src_part_ids
(
self
)
->
Tensor
:
xs
=
self
.
all_to_all_fw
(
[
torch
.
full
(
(
self
.
fw_table
(
i
)
.
size
(
1
),),
self
.
part_id
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
,
)
for
i
in
range
(
self
.
num_parts
)
],
async_op
=
False
,
work
=
all_to_all_s
(
output_tensor
,
data
[
self
.
_bw_ind
],
self
.
_fw_ptr
,
self
.
_bw_ptr
,
group
=
self
.
group
,
async_op
=
async_op
,
)
out
=
torch
.
full
((
self
.
src_len
,),
2
**
16
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
ctx
.
device
)
for
i
,
t
in
enumerate
(
xs
):
out
[
self
.
bw_table
(
i
)[
0
]]
=
t
return
out
.
int
()
&
0xFFFF
def
all_to_all_fw
(
self
,
input_tensor_list
:
List
[
Tensor
],
async_op
:
bool
=
False
):
output_tensor_list
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
t
=
input_tensor_list
[
i
]
assert
t
.
size
(
0
)
==
self
.
_fw_ptr
[
i
+
1
]
-
self
.
_fw_ptr
[
i
]
s
=
self
.
_bw_ptr
[
i
+
1
]
-
self
.
_bw_ptr
[
i
]
output_tensor_list
.
append
(
torch
.
empty
(
s
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
work
=
self
.
ctx
.
all_to_all_v
(
output_tensor_list
,
input_tensor_list
,
group
=
self
.
group
,
async_op
=
async_op
,
work
=
RouteWork
(
work
if
async_op
else
None
,
self
.
_fw_ptr
,
self
.
_fw_ind
,
self
.
dst_len
,
output_tensor
,
)
if
async_op
:
return
output_tensor_list
,
work
else
:
return
output_tensor_list
return
work
if
async_op
else
work
.
wait
()
def
all_to_all_bw
(
self
,
input_tensor_list
:
List
[
Tensor
],
async_op
:
bool
=
False
):
output_tensor_list
:
List
[
Tensor
]
=
[]
for
i
in
range
(
self
.
num_parts
):
t
=
input_tensor_list
[
i
]
assert
t
.
size
(
0
)
==
self
.
_bw_ptr
[
i
+
1
]
-
self
.
_bw_ptr
[
i
]
s
=
self
.
_fw_ptr
[
i
+
1
]
-
self
.
_fw_ptr
[
i
]
output_tensor_list
.
append
(
torch
.
empty
(
s
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
work
=
self
.
ctx
.
all_to_all_v
(
output_tensor_list
,
input_tensor_list
,
group
=
self
.
group
,
async_op
=
async_op
,
@torch.no_grad
()
def
get_src_part_ids
(
self
)
->
Tensor
:
input_tensor
=
torch
.
full_like
(
self
.
_fw_ind
,
self
.
part_id
)
output_tensor
=
torch
.
empty_like
(
self
.
_bw_ind
)
all_to_all_s
(
output_tensor
,
input_tensor
,
self
.
_bw_ptr
,
self
.
_fw_ptr
,
group
=
self
.
group
,
)
if
async_op
:
return
output_tensor_list
,
work
else
:
return
output_tensor_list
out
=
torch
.
full
(
(
self
.
src_len
,),
2
**
16
-
1
,
dtype
=
self
.
_bw_ind
.
dtype
,
device
=
self
.
_bw_ind
.
device
,
)
for
s
,
t
in
zip
(
self
.
_bw_ptr
,
self
.
_bw_ptr
[
1
:]):
ind
=
self
.
_bw_ind
[
s
:
t
]
assert
(
out
[
ind
]
==
2
**
16
-
1
)
.
all
(),
f
"some vertices exist in more than one partition"
out
[
ind
]
=
output_tensor
[
s
:
t
]
&
0xFF
return
out
@staticmethod
def
_build_route_tables
(
...
...
@@ -275,8 +220,6 @@ class Route:
assert
src_ids
.
dim
()
==
1
assert
dst_ids
.
dim
()
==
1
ctx
=
DistributedContext
.
get_default_context
()
src_len
=
src_ids
.
size
(
0
)
dst_len
=
dst_ids
.
size
(
0
)
...
...
@@ -292,6 +235,7 @@ class Route:
all_dst_lens
:
List
[
int
]
=
[
None
]
*
world_size
dist
.
all_gather_object
(
all_dst_lens
,
dst_len
,
group
=
group
)
# all_reduce number of nodes
num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
src_ids
.
numel
()
>
0
:
num_nodes
=
src_ids
.
max
()
.
max
(
num_nodes
)
...
...
@@ -342,25 +286,54 @@ class Route:
src_ind
=
smp
[
ind
]
dst_ind
=
xmp
[
ind
]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_tables
.
append
(
torch
.
vstack
([
dst_ind
,
src_ind
]))
bw_tables
.
append
(
torch
.
vstack
([
src_ind
,
dst_ind
]))
fw_tables
=
[
t
.
t
()
.
contiguous
()
for
t
in
fw_tables
]
fw_tables
=
ctx
.
all_to_all_g
(
fw_tables
,
group
=
group
)
fw_tables
=
[
t
.
t
()
.
contiguous
()
for
t
in
fw_tables
]
fw_tables
.
append
(
dst_ind
)
bw_tables
.
append
(
src_ind
)
fw_tables
=
Route
.
__backward_fw_tables
(
fw_tables
,
group
=
group
)
#
非二分图,每个点添加自环
#
add self-loops if not bipartite graph
if
not
bipartite
:
rank_ind
=
torch
.
arange
(
dst_len
,
**
ikw
)
fw_tables
[
rank
]
=
bw_tables
[
rank
]
=
torch
.
vstack
([
rank_ind
,
rank_ind
])
fw_tables
[
rank
]
=
bw_tables
[
rank
]
=
rank_ind
return
fw_tables
,
bw_tables
@staticmethod
def
__expand_index_to_mask
(
index
:
Tensor
,
len
:
int
)
->
Tensor
:
mask
=
torch
.
zeros
(
len
,
dtype
=
torch
.
bool
,
device
=
index
.
device
)
mask
[
index
]
=
True
return
mask
def
__filter_ind_and_ptr
(
ptr
:
List
[
int
],
ind
:
Tensor
,
mask
:
Tensor
)
->
Tuple
[
List
[
int
],
Tensor
]:
m
=
mask
[
ind
]
new_ptr
:
List
[
int
]
=
[
0
]
new_ind
:
List
[
Tensor
]
=
[]
for
s
,
t
in
zip
(
ptr
,
ptr
[
1
:]):
new_ind
.
append
(
ind
[
s
:
t
][
m
[
s
:
t
]])
new_ptr
.
append
(
new_ptr
[
-
1
]
+
new_ind
[
-
1
]
.
numel
())
return
new_ptr
,
torch
.
cat
(
new_ind
,
dim
=
0
)
@staticmethod
def
__remap_ind
(
ind
:
Tensor
,
mask
:
Tensor
)
->
Tuple
[
Tensor
,
int
]:
n
:
int
=
mask
.
count_nonzero
()
.
item
()
imp
=
torch
.
full
((
mask
.
numel
(),),
(
2
**
62
-
1
)
*
2
+
1
,
dtype
=
ind
.
dtype
,
device
=
ind
.
device
)
imp
[
mask
]
=
torch
.
arange
(
n
,
dtype
=
ind
.
dtype
,
device
=
ind
.
device
)
return
ind
,
int
(
n
)
@staticmethod
def
__backward_fw_tables
(
fw_tables
:
List
[
Tensor
],
group
:
Any
,
)
->
List
[
Tensor
]:
rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
send_sizes
=
[
t
.
size
()
for
t
in
fw_tables
]
recv_sizes
=
[
None
]
*
world_size
dist
.
all_gather_object
(
recv_sizes
,
send_sizes
,
group
=
group
)
recv_sizes
=
[
s
[
rank
]
for
s
in
recv_sizes
]
fixed_tables
=
[]
for
s
,
t
in
zip
(
recv_sizes
,
fw_tables
):
t
=
torch
.
empty
(
*
s
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
fixed_tables
.
append
(
t
)
all_to_all_v
(
fixed_tables
,
fw_tables
,
group
=
group
)
return
fixed_tables
@staticmethod
def
__tables_to_indptr
(
...
...
@@ -369,21 +342,101 @@ class Route:
):
fw_ptr
:
List
[
int
]
=
[
0
]
for
t
in
fw_tables
:
last_n
=
fw_ptr
[
-
1
]
fw_ptr
.
append
(
last_n
+
t
.
size
(
1
))
fw_ind
=
torch
.
cat
(
fw_tables
,
dim
=
1
)
assert
t
.
dim
()
==
1
fw_ptr
.
append
(
fw_ptr
[
-
1
]
+
t
.
numel
(
))
fw_ind
=
torch
.
cat
(
fw_tables
,
dim
=
0
)
bw_ptr
:
List
[
int
]
=
[
0
]
for
t
in
bw_tables
:
last_n
=
bw_ptr
[
-
1
]
bw_ptr
.
append
(
last_n
+
t
.
size
(
1
))
bw_ind
=
torch
.
cat
(
bw_tables
,
dim
=
1
)
assert
t
.
dim
()
==
1
bw_ptr
.
append
(
bw_ptr
[
-
1
]
+
t
.
numel
(
))
bw_ind
=
torch
.
cat
(
bw_tables
,
dim
=
0
)
return
{
"fw_ptr"
:
fw_ptr
,
"bw_ptr"
:
bw_ptr
,
"fw_ind"
:
fw_ind
,
"bw_ind"
:
bw_ind
,
}
class
RouteWork
:
def
__init__
(
self
,
work
:
Any
,
ptr
:
List
[
int
],
ind
:
Tensor
,
len
:
int
,
recv_t
:
Tensor
,
)
->
None
:
self
.
_work
=
work
self
.
_ptr
=
ptr
self
.
_ind
=
ind
self
.
_len
=
len
self
.
_recv_t
=
recv_t
if
self
.
_work
is
None
:
self
.
_reduce
()
def
_reduce
(
self
):
out
=
torch
.
zeros
(
self
.
_len
,
*
self
.
_recv_t
.
shape
[
1
:],
dtype
=
self
.
_recv_t
.
dtype
,
device
=
self
.
_recv_t
.
device
,
)
for
s
,
t
in
zip
(
self
.
_ptr
,
self
.
_ptr
[
1
:]):
ind
=
self
.
_ind
[
s
:
t
]
out
[
ind
]
+=
self
.
_recv_t
[
s
:
t
]
self
.
_work
=
None
self
.
_ptr
=
None
self
.
_ind
=
None
self
.
_len
=
None
self
.
_recv_t
=
out
def
wait
(
self
)
->
Tensor
:
if
self
.
_work
is
None
:
return
self
.
_recv_t
self
.
_work
.
wait
()
self
.
_reduce
()
return
self
.
_recv_t
class
RouteWorkCache
:
def
__init__
(
self
,
enable_fw
:
bool
=
True
,
enable_bw
:
bool
=
True
,
)
->
None
:
self
.
enable_fw
=
enable_fw
self
.
enable_bw
=
enable_bw
self
.
_cached_works
:
Dict
[
str
,
RouteWork
]
=
{}
def
enable_fw_
(
self
,
enable
:
bool
=
True
):
self
.
enable_fw
=
enable
return
self
def
enable_bw_
(
self
,
enable
:
bool
=
True
):
self
.
enable_bw
=
enable
return
self
def
wait
(
self
):
for
work
in
self
.
_cached_works
.
values
():
work
.
wait
()
def
clear
(
self
):
self
.
_cached_works
.
clear
()
def
get_and_set
(
self
,
key
:
str
,
work
:
RouteWork
,
bw
:
bool
=
False
,
)
->
Optional
[
RouteWork
]:
if
bw
and
self
.
enable_bw
:
key
=
key
+
"_bw"
elif
not
bw
and
self
.
enable_fw
:
key
=
key
+
"_fw"
else
:
return
work
t
=
self
.
_cached_works
.
get
(
key
,
work
)
self
.
_cached_works
[
key
]
=
work
return
t
class
RouteAlltoAll
(
autograd
.
Function
):
@staticmethod
...
...
@@ -391,10 +444,18 @@ class RouteAlltoAll(autograd.Function):
ctx
:
autograd
.
function
.
FunctionCtx
,
x
:
Tensor
,
route
:
Route
,
cache
:
Optional
[
RouteWorkCache
],
cache_key
:
Optional
[
str
],
):
ctx
.
saved_route
=
route
return
route
.
fw_tensor
(
x
)
ctx
.
saved_cache
=
cache
ctx
.
saved_cache_key
=
cache_key
if
cache
is
None
or
cache_key
is
None
:
return
route
.
fw_tensor
(
x
)
else
:
work
=
route
.
fw_tensor
(
x
,
async_op
=
True
)
work
=
cache
.
get_and_set
(
cache_key
,
work
,
bw
=
False
)
return
work
.
wait
()
@staticmethod
def
backward
(
...
...
@@ -402,4 +463,13 @@ class RouteAlltoAll(autograd.Function):
grad
:
Tensor
,
)
->
Tensor
:
route
:
Route
=
ctx
.
saved_route
return
route
.
bw_tensor
(
grad
),
None
\ No newline at end of file
cache
:
Optional
[
RouteWorkCache
]
=
ctx
.
saved_cache
cache_key
:
Optional
[
str
]
=
ctx
.
saved_cache_key
if
cache
is
None
or
cache_key
is
None
:
return
route
.
bw_tensor
(
grad
),
None
,
None
,
None
else
:
work
=
route
.
bw_tensor
(
grad
,
async_op
=
True
)
work
=
cache
.
get_and_set
(
cache_key
,
work
,
bw
=
True
)
return
work
.
wait
(),
None
,
None
,
None
\ No newline at end of file
starrygl/graph/utils.py
deleted
100644 → 0
View file @
10c38111
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
def
init_vc_edge_index
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
,
bipartite
:
bool
=
True
,
)
->
Tuple
[
Tensor
,
Tensor
]:
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
dst_ids
.
device
)
local_num_nodes
=
torch
.
zeros
(
1
,
**
ikw
)
if
dst_ids
.
numel
()
>
0
:
local_num_nodes
=
dst_ids
.
max
()
.
max
(
local_num_nodes
)
if
edge_index
.
numel
()
>
0
:
local_num_nodes
=
edge_index
.
max
()
.
max
(
local_num_nodes
)
local_num_nodes
=
local_num_nodes
.
item
()
+
1
xmp
:
Tensor
=
torch
.
zeros
(
local_num_nodes
,
**
ikw
)
xmp
[
edge_index
[
1
]
.
unique
()]
+=
0
b01
xmp
[
dst_ids
.
unique
()]
+=
0
b10
if
not
(
xmp
!=
0x01
)
.
all
():
raise
RuntimeError
(
f
"must be vertex-cut partition graph"
)
if
bipartite
:
src_ids
=
edge_index
[
0
]
.
unique
()
else
:
xmp
.
fill_
(
0
)
xmp
[
edge_index
[
0
]]
=
1
xmp
[
dst_ids
]
=
0
src_ids
=
torch
.
cat
([
dst_ids
,
torch
.
where
(
xmp
>
0
)[
0
]],
dim
=-
1
)
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
src_ids
]
=
torch
.
arange
(
src_ids
.
size
(
0
),
**
ikw
)
src
=
xmp
[
edge_index
[
0
]]
xmp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
xmp
[
dst_ids
]
=
torch
.
arange
(
dst_ids
.
size
(
0
),
**
ikw
)
dst
=
xmp
[
edge_index
[
1
]]
local_edge_index
=
torch
.
vstack
([
src
,
dst
])
return
src_ids
,
local_edge_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment