Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
09c175e2
Commit
09c175e2
authored
Dec 22, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add demo train_hybrid.py
parent
057dc57f
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
301 additions
and
5 deletions
+301
-5
starrygl/distributed/cclib.py
+6
-0
starrygl/distributed/context.py
+85
-1
starrygl/parallel/route.py
+7
-1
starrygl/parallel/sequence.py
+0
-0
starrygl/parallel/sparse.py
+17
-3
train_hybrid.py
+186
-0
No files found.
starrygl/distributed/cclib.py
View file @
09c175e2
...
@@ -152,8 +152,11 @@ def batch_send(
...
@@ -152,8 +152,11 @@ def batch_send(
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
return
BatchWork
(
None
,
None
)
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
dst
=
dist
.
get_global_rank
(
group
,
dst
)
if
async_op
:
if
async_op
:
works
=
[]
works
=
[]
...
@@ -177,8 +180,11 @@ def batch_recv(
...
@@ -177,8 +180,11 @@ def batch_recv(
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
return
BatchWork
(
None
,
None
)
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
src
=
dist
.
get_global_rank
(
group
,
src
)
if
async_op
:
if
async_op
:
works
=
[]
works
=
[]
...
...
starrygl/distributed/context.py
View file @
09c175e2
...
@@ -7,6 +7,7 @@ import os
...
@@ -7,6 +7,7 @@ import os
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
import
socket
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
logging
import
logging
...
@@ -127,6 +128,18 @@ class DistributedContext:
...
@@ -127,6 +128,18 @@ class DistributedContext:
self
.
_local_rank
=
local_rank
self
.
_local_rank
=
local_rank
self
.
_compute_device
=
device
self
.
_compute_device
=
device
self
.
_hostname
=
socket
.
gethostname
()
if
self
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
set_device
(
self
.
device
)
rank_to_host
=
[
None
]
*
self
.
world_size
dist
.
all_gather_object
(
rank_to_host
,
(
self
.
hostname
,
self
.
local_rank
))
self
.
_rank_to_host
:
Tuple
[
Tuple
[
str
,
int
],
...
]
=
tuple
(
rank_to_host
)
host_index
=
[
h
for
h
,
_
in
self
.
rank_to_host
]
host_index
.
sort
()
self
.
_host_index
:
Dict
[
str
,
int
]
=
{
h
:
i
for
i
,
h
in
enumerate
(
host_index
)}
self
.
__temp_ag_remote_object
:
Optional
[
rpc
.
RRef
]
=
None
self
.
__temp_ag_remote_object
:
Optional
[
rpc
.
RRef
]
=
None
...
@@ -145,16 +158,87 @@ class DistributedContext:
...
@@ -145,16 +158,87 @@ class DistributedContext:
@property
@property
def
local_rank
(
self
)
->
int
:
def
local_rank
(
self
)
->
int
:
return
self
.
_local_rank
return
self
.
_local_rank
@property
def
hostname
(
self
)
->
str
:
return
self
.
_hostname
@property
def
rank_to_host
(
self
):
return
self
.
_rank_to_host
@property
def
host_index
(
self
):
return
self
.
_host_index
@property
@property
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
_compute_device
return
self
.
_compute_device
def
get_default_group
(
self
):
def
get_default_group
(
self
):
return
dist
.
distributed_c10d
.
_get_default_group
()
# return dist.distributed_c10d._get_default_group()
return
dist
.
GroupMember
.
WORLD
def
get_default_store
(
self
):
def
get_default_store
(
self
):
return
dist
.
distributed_c10d
.
_get_default_store
()
return
dist
.
distributed_c10d
.
_get_default_store
()
def
get_ranks_by_host
(
self
,
hostname
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
...
]:
if
hostname
is
None
:
hostname
=
self
.
hostname
ranks
:
List
[
int
]
=
[]
for
i
,
(
h
,
r
)
in
enumerate
(
self
.
rank_to_host
):
if
h
==
hostname
:
ranks
.
append
(
i
)
ranks
.
sort
()
return
tuple
(
ranks
)
def
get_ranks_by_local
(
self
,
local_rank
:
Optional
[
int
]
=
None
)
->
Tuple
[
int
,
...
]:
if
local_rank
is
None
:
local_rank
=
self
.
local_rank
ranks
:
List
[
Tuple
[
int
,
str
]]
=
[]
for
i
,
(
h
,
r
)
in
enumerate
(
self
.
rank_to_host
):
if
r
==
local_rank
:
ranks
.
append
((
i
,
h
))
ranks
.
sort
(
key
=
lambda
x
:
self
.
host_index
[
x
[
1
]])
return
tuple
(
i
for
i
,
h
in
ranks
)
def
get_hybrid_matrix
(
self
)
->
Tensor
:
hosts
=
sorted
(
self
.
host_index
.
items
(),
key
=
lambda
x
:
x
[
1
])
matrix
=
[]
for
h
,
_
in
hosts
:
rs
=
self
.
get_ranks_by_host
(
h
)
matrix
.
append
(
rs
)
return
torch
.
tensor
(
matrix
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
def
new_hybrid_subgroups
(
self
,
matrix
:
Optional
[
Tensor
]
=
None
,
backend
:
Any
=
None
,
)
->
Tuple
[
Any
,
Any
]:
if
matrix
is
None
:
matrix
=
self
.
get_hybrid_matrix
()
assert
matrix
.
dim
()
==
2
row_group
=
None
col_group
=
None
for
row
in
matrix
.
tolist
():
if
self
.
rank
in
row
:
row_group
=
dist
.
new_group
(
row
,
backend
=
backend
,
use_local_synchronization
=
True
,
)
break
for
col
in
matrix
.
t
()
.
tolist
():
if
self
.
rank
in
col
:
col_group
=
dist
.
new_group
(
col
,
backend
=
backend
,
use_local_synchronization
=
True
,
)
break
assert
row_group
is
not
None
assert
col_group
is
not
None
return
row_group
,
col_group
def
get_worker_info
(
self
,
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
def
get_worker_info
(
self
,
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
...
...
starrygl/parallel/route.py
View file @
09c175e2
...
@@ -24,6 +24,9 @@ class Route:
...
@@ -24,6 +24,9 @@ class Route:
bipartite
:
bool
=
True
,
bipartite
:
bool
=
True
,
group
:
Any
=
None
,
group
:
Any
=
None
,
)
->
'Route'
:
)
->
'Route'
:
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
fw_tables
,
bw_tables
=
Route
.
_build_route_tables
(
fw_tables
,
bw_tables
=
Route
.
_build_route_tables
(
src_ids
=
src_ids
,
dst_ids
=
dst_ids
,
src_ids
=
src_ids
,
dst_ids
=
dst_ids
,
bipartite
=
bipartite
,
group
=
group
,
bipartite
=
bipartite
,
group
=
group
,
...
@@ -256,8 +259,11 @@ class Route:
...
@@ -256,8 +259,11 @@ class Route:
all_dst_ids
[
i
]
=
dst_ids
all_dst_ids
[
i
]
=
dst_ids
else
:
else
:
all_dst_ids
[
i
]
=
torch
.
empty
(
all_dst_lens
[
i
],
**
ikw
)
all_dst_ids
[
i
]
=
torch
.
empty
(
all_dst_lens
[
i
],
**
ikw
)
src_rank
=
dist
.
get_global_rank
(
group
,
i
)
all_dst_get
[
i
]
=
dist
.
broadcast
(
all_dst_get
[
i
]
=
dist
.
broadcast
(
all_dst_ids
[
i
],
src
=
i
,
async_op
=
True
,
group
=
group
all_dst_ids
[
i
],
src
=
src_rank
,
async_op
=
True
,
group
=
group
,
)
)
fw_tables
:
List
[
Tensor
]
=
[]
fw_tables
:
List
[
Tensor
]
=
[]
...
...
starrygl/parallel/sequence.py
View file @
09c175e2
This diff is collapsed.
Click to expand it.
starrygl/parallel/sparse.py
View file @
09c175e2
...
@@ -58,6 +58,9 @@ class SparseBlocks:
...
@@ -58,6 +58,9 @@ class SparseBlocks:
def
__fetch_ids_sizes
(
local_ids
:
Tensor
,
group
:
Any
):
def
__fetch_ids_sizes
(
local_ids
:
Tensor
,
group
:
Any
):
assert
local_ids
.
dim
()
==
1
assert
local_ids
.
dim
()
==
1
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
rank
=
dist
.
get_rank
(
group
)
rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
local_ids
.
device
)
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
local_ids
.
device
)
...
@@ -80,8 +83,9 @@ class SparseBlocks:
...
@@ -80,8 +83,9 @@ class SparseBlocks:
all_ids
[
i
]
=
local_ids
all_ids
[
i
]
=
local_ids
else
:
else
:
all_ids
[
i
]
=
torch
.
empty
(
all_lens
[
i
],
**
ikw
)
all_ids
[
i
]
=
torch
.
empty
(
all_lens
[
i
],
**
ikw
)
src
=
dist
.
get_global_rank
(
group
,
i
)
all_get
[
i
]
=
dist
.
broadcast
(
all_get
[
i
]
=
dist
.
broadcast
(
all_ids
[
i
],
src
=
i
,
async_op
=
True
,
group
=
group
all_ids
[
i
],
src
=
src
,
async_op
=
True
,
group
=
group
)
)
imp
:
Tensor
=
torch
.
full
((
num_nodes
,),
(
2
**
62
-
1
)
*
2
+
1
,
**
ikw
)
imp
:
Tensor
=
torch
.
full
((
num_nodes
,),
(
2
**
62
-
1
)
*
2
+
1
,
**
ikw
)
...
@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
...
@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
part_id
=
sp
.
part_id
part_id
=
sp
.
part_id
num_parts
=
sp
.
num_parts
num_parts
=
sp
.
num_parts
group
=
sp
.
group
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
def
async_fetch
(
i
:
int
):
def
async_fetch
(
i
:
int
):
n
=
sp
.
adj_t
(
i
)
.
sparse_size
(
1
)
n
=
sp
.
adj_t
(
i
)
.
sparse_size
(
1
)
if
i
==
part_id
:
if
i
==
part_id
:
h
=
x
.
clone
()
h
=
x
.
clone
()
else
:
else
:
h
=
torch
.
empty
(
n
,
*
x
.
shape
[
1
:],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
h
=
torch
.
empty
(
n
,
*
x
.
shape
[
1
:],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
dist
.
broadcast
(
h
,
src
=
i
,
group
=
sp
.
group
,
async_op
=
True
)
src
=
dist
.
get_global_rank
(
group
,
i
)
return
dist
.
broadcast
(
h
,
src
=
src
,
group
=
sp
.
group
,
async_op
=
True
)
last_work
=
None
last_work
=
None
out
=
None
out
=
None
...
@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
...
@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
part_id
=
sp
.
part_id
part_id
=
sp
.
part_id
num_parts
=
sp
.
num_parts
num_parts
=
sp
.
num_parts
group
=
sp
.
group
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
def
async_reduce
(
i
:
int
,
g
:
Tensor
):
def
async_reduce
(
i
:
int
,
g
:
Tensor
):
dst
=
dist
.
get_global_rank
(
group
,
i
)
return
dist
.
reduce
(
return
dist
.
reduce
(
g
,
dst
=
i
,
op
=
dist
.
ReduceOp
.
SUM
,
g
,
dst
=
dst
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
sp
.
group
,
async_op
=
True
,
group
=
sp
.
group
,
async_op
=
True
,
)
)
...
...
train_hybrid.py
0 → 100644
View file @
09c175e2
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.data
import
GraphData
from
starrygl.parallel
import
Route
,
SequencePipe
from
starrygl.parallel.sequence
import
STensor
from
starrygl.parallel.utils
import
*
import
torch_geometric.nn
as
pyg_nn
import
torch_geometric.datasets
as
pyg_datasets
import
torch_geometric.utils
as
pyg_utils
import
logging
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
def
prepare_data
(
root
:
str
,
num_parts
,
part_algo
:
str
=
"metis"
):
ctx
=
DistributedContext
.
get_default_context
()
data
=
pyg_datasets
.
Planetoid
(
root
,
"Cora"
)[
0
]
if
data
.
is_directed
():
data
.
edge_index
,
_
=
pyg_utils
.
to_undirected
(
data
.
edge_index
)
data
.
edge_index
,
_
=
pyg_utils
.
add_remaining_self_loops
(
data
.
edge_index
)
data
.
num_classes
=
data
.
y
.
max
()
.
item
()
+
1
logging
.
info
(
f
"num_nodes: {data.num_nodes}"
)
logging
.
info
(
f
"num_edges: {data.num_edges}"
)
logging
.
info
(
f
"num_features: {data.num_features}"
)
logging
.
info
(
f
"num_classes: {data.num_classes}"
)
g
=
GraphData
.
from_pyg_data
(
data
)
logging
.
info
(
f
"GraphData.meta().keys(): {g.meta().keys()}"
)
logging
.
info
(
f
"GraphData.node().keys(): {g.node().keys()}"
)
logging
.
info
(
f
"GraphData.edge().keys(): {g.edge().keys()}"
)
g
.
save_partition
(
root
,
num_parts
,
part_algo
)
return
g
class
SimpleConv
(
pyg_nn
.
MessagePassing
):
def
__init__
(
self
,
in_feats
:
int
,
out_feats
:
int
):
super
()
.
__init__
(
aggr
=
"mean"
)
self
.
linear
=
nn
.
Linear
(
in_feats
,
out_feats
)
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
route
:
Route
):
dst_len
=
x
.
size
(
0
)
x
=
route
.
apply
(
x
)
# exchange features
return
self
.
propagate
(
edge_index
,
x
=
x
)[:
dst_len
]
def
message
(
self
,
x_j
:
Tensor
):
return
x_j
def
update
(
self
,
x
:
Tensor
):
return
F
.
relu
(
self
.
linear
(
x
))
class
SimpleGNN
(
nn
.
Module
):
def
__init__
(
self
,
num_features
:
int
,
hidden_dims
:
int
,
num_layers
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_layers
):
in_ch
=
hidden_dims
if
i
>
0
else
num_features
out_ch
=
hidden_dims
self
.
layers
.
append
(
SimpleConv
(
in_ch
,
out_ch
))
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
route
:
Route
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
edge_index
,
route
)
return
x
class
SimpleRNN
(
SequencePipe
,
nn
.
Module
):
def
__init__
(
self
,
num_classes
:
int
,
hidden_dims
:
int
,
num_layers
:
int
,
device
:
Any
,
group
:
Any
,
)
->
None
:
super
()
.
__init__
()
self
.
device
=
device
self
.
group
=
group
self
.
num_layers
=
num_layers
self
.
hidden_dims
=
hidden_dims
self
.
gru
=
nn
.
GRU
(
input_size
=
hidden_dims
,
hidden_size
=
hidden_dims
,
num_layers
=
num_layers
,
batch_first
=
True
,
)
self
.
out
=
nn
.
Linear
(
hidden_dims
,
num_classes
)
def
forward
(
self
,
inputs
,
states
):
x
,
=
inputs
# (N, L, H)
h
,
=
states
# (N, L, H)
h
=
h
.
transpose
(
0
,
1
)
.
contiguous
()
# (L, N, H)
x
,
h
=
self
.
gru
(
x
,
h
)
# (N, L, H), (L, N, H)
h
=
h
.
transpose
(
0
,
1
)
.
contiguous
()
# (N, L, H)
return
(
x
,),
(
h
,
)
def
loss_fn
(
self
,
inputs
,
labels
)
->
Tensor
:
x
,
=
inputs
return
x
.
square
()
.
mean
()
def
get_group
(
self
)
->
Any
:
return
self
.
group
def
get_init_states
(
self
):
s
=
torch
.
zeros
(
self
.
num_layers
,
self
.
hidden_dims
)
.
to
(
self
.
device
)
return
(
s
,)
if
__name__
==
"__main__"
:
data_root
=
"./dataset"
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
hybrid_matrix
=
ctx
.
get_hybrid_matrix
()
if
hybrid_matrix
.
size
(
0
)
==
1
:
hybrid_matrix
=
hybrid_matrix
.
view
(
2
,
-
1
)
ctx
.
sync_print
(
hybrid_matrix
)
# sp is sequence parallel
# pp is partition parallel
sp_group
,
pp_group
=
ctx
.
new_hybrid_subgroups
(
hybrid_matrix
)
# partition data
if
ctx
.
rank
==
0
:
prepare_data
(
data_root
,
dist
.
get_world_size
(
pp_group
))
dist
.
barrier
()
g
=
GraphData
.
load_partition
(
data_root
,
dist
.
get_rank
(
pp_group
),
dist
.
get_world_size
(
pp_group
),
)
.
to
(
ctx
.
device
)
route
=
g
.
to_route
(
pp_group
)
# only on subgroup
num_features
=
g
.
node
(
"dst"
)[
"x"
]
.
size
(
-
1
)
num_classes
=
g
.
meta
()[
"num_classes"
]
hidden_dims
=
128
num_layers
=
3
gnn
=
SimpleGNN
(
num_features
,
hidden_dims
,
num_layers
)
.
to
(
ctx
.
device
)
rnn
=
SimpleRNN
(
num_classes
,
hidden_dims
,
num_layers
,
device
=
ctx
.
device
,
group
=
sp_group
)
.
to
(
ctx
.
device
)
opt
=
torch
.
optim
.
Adam
([
p
for
p
in
gnn
.
parameters
()]
+
[
p
for
p
in
rnn
.
parameters
()])
for
ep
in
range
(
1
,
100
+
1
):
seq_len
=
200
xs
=
[]
opt
.
zero_grad
()
for
_
in
range
(
seq_len
):
# snapshot parallel between partition parallel subgroups
z
=
gnn
(
x
=
g
.
node
(
"dst"
)[
"x"
],
edge_index
=
g
.
edge_index
(),
route
=
route
,
#
)
xs
.
append
(
z
.
unsqueeze
(
1
))
x
=
torch
.
cat
(
xs
,
dim
=
1
)
# (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss
=
rnn
.
fast_backward
(
32
,
(
x
,),
(
g
.
node
(
"dst"
)[
"train_mask"
],))
# all reduce
all_reduce_gradients
(
rnn
)
all_reduce_buffers
(
rnn
)
all_reduce_gradients
(
gnn
)
all_reduce_buffers
(
gnn
)
opt
.
step
()
ctx
.
sync_print
(
loss
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment