Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
542831ac
Commit
542831ac
authored
Dec 23, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add constraint params to GraphData.save_partition
parent
150eafe6
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
16 deletions
+91
-16
starrygl/data/graph.py
+90
-15
train_hybrid.py
+1
-1
No files found.
starrygl/data/graph.py
View file @
542831ac
...
...
@@ -20,6 +20,9 @@ __all__ = [
]
Strings
=
Sequence
[
str
]
OptStrings
=
Optional
[
Strings
]
class
GraphData
:
def
__init__
(
self
,
edge_indices
:
Union
[
Tensor
,
Dict
[
Tuple
[
str
,
str
,
str
],
Tensor
]],
...
...
@@ -168,33 +171,99 @@ class GraphData:
return
g
@staticmethod
def
load_partition
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algo
:
str
=
"metis"
)
->
'GraphData'
:
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algo}_{num_parts}"
/
f
"{part_id:03d}"
def
load_partition
(
root
:
str
,
part_id
:
int
,
num_parts
:
int
,
algorithm
:
str
=
"metis"
,
)
->
'GraphData'
:
p
=
Path
(
root
)
.
expanduser
()
.
resolve
()
/
f
"{algorithm}_{num_parts}"
/
f
"{part_id:03d}"
return
torch
.
load
(
p
.
__str__
())
def
save_partition
(
self
,
root
:
str
,
num_parts
:
int
,
algo
:
str
=
"metis"
):
def
save_partition
(
self
,
root
:
str
,
num_parts
:
int
,
node_weight
:
Optional
[
str
]
=
None
,
edge_weight
:
Optional
[
str
]
=
None
,
include_node_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
include_edge_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
include_meta_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
ignore_node_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
ignore_edge_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
ignore_meta_attrs
:
Optional
[
Sequence
[
str
]]
=
None
,
algorithm
:
str
=
"metis"
,
partition_kwargs
=
None
,
):
assert
not
self
.
is_heterogeneous
,
"only support homomorphic graph"
num_nodes
:
int
=
self
.
node
()
.
num_nodes
edge_index
:
Tensor
=
self
.
edge_index
()
logging
.
info
(
f
"running partition aglorithm: {algo}"
)
if
algo
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
)
elif
algo
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
)
logging
.
info
(
f
"running partition aglorithm: {algorithm}"
)
partition_kwargs
=
partition_kwargs
or
{}
if
node_weight
is
not
None
:
node_weight
=
self
.
node
()[
node_weight
]
if
edge_weight
is
not
None
:
edge_weight
=
self
.
edge
()[
edge_weight
]
if
algorithm
==
"metis"
:
node_parts
=
metis_partition
(
edge_index
,
num_nodes
,
num_parts
,
node_weight
=
node_weight
,
edge_weight
=
edge_weight
,
**
partition_kwargs
,
)
elif
algorithm
==
"mt-metis"
:
node_parts
=
mt_metis_partition
(
edge_index
,
num_nodes
,
num_parts
,
node_weight
=
node_weight
,
edge_weight
=
edge_weight
,
**
partition_kwargs
,
)
elif
algorithm
==
"random"
:
node_parts
=
random_partition
(
edge_index
,
num_nodes
,
num_parts
,
**
partition_kwargs
,
)
else
:
raise
ValueError
(
f
"unknown partition algorithm: {algo}"
)
raise
ValueError
(
f
"unknown partition algorithm: {algo
rithm
}"
)
root_path
=
Path
(
root
)
.
expanduser
()
.
resolve
()
base_path
=
root_path
/
f
"{algo}_{num_parts}"
base_path
=
root_path
/
f
"{algo
rithm
}_{num_parts}"
if
base_path
.
exists
():
logging
.
warning
(
f
"directory '{base_path.__str__()}' exists, and will be removed."
)
shutil
.
rmtree
(
base_path
.
__str__
())
base_path
.
mkdir
(
parents
=
True
)
if
include_node_attrs
is
None
:
include_node_attrs
=
self
.
node
()
.
keys
()
if
include_edge_attrs
is
None
:
include_edge_attrs
=
self
.
edge
()
.
keys
()
if
include_meta_attrs
is
None
:
include_meta_attrs
=
self
.
meta
()
.
keys
()
if
ignore_node_attrs
is
None
:
ignore_node_attrs
=
set
()
else
:
ignore_node_attrs
=
set
(
ignore_node_attrs
)
if
ignore_edge_attrs
is
None
:
ignore_edge_attrs
=
set
()
else
:
ignore_edge_attrs
=
set
(
ignore_edge_attrs
)
if
ignore_meta_attrs
is
None
:
ignore_meta_attrs
=
set
()
else
:
ignore_meta_attrs
=
set
(
ignore_meta_attrs
)
for
i
in
range
(
num_parts
):
npart_mask
=
node_parts
==
i
...
...
@@ -213,13 +282,19 @@ class GraphData:
raw_dst_ids
=
raw_dst_ids
,
)
for
key
in
self
.
node
()
.
keys
():
for
key
in
include_node_attrs
:
if
key
in
ignore_node_attrs
:
continue
g
.
node
(
"dst"
)[
key
]
=
self
.
node
()[
key
][
npart_mask
]
for
key
in
self
.
edge
()
.
keys
():
for
key
in
include_edge_attrs
:
if
key
in
ignore_edge_attrs
:
continue
g
.
edge
()[
key
]
=
self
.
edge
()[
key
][
epart_mask
]
for
key
in
self
.
meta
()
.
keys
():
for
key
in
include_meta_attrs
:
if
key
in
ignore_meta_attrs
:
continue
g
.
meta
()[
key
]
=
self
.
meta
()[
key
]
logging
.
info
(
f
"saving partition data: {i+1}/{num_parts}"
)
...
...
train_hybrid.py
View file @
542831ac
...
...
@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"):
logging
.
info
(
f
"GraphData.node().keys(): {g.node().keys()}"
)
logging
.
info
(
f
"GraphData.edge().keys(): {g.edge().keys()}"
)
g
.
save_partition
(
root
,
num_parts
,
part_algo
)
g
.
save_partition
(
root
,
num_parts
,
algorithm
=
part_algo
)
return
g
class
SimpleConv
(
pyg_nn
.
MessagePassing
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment