Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
6928f9b4
Commit
6928f9b4
authored
Jan 02, 2024
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add layerpipe
parent
28564281
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
270 additions
and
47 deletions
+270
-47
starrygl/parallel/__init__.py
+2
-0
starrygl/parallel/layerpipe.py
+213
-41
starrygl/parallel/timeline/pipe.py
+24
-4
train_hybrid.py
+31
-2
No files found.
starrygl/parallel/__init__.py
View file @
6928f9b4
from
.route
import
*
from
.route
import
*
from
.timeline
import
SequencePipe
from
.timeline
import
SequencePipe
from
.layerpipe
import
LayerPipe
,
LayerDetach
from
.sparse
import
*
from
.sparse
import
*
\ No newline at end of file
starrygl/parallel/layerpipe.py
View file @
6928f9b4
import
torch
import
torch
import
torch.nn
as
nn
import
torch.autograd
as
autograd
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
.route
import
Route
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
.timeline.utils
import
vector_backward
from
.utils
import
*
__all__
=
[
"LayerPipe"
,
"LayerDetach"
,
]
class
LayerPipe
(
ABC
):
class
LayerPipe
(
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
pass
self
.
_layer_id
:
Optional
[
int
]
=
None
self
.
_snapshot_id
:
Optional
[
int
]
=
None
@abstractmethod
self
.
_rts
:
List
[
LayerPipeRuntime
]
=
[]
def
get_group
(
self
):
raise
NotImplemented
@property
def
layer_id
(
self
)
->
int
:
assert
self
.
_layer_id
is
not
None
return
self
.
_layer_id
@property
def
snapshot_id
(
self
)
->
int
:
assert
self
.
_snapshot_id
is
not
None
return
self
.
_snapshot_id
def
apply
(
self
,
num_layers
:
int
,
num_snapshots
:
int
,
)
->
Sequence
[
Sequence
[
Tensor
]]:
runtime
=
LayerPipeRuntime
(
num_layers
,
num_snapshots
,
self
)
self
.
_rts
.
append
(
runtime
)
return
runtime
.
forward
()
def
backward
(
self
):
for
runtime
in
self
.
_rts
:
runtime
.
backward
()
self
.
_rts
.
clear
()
def
all_reduce
(
self
,
async_op
:
bool
=
False
):
works
=
[]
for
_
,
net
in
self
.
get_model
():
ws
=
all_reduce_gradients
(
net
,
async_op
=
async_op
)
if
async_op
:
works
.
extend
(
ws
)
ws
=
all_reduce_buffers
(
net
,
async_op
=
async_op
)
if
async_op
:
works
.
extend
(
ws
)
if
async_op
:
return
ws
def
to
(
self
,
device
:
Any
):
for
_
,
net
in
self
.
get_model
():
net
.
to
(
device
)
return
self
@abstractmethod
def
get_model
(
self
)
->
Sequence
[
Tuple
[
str
,
nn
.
Module
]]:
def
get_route
(
self
,
tag
:
int
)
->
Route
:
models
=
[]
raise
NotImplemented
for
key
in
dir
(
self
):
if
key
in
{
"layer_id"
,
"snapshot_id"
}:
continue
val
=
getattr
(
self
,
key
)
if
isinstance
(
val
,
nn
.
Module
):
models
.
append
((
key
,
val
))
return
tuple
(
models
)
@abstractmethod
@abstractmethod
def
get_graph
(
self
,
tag
:
int
)
->
Any
:
def
layer_inputs
(
self
,
raise
NotImplemented
inputs
:
Optional
[
Sequence
[
Tensor
]]
=
None
,
)
->
Sequence
[
Tensor
]:
raise
NotImplementedError
@abstractmethod
@abstractmethod
def
forward
(
self
,
def
layer_forward
(
self
,
tag
:
int
,
inputs
:
Sequence
[
Tensor
],
dist_inputs
:
Sequence
[
Tensor
],
)
->
Sequence
[
Tensor
]:
self_inputs
:
Sequence
[
Tensor
],
time_states
:
Sequence
[
Tensor
],
)
->
Tuple
[
Sequence
[
Tensor
],
Sequence
[
Tensor
],
Sequence
[
Tensor
],
]:
raise
NotImplemented
raise
NotImplemented
def
load_node
(
self
)
->
Sequence
[
Tensor
]:
@contextmanager
pass
def
_switch_layer
(
self
,
layer_id
:
int
,
snapshot_id
:
int
,
):
saved_layer_id
=
self
.
_layer_id
saved_snapshot_id
=
self
.
_snapshot_id
def
load_edge
(
self
)
->
Sequence
[
Tensor
]:
self
.
_layer_id
=
layer_id
pass
self
.
_snapshot_id
=
snapshot_id
try
:
yield
finally
:
self
.
_layer_id
=
saved_layer_id
self
.
_snapshot_id
=
saved_snapshot_id
def
load_conv_state
(
self
)
->
Sequence
[
Tensor
]:
pass
def
load_time_state
(
self
)
->
Sequence
[
Tensor
]:
class
LayerPipeRuntime
:
pass
def
__init__
(
self
,
num_layers
:
int
,
num_snapshots
:
int
,
program
:
LayerPipe
,
)
->
None
:
self
.
num_layers
=
num_layers
self
.
num_snapshots
=
num_snapshots
self
.
program
=
program
self
.
ready_bw
:
Dict
[
Any
,
LayerDetach
]
=
{}
def
save_as_node
(
self
,
*
outputs
:
Tensor
):
def
forward
(
self
)
->
Sequence
[
Sequence
[
Tensor
]]:
pass
for
op
,
layer_i
,
snap_i
in
ForwardFootprint
(
self
.
num_layers
,
self
.
num_snapshots
):
if
op
==
"sync"
:
xs
=
self
.
ready_bw
[(
layer_i
-
1
,
snap_i
,
1
)]
.
values
()
if
layer_i
>
0
else
None
with
self
.
program
.
_switch_layer
(
layer_i
,
snap_i
):
xs
=
self
.
program
.
layer_inputs
(
None
)
self
.
ready_bw
[(
layer_i
,
snap_i
,
0
)]
=
LayerDetach
(
*
xs
)
elif
op
==
"comp"
:
xs
=
self
.
ready_bw
[(
layer_i
,
snap_i
,
0
)]
.
values
()
with
self
.
program
.
_switch_layer
(
layer_i
,
snap_i
):
xs
=
self
.
program
.
layer_forward
(
xs
)
self
.
ready_bw
[(
layer_i
,
snap_i
,
1
)]
=
LayerDetach
(
*
xs
)
xs
=
[]
for
snap_i
in
range
(
self
.
num_snapshots
):
layer_i
=
self
.
num_layers
-
1
xs
.
append
(
self
.
ready_bw
[(
layer_i
,
snap_i
,
1
)]
.
values
())
return
xs
def
save_as_edge
(
self
,
*
outputs
:
Tensor
):
def
backward
(
self
):
pass
for
op
,
layer_i
,
snap_i
in
BackwardFootprint
(
self
.
num_layers
,
self
.
num_snapshots
):
if
op
==
"sync"
:
self
.
ready_bw
.
pop
((
layer_i
,
snap_i
,
0
))
.
backward
()
elif
op
==
"comp"
:
self
.
ready_bw
.
pop
((
layer_i
,
snap_i
,
1
))
.
backward
()
assert
len
(
self
.
ready_bw
)
==
0
def
save_as_conv_state
(
self
,
*
states
:
Tensor
):
pass
def
save_as_time_state
(
self
,
*
states
:
Tensor
):
class
LayerDetach
:
pass
def
__init__
(
self
,
*
inputs
:
Tensor
,
)
->
None
:
outputs
=
tuple
(
t
.
detach
()
for
t
in
inputs
)
for
s
,
t
in
zip
(
inputs
,
outputs
):
t
.
requires_grad_
(
s
.
requires_grad
)
self
.
_inputs
=
inputs
self
.
_outputs
=
outputs
def
values
(
self
)
->
Sequence
[
Tensor
]:
return
tuple
(
self
.
_outputs
)
def
backward
(
self
)
->
None
:
vec_loss
,
vec_grad
=
[],
[]
for
s
,
t
in
zip
(
self
.
_inputs
,
self
.
_outputs
):
g
,
t
.
grad
=
t
.
grad
,
None
if
not
s
.
requires_grad
:
continue
vec_loss
.
append
(
s
)
vec_grad
.
append
(
g
)
vector_backward
(
vec_loss
,
vec_grad
)
class
ForwardFootprint
:
def
__init__
(
self
,
num_layers
:
int
,
num_snapshots
:
int
,
)
->
None
:
self
.
_num_layers
=
num_layers
self
.
_num_snapshots
=
num_snapshots
def
__iter__
(
self
):
if
self
.
_num_layers
<=
0
or
self
.
_num_snapshots
<=
0
:
return
# starting
if
self
.
_num_snapshots
>
1
:
yield
"sync"
,
0
,
0
yield
"sync"
,
0
,
1
elif
self
.
_num_snapshots
>
0
:
yield
"sync"
,
0
,
0
for
i
in
range
(
0
,
self
.
_num_snapshots
,
2
):
for
l
in
range
(
self
.
_num_layers
):
# snapshot i
yield
"comp"
,
l
,
i
if
l
+
1
<
self
.
_num_layers
:
yield
"sync"
,
l
+
1
,
i
elif
i
+
2
<
self
.
_num_snapshots
:
yield
"sync"
,
0
,
i
+
2
# snapshot i + 1
if
i
+
1
>=
self
.
_num_snapshots
:
continue
yield
"comp"
,
l
,
i
+
1
if
l
+
1
<
self
.
_num_layers
:
yield
"sync"
,
l
+
1
,
i
+
1
elif
i
+
3
<
self
.
_num_snapshots
:
yield
"sync"
,
0
,
i
+
3
class
BackwardFootprint
:
def
__init__
(
self
,
num_layers
:
int
,
num_snapshots
:
int
,
)
->
None
:
self
.
_num_layers
=
num_layers
self
.
_num_snapshots
=
num_snapshots
def
__iter__
(
self
):
if
self
.
_num_layers
<=
0
or
self
.
_num_snapshots
<=
0
:
return
for
i
in
range
(
0
,
self
.
_num_snapshots
,
2
):
for
j
in
range
(
self
.
_num_layers
):
l
=
self
.
_num_layers
-
j
-
1
# snapshot i
yield
"comp"
,
l
,
i
yield
"sync"
,
l
,
i
# snapshot i + 1
if
i
+
1
>=
self
.
_num_snapshots
:
continue
yield
"comp"
,
l
,
i
+
1
yield
"sync"
,
l
,
i
+
1
class
SparseLayerPipe
:
def
__init__
(
self
)
->
None
:
pass
\ No newline at end of file
starrygl/parallel/timeline/pipe.py
View file @
6928f9b4
...
@@ -11,6 +11,7 @@ from contextlib import contextmanager
...
@@ -11,6 +11,7 @@ from contextlib import contextmanager
from
.sync
import
VirtualMotions
,
VirtualForward
,
BatchSync
from
.sync
import
VirtualMotions
,
VirtualForward
,
BatchSync
from
.utils
import
vector_backward
from
.utils
import
vector_backward
from
starrygl.parallel.utils
import
*
class
SequencePipe
(
ABC
):
class
SequencePipe
(
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -54,6 +55,24 @@ class SequencePipe(ABC):
...
@@ -54,6 +55,24 @@ class SequencePipe(ABC):
models
.
append
((
key
,
val
))
models
.
append
((
key
,
val
))
return
tuple
(
models
)
return
tuple
(
models
)
def
to
(
self
,
device
:
Any
):
for
_
,
net
in
self
.
get_model
():
net
.
to
(
device
)
return
self
def
all_reduce
(
self
,
async_op
:
bool
=
False
):
works
=
[]
for
name
,
net
in
self
.
get_model
():
ws
=
all_reduce_gradients
(
net
,
async_op
=
async_op
)
if
async_op
:
works
.
extend
(
ws
)
ws
=
all_reduce_buffers
(
net
,
async_op
=
async_op
)
if
async_op
:
works
.
extend
(
ws
)
if
async_op
:
return
ws
def
apply
(
self
,
bs
:
int
,
*
inputs
:
Tensor
)
->
Sequence
[
Tensor
]:
def
apply
(
self
,
bs
:
int
,
*
inputs
:
Tensor
)
->
Sequence
[
Tensor
]:
runtime
=
SequencePipeRuntime
(
bs
,
self
)
runtime
=
SequencePipeRuntime
(
bs
,
self
)
return
SequencePipeFunction
.
apply
(
runtime
,
*
inputs
)
return
SequencePipeFunction
.
apply
(
runtime
,
*
inputs
)
...
@@ -87,10 +106,11 @@ class SequencePipe(ABC):
...
@@ -87,10 +106,11 @@ class SequencePipe(ABC):
self
.
_pos_begin
=
begin
self
.
_pos_begin
=
begin
self
.
_pos_end
=
end
self
.
_pos_end
=
end
yield
try
:
yield
self
.
_pos_begin
=
saved_begin
finally
:
self
.
_pos_end
=
saved_end
self
.
_pos_begin
=
saved_begin
self
.
_pos_end
=
saved_end
class
SequencePipeRuntime
:
class
SequencePipeRuntime
:
...
...
train_hybrid.py
View file @
6928f9b4
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -9,7 +9,7 @@ from typing import *
...
@@ -9,7 +9,7 @@ from typing import *
from
starrygl.distributed
import
DistributedContext
from
starrygl.distributed
import
DistributedContext
from
starrygl.data
import
GraphData
from
starrygl.data
import
GraphData
from
starrygl.parallel
import
Route
,
SequencePipe
from
starrygl.parallel
import
Route
,
SequencePipe
,
LayerPipe
from
starrygl.parallel.utils
import
*
from
starrygl.parallel.utils
import
*
import
torch_geometric.nn
as
pyg_nn
import
torch_geometric.nn
as
pyg_nn
...
@@ -76,6 +76,35 @@ class SimpleGNN(nn.Module):
...
@@ -76,6 +76,35 @@ class SimpleGNN(nn.Module):
x
=
layer
(
x
,
edge_index
,
route
)
x
=
layer
(
x
,
edge_index
,
route
)
return
x
return
x
class
SimpleGNNPipe
(
LayerPipe
):
def
__init__
(
self
,
num_features
:
int
,
hidden_dims
:
int
,
num_layers
:
int
,
features
:
Tensor
,
edge_index
:
Tensor
,
route
:
Route
,
)
->
None
:
super
()
.
__init__
()
self
.
features
=
features
self
.
edge_index
=
edge_index
self
.
route
=
route
self
.
net
=
SimpleGNN
(
num_features
,
hidden_dims
,
num_layers
)
def
layer_inputs
(
self
,
inputs
:
Sequence
[
Tensor
]
|
None
=
None
)
->
Sequence
[
Tensor
]:
if
self
.
layer_id
==
0
:
x
=
self
.
features
else
:
x
,
=
inputs
x
=
self
.
route
.
apply
(
x
)
return
(
x
,)
def
layer_forward
(
self
,
inputs
:
Sequence
[
Tensor
])
->
Sequence
[
Tensor
]:
x
,
=
inputs
x
=
self
.
net
.
layers
[
self
.
layer_id
](
x
,
self
.
edge_index
,
self
.
route
)
return
(
x
,)
class
SimpleRNN
(
SequencePipe
,
nn
.
Module
):
class
SimpleRNN
(
SequencePipe
,
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
:
int
,
num_classes
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment