Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
09c175e2
Commit
09c175e2
authored
Dec 22, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add demo train_hybrid.py
parent
057dc57f
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1032 additions
and
342 deletions
+1032
-342
starrygl/distributed/cclib.py
+6
-0
starrygl/distributed/context.py
+85
-1
starrygl/parallel/route.py
+7
-1
starrygl/parallel/sequence.py
+731
-337
starrygl/parallel/sparse.py
+17
-3
train_hybrid.py
+186
-0
No files found.
starrygl/distributed/cclib.py
View file @
09c175e2
...
@@ -152,8 +152,11 @@ def batch_send(
...
@@ -152,8 +152,11 @@ def batch_send(
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
return
BatchWork
(
None
,
None
)
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
dst
=
dist
.
get_global_rank
(
group
,
dst
)
if
async_op
:
if
async_op
:
works
=
[]
works
=
[]
...
@@ -177,8 +180,11 @@ def batch_recv(
...
@@ -177,8 +180,11 @@ def batch_recv(
if
len
(
tensors
)
==
0
:
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
return
BatchWork
(
None
,
None
)
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
src
=
dist
.
get_global_rank
(
group
,
src
)
if
async_op
:
if
async_op
:
works
=
[]
works
=
[]
...
...
starrygl/distributed/context.py
View file @
09c175e2
...
@@ -7,6 +7,7 @@ import os
...
@@ -7,6 +7,7 @@ import os
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
import
socket
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
logging
import
logging
...
@@ -127,6 +128,18 @@ class DistributedContext:
...
@@ -127,6 +128,18 @@ class DistributedContext:
self
.
_local_rank
=
local_rank
self
.
_local_rank
=
local_rank
self
.
_compute_device
=
device
self
.
_compute_device
=
device
self
.
_hostname
=
socket
.
gethostname
()
if
self
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
set_device
(
self
.
device
)
rank_to_host
=
[
None
]
*
self
.
world_size
dist
.
all_gather_object
(
rank_to_host
,
(
self
.
hostname
,
self
.
local_rank
))
self
.
_rank_to_host
:
Tuple
[
Tuple
[
str
,
int
],
...
]
=
tuple
(
rank_to_host
)
host_index
=
[
h
for
h
,
_
in
self
.
rank_to_host
]
host_index
.
sort
()
self
.
_host_index
:
Dict
[
str
,
int
]
=
{
h
:
i
for
i
,
h
in
enumerate
(
host_index
)}
self
.
__temp_ag_remote_object
:
Optional
[
rpc
.
RRef
]
=
None
self
.
__temp_ag_remote_object
:
Optional
[
rpc
.
RRef
]
=
None
...
@@ -145,16 +158,87 @@ class DistributedContext:
...
@@ -145,16 +158,87 @@ class DistributedContext:
@property
@property
def
local_rank
(
self
)
->
int
:
def
local_rank
(
self
)
->
int
:
return
self
.
_local_rank
return
self
.
_local_rank
@property
def
hostname
(
self
)
->
str
:
return
self
.
_hostname
@property
def
rank_to_host
(
self
):
return
self
.
_rank_to_host
@property
def
host_index
(
self
):
return
self
.
_host_index
@property
@property
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
_compute_device
return
self
.
_compute_device
def
get_default_group
(
self
):
def
get_default_group
(
self
):
return
dist
.
distributed_c10d
.
_get_default_group
()
# return dist.distributed_c10d._get_default_group()
return
dist
.
GroupMember
.
WORLD
def
get_default_store
(
self
):
def
get_default_store
(
self
):
return
dist
.
distributed_c10d
.
_get_default_store
()
return
dist
.
distributed_c10d
.
_get_default_store
()
def
get_ranks_by_host
(
self
,
hostname
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
...
]:
if
hostname
is
None
:
hostname
=
self
.
hostname
ranks
:
List
[
int
]
=
[]
for
i
,
(
h
,
r
)
in
enumerate
(
self
.
rank_to_host
):
if
h
==
hostname
:
ranks
.
append
(
i
)
ranks
.
sort
()
return
tuple
(
ranks
)
def
get_ranks_by_local
(
self
,
local_rank
:
Optional
[
int
]
=
None
)
->
Tuple
[
int
,
...
]:
if
local_rank
is
None
:
local_rank
=
self
.
local_rank
ranks
:
List
[
Tuple
[
int
,
str
]]
=
[]
for
i
,
(
h
,
r
)
in
enumerate
(
self
.
rank_to_host
):
if
r
==
local_rank
:
ranks
.
append
((
i
,
h
))
ranks
.
sort
(
key
=
lambda
x
:
self
.
host_index
[
x
[
1
]])
return
tuple
(
i
for
i
,
h
in
ranks
)
def
get_hybrid_matrix
(
self
)
->
Tensor
:
hosts
=
sorted
(
self
.
host_index
.
items
(),
key
=
lambda
x
:
x
[
1
])
matrix
=
[]
for
h
,
_
in
hosts
:
rs
=
self
.
get_ranks_by_host
(
h
)
matrix
.
append
(
rs
)
return
torch
.
tensor
(
matrix
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
def
new_hybrid_subgroups
(
self
,
matrix
:
Optional
[
Tensor
]
=
None
,
backend
:
Any
=
None
,
)
->
Tuple
[
Any
,
Any
]:
if
matrix
is
None
:
matrix
=
self
.
get_hybrid_matrix
()
assert
matrix
.
dim
()
==
2
row_group
=
None
col_group
=
None
for
row
in
matrix
.
tolist
():
if
self
.
rank
in
row
:
row_group
=
dist
.
new_group
(
row
,
backend
=
backend
,
use_local_synchronization
=
True
,
)
break
for
col
in
matrix
.
t
()
.
tolist
():
if
self
.
rank
in
col
:
col_group
=
dist
.
new_group
(
col
,
backend
=
backend
,
use_local_synchronization
=
True
,
)
break
assert
row_group
is
not
None
assert
col_group
is
not
None
return
row_group
,
col_group
def
get_worker_info
(
self
,
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
def
get_worker_info
(
self
,
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
...
...
starrygl/parallel/route.py
View file @
09c175e2
...
@@ -24,6 +24,9 @@ class Route:
...
@@ -24,6 +24,9 @@ class Route:
bipartite
:
bool
=
True
,
bipartite
:
bool
=
True
,
group
:
Any
=
None
,
group
:
Any
=
None
,
)
->
'Route'
:
)
->
'Route'
:
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
fw_tables
,
bw_tables
=
Route
.
_build_route_tables
(
fw_tables
,
bw_tables
=
Route
.
_build_route_tables
(
src_ids
=
src_ids
,
dst_ids
=
dst_ids
,
src_ids
=
src_ids
,
dst_ids
=
dst_ids
,
bipartite
=
bipartite
,
group
=
group
,
bipartite
=
bipartite
,
group
=
group
,
...
@@ -256,8 +259,11 @@ class Route:
...
@@ -256,8 +259,11 @@ class Route:
all_dst_ids
[
i
]
=
dst_ids
all_dst_ids
[
i
]
=
dst_ids
else
:
else
:
all_dst_ids
[
i
]
=
torch
.
empty
(
all_dst_lens
[
i
],
**
ikw
)
all_dst_ids
[
i
]
=
torch
.
empty
(
all_dst_lens
[
i
],
**
ikw
)
src_rank
=
dist
.
get_global_rank
(
group
,
i
)
all_dst_get
[
i
]
=
dist
.
broadcast
(
all_dst_get
[
i
]
=
dist
.
broadcast
(
all_dst_ids
[
i
],
src
=
i
,
async_op
=
True
,
group
=
group
all_dst_ids
[
i
],
src
=
src_rank
,
async_op
=
True
,
group
=
group
,
)
)
fw_tables
:
List
[
Tensor
]
=
[]
fw_tables
:
List
[
Tensor
]
=
[]
...
...
starrygl/parallel/sequence.py
View file @
09c175e2
...
@@ -6,394 +6,788 @@ import torch.distributed as dist
...
@@ -6,394 +6,788 @@ import torch.distributed as dist
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.distributed.cclib
import
batch_send
,
batch_recv
,
BatchWork
from
starrygl.distributed.cclib
import
*
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
.utils
import
all_reduce_buffers
,
all_reduce_gradients
__all__
=
[
__all__
=
[
"SequencePipe"
,
"SequencePipe"
,
]
]
OptTensor
=
Optional
[
Tensor
]
SInteger
=
Sequence
[
int
]
STensor
=
Sequence
[
Tensor
]
SOptTensor
=
Sequence
[
OptTensor
]
PairSTensor
=
Tuple
[
STensor
,
STensor
]
OptSTensor
=
Optional
[
STensor
]
class
SequencePipe
(
nn
.
Module
,
ABC
):
class
SequencePipe
(
ABC
):
def
__init__
(
self
,
def
__init__
(
self
)
->
None
:
batch_size
:
int
,
seq_ranks
:
Optional
[
List
[
int
]]
=
None
,
group
:
Any
=
None
,
)
->
None
:
super
()
.
__init__
()
super
()
.
__init__
()
self
.
_batch_size
=
int
(
batch_size
)
self
.
_pos_begin
=
0
self
.
_pos_end
=
0
if
seq_ranks
is
None
:
self
.
_pos_start
=
True
seq_ranks
=
list
(
range
(
dist
.
get_world_size
(
group
)))
self
.
_seq_ranks
=
tuple
(
seq_ranks
)
self
.
_group
=
group
rank
=
dist
.
get_rank
(
group
)
self
.
_index
=
self
.
_seq_ranks
.
index
(
rank
)
self
.
_init_states
:
Optional
[
Tuple
[
Tensor
,
...
]]
=
None
self
.
_all_reduce_grads_op
=
None
self
.
_all_reduce_buffers_op
=
None
@
property
@
abstractmethod
def
batch_size
(
self
)
->
int
:
def
get_group
(
self
)
->
Any
:
r
eturn
self
.
_batch_size
r
aise
NotImplementedError
@
property
@
abstractmethod
def
seq_ranks
(
self
)
:
def
get_init_states
(
self
)
->
STensor
:
r
eturn
self
.
_seq_ranks
r
aise
NotImplementedError
@property
@abstractmethod
def
num_ranks
(
self
)
->
int
:
def
forward
(
self
,
inputs
:
STensor
,
states
:
STensor
)
->
PairSTensor
:
return
len
(
self
.
_seq_ranks
)
raise
NotImplementedError
def
loss_fn
(
self
,
inputs
:
STensor
,
labels
:
STensor
)
->
Tensor
:
raise
NotImplementedError
def
get_ranks
(
self
)
->
SInteger
:
world_size
=
dist
.
get_world_size
(
self
.
get_group
())
return
tuple
(
range
(
world_size
))
def
apply
(
self
,
bs
:
int
,
*
inputs
:
Tensor
):
return
SequencePipeFunction
.
apply
(
bs
,
self
,
*
inputs
)
def
fast_backward
(
self
,
bs
:
int
,
inputs
:
STensor
,
labels
:
STensor
)
->
Tensor
:
runtime
=
SequencePipeRuntime
(
bs
,
self
,
last_layer
=
True
)
inputs_grads
=
runtime
.
backward
(
inputs
,
labels
)
runtime
.
vector_backward
(
inputs
,
inputs_grads
)
return
runtime
.
acc_loss
@property
@property
def
index
(
self
)
->
int
:
def
begin
(
self
)
->
int
:
return
self
.
_
index
return
self
.
_
pos_begin
@property
@property
def
group
(
self
):
def
end
(
self
)
->
int
:
return
self
.
_group
return
self
.
_pos_end
def
init_states
(
self
,
*
states
:
Tensor
):
for
s
in
states
:
assert
isinstance
(
s
,
Tensor
),
f
"states must be tuple of tensors"
self
.
_init_states
=
tuple
(
states
)
return
self
def
enable_all_reduce
(
self
,
grads_op
=
dist
.
ReduceOp
.
SUM
,
buffers_op
=
dist
.
ReduceOp
.
AVG
,
):
self
.
_all_reduce_grads_op
=
grads_op
self
.
_all_reduce_buffers_op
=
buffers_op
return
self
def
disable_all_reduce
(
self
):
@property
self
.
_all_reduce_grads_op
=
None
def
start
(
self
)
->
bool
:
self
.
_all_reduce_buffers_op
=
None
return
self
.
_pos_start
return
self
@abstractmethod
@property
def
state_forward
(
self
,
def
batch_size
(
self
)
->
int
:
batch_id
:
int
,
return
self
.
_pos_end
-
self
.
_pos_begin
inputs
:
Tuple
[
Tensor
,
...
],
states
:
Tuple
[
Tensor
,
...
],
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
]]:
raise
NotImplementedError
()
@abstractmethod
@contextmanager
def
loss_fn
(
self
,
def
_switch
(
self
,
batch_id
:
int
,
begin
:
int
,
end
:
int
,
start
:
bool
,
outputs
:
Tuple
[
Tensor
,
...
],
):
)
->
Tensor
:
saved_begin
=
self
.
_pos_begin
raise
NotImplementedError
()
saved_end
=
self
.
_pos_end
saved_start
=
self
.
_pos_start
@torch.inference_mode
()
def
forward
(
self
,
*
inputs
:
Tensor
)
->
Tuple
[
Tensor
,
...
]:
self
.
_pos_begin
=
begin
detach
=
self
.
_detach_inputs
(
inputs
)
self
.
_pos_end
=
end
self
.
_pos_start
=
start
B
=
inputs
[
0
]
.
size
(
0
)
yield
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
last_work
=
None
self
.
_pos_begin
=
saved_begin
self
.
_pos_end
=
saved_end
self
.
_pos_start
=
saved_start
outputs
=
None
for
batch_id
in
range
(
num_batchs
):
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
detach
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
)
batch_outputs
,
batch_states
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_states
)
if
outputs
is
None
:
outputs
=
[]
for
t
in
batch_outputs
:
t
=
torch
.
empty
(
B
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
outputs
.
append
(
t
)
outputs
=
tuple
(
outputs
)
for
o
,
t
in
zip
(
outputs
,
batch_outputs
):
o
[
start
:
end
]
=
t
.
data
if
last_work
is
not
None
:
class
SequencePipeRuntime
:
last_work
.
wait
()
def
__init__
(
self
,
last_work
=
self
.
_save_states
(
batch_states
)
micro_batch_size
:
int
,
program
:
SequencePipe
,
if
last_work
is
not
None
:
last_layer
:
bool
=
False
,
last_work
.
wait
()
)
->
None
:
self
.
micro_batch_size
=
micro_batch_size
self
.
program
=
program
self
.
last_layer
=
last_layer
self
.
acc_loss
=
None
self
.
group
=
program
.
get_group
()
self
.
ranks
=
program
.
get_ranks
()
self
.
index
=
self
.
ranks
.
index
(
dist
.
get_rank
(
self
.
group
))
self
.
_last_work
=
None
def
forward
(
self
,
inputs
:
STensor
)
->
STensor
:
detach
=
self
.
detach_inputs
(
inputs
)
N
=
inputs
[
0
]
.
size
(
0
)
S
=
(
N
+
self
.
micro_batch_size
-
1
)
//
self
.
micro_batch_size
outputs
=
None
for
i
in
range
(
S
):
begin
=
i
*
self
.
micro_batch_size
end
=
min
(
N
,
begin
+
self
.
micro_batch_size
)
start
=
(
self
.
index
==
0
)
with
self
.
program
.
_switch
(
begin
,
end
,
start
):
batch_inputs
=
self
.
get_batch_inputs
(
begin
,
end
,
detach
)
batch_states
=
self
.
get_batch_states
(
begin
,
end
)
batch_outputs
,
batch_states
=
self
.
forward_inner
(
batch_inputs
,
batch_states
)
if
outputs
is
None
:
outputs
=
[]
for
t
in
batch_outputs
:
t
=
torch
.
empty
(
N
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
outputs
.
append
(
t
)
outputs
=
tuple
(
outputs
)
for
t
,
b
in
zip
(
outputs
,
batch_outputs
):
t
[
begin
:
end
]
=
b
.
data
self
.
wait_last_work
(
next_one
=
self
.
save_states
(
batch_states
),
)
self
.
wait_last_work
()
return
outputs
return
outputs
def
backward
(
self
,
*
inputs
:
Tensor
,
scale
:
float
=
1.0
)
->
Tensor
:
detach
=
self
.
_detach_inputs
(
inputs
)
B
=
inputs
[
0
]
.
size
(
0
)
def
backward
(
self
,
inputs
:
STensor
,
output_grads
:
OptSTensor
=
None
)
->
STensor
:
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
detach
=
self
.
detach_inputs
(
inputs
)
detach_grads
=
self
.
detach_inputs
(
output_grads
)
footprint
=
F1B1Footprint
(
self
.
index
,
self
.
num_ranks
,
num_batchs
)
N
=
inputs
[
0
]
.
size
(
0
)
source_footprint
=
F1B1Footprint
(
self
.
index
-
1
,
self
.
num_ranks
,
num_batchs
)
S
=
(
N
+
self
.
micro_batch_size
-
1
)
//
self
.
micro_batch_size
target_footprint
=
F1B1Footprint
(
self
.
index
+
1
,
self
.
num_ranks
,
num_batchs
)
fw_ready
:
Dict
[
int
,
Any
]
=
{}
footprint
=
F1B1Footprint
(
self
.
index
,
len
(
self
.
ranks
),
S
)
bw_ready
:
Dict
[
int
,
Any
]
=
{}
source_footprint
=
F1B1Footprint
(
self
.
index
-
1
,
len
(
self
.
ranks
),
S
)
target_footprint
=
F1B1Footprint
(
self
.
index
+
1
,
len
(
self
.
ranks
),
S
)
last_work
=
None
fw_ready
=
{}
bw_ready
=
{}
# input_batch_grads = []
input_grads
=
None
input_grads
=
None
total_loss
=
None
while
True
:
while
True
:
_
,
op
,
batch_id
=
footprint
.
step
()
_
,
op
,
i
=
footprint
.
step
()
_
,
source_op
,
source_
batch_id
=
source_footprint
.
step
()
_
,
source_op
,
source_
i
=
source_footprint
.
step
()
_
,
target_op
,
target_
batch_id
=
target_footprint
.
step
()
_
,
target_op
,
target_
i
=
target_footprint
.
step
()
if
op
is
None
and
source_op
is
None
and
target_op
is
None
:
if
(
not
op
)
and
(
not
source_op
)
and
(
not
target_op
)
:
break
break
if
last_work
is
not
None
:
self
.
wait_last_work
()
last_work
.
wait
()
last_work
=
None
if
source_op
==
"backward"
:
if
source_op
==
"backward"
:
input_states
,
=
bw_ready
.
pop
(
source_batch_id
)
batch_input_state_grads
,
=
bw_ready
.
pop
(
source_i
)
last_work
=
self
.
_save_states
(
input_states
,
grad
=
True
)
self
.
wait_last_work
(
del
input_states
next_one
=
self
.
save_grads
(
batch_input_state_grads
),
elif
target_op
==
"forward"
:
*
_
,
output_states
=
fw_ready
[
target_batch_id
]
last_work
=
self
.
_save_states
(
output_states
)
del
_
,
output_states
if
op
==
"forward"
:
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
detach
,
inputs
)
batch_input_states
=
self
.
_get_batch_states
(
end
-
start
,
requires_grad
=
True
)
batch_outputs
,
batch_output_states
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_input_states
,
)
)
fw_ready
[
batch_id
]
=
[
elif
target_op
==
"forward"
:
batch_inputs
,
*
_
,
batch_output_states
=
fw_ready
[
target_i
]
batch_outputs
,
self
.
wait_last_work
(
batch_input_states
,
next_one
=
self
.
save_states
(
batch_output_states
),
batch_output_states
,
]
elif
op
==
"backward"
:
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
=
fw_ready
.
pop
(
batch_id
)
scale_factor
=
scale
*
self
.
batch_size
/
(
end
-
start
)
grads
,
loss
=
self
.
_backward_inner
(
batch_id
,
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
,
scale_factor
=
scale_factor
,
)
)
bw_ready
[
batch_id
]
=
[
del
_
batch_input_states
,
]
begin
=
i
*
self
.
micro_batch_size
end
=
min
(
N
,
begin
+
self
.
micro_batch_size
)
total_loss
=
loss
if
total_loss
is
None
else
total_loss
+
loss
start
=
(
self
.
index
==
0
)
with
self
.
program
.
_switch
(
begin
,
end
,
start
):
if
input_grads
is
None
:
if
op
==
"forward"
:
input_grads
=
[]
batch_inputs
=
self
.
get_batch_inputs
(
begin
,
end
,
detach
,
inputs
)
for
t
in
grads
:
batch_input_states
=
self
.
get_batch_states
(
begin
,
end
,
requires_grad
=
True
)
if
t
is
not
None
:
t
=
torch
.
empty
(
B
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
batch_outputs
,
batch_output_states
=
\
input_grads
.
append
(
t
)
self
.
forward_inner
(
batch_inputs
,
batch_input_states
)
input_grads
=
tuple
(
input_grads
)
fw_ready
[
i
]
=
[
batch_inputs
,
batch_outputs
,
for
g
,
t
in
zip
(
input_grads
,
grads
):
batch_input_states
,
batch_output_states
,
if
g
is
not
None
:
]
g
[
start
:
end
]
=
t
.
data
elif
op
==
"backward"
:
batch_inputs
,
batch_outputs
,
\
batch_input_states
,
batch_output_states
=
fw_ready
.
pop
(
i
)
batch_output_grads
=
self
.
get_batch_inputs
(
begin
,
end
,
detach_grads
)
batch_input_grads
,
batch_input_state_grads
=
\
self
.
backward_inner
(
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
,
batch_output_grads
,
)
bw_ready
[
i
]
=
[
batch_input_state_grads
,
]
if
input_grads
is
None
:
input_grads
=
[]
for
t
in
batch_input_grads
:
if
t
is
not
None
:
t
=
torch
.
empty
(
N
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
input_grads
.
append
(
t
)
input_grads
=
tuple
(
input_grads
)
for
g
,
t
in
zip
(
input_grads
,
batch_input_grads
):
if
g
is
not
None
:
g
[
begin
:
end
]
=
t
.
data
self
.
wait_last_work
()
return
input_grads
def
wait_last_work
(
self
,
next_one
=
None
):
if
self
.
_last_work
is
not
None
:
self
.
_last_work
.
wait
()
self
.
_last_work
=
next_one
def
detach_inputs
(
self
,
inputs
:
STensor
)
->
STensor
:
detach
:
STensor
=
[]
for
t
in
inputs
:
assert
t
.
size
(
0
)
==
inputs
[
0
]
.
size
(
0
),
"The first dimension of all tensors must be the same."
detach
.
append
(
t
.
detach
())
return
detach
def
get_batch_inputs
(
self
,
begin
:
int
,
end
:
int
,
detach
:
STensor
,
inputs
:
OptSTensor
=
None
,
)
->
STensor
:
batch
:
STensor
=
[]
for
i
,
t
in
enumerate
(
detach
):
assert
not
t
.
requires_grad
if
inputs
and
inputs
[
i
]
.
requires_grad
:
t
=
t
[
begin
:
end
]
t
.
requires_grad_
()
t
.
retain_grad
()
batch
.
append
(
t
)
else
:
batch
.
append
(
t
[
begin
:
end
])
return
batch
def
get_batch_states
(
self
,
begin
:
int
,
end
:
int
,
requires_grad
:
bool
=
False
,
)
->
STensor
:
states
=
[]
for
s
in
self
.
program
.
get_init_states
():
s
=
s
.
unsqueeze
(
0
)
.
broadcast_to
(
end
-
begin
,
*
s
.
size
())
.
contiguous
()
if
requires_grad
and
self
.
index
>
0
and
s
.
is_floating_point
():
s
.
requires_grad_
()
s
.
retain_grad
()
states
.
append
(
s
)
return
states
def
forward_inner
(
self
,
inputs
:
STensor
,
states
:
STensor
,
):
states
=
self
.
load_states
(
states
)
return
self
.
program
.
forward
(
inputs
,
states
)
def
backward_inner
(
self
,
inputs
:
STensor
,
outputs
:
STensor
,
input_states
:
STensor
,
output_states
:
STensor
,
output_grads
:
STensor
,
)
->
PairSTensor
:
vloss
=
[]
vgrad
=
[]
if
last_work
is
not
None
:
if
self
.
last_layer
:
last_work
.
wait
()
vloss
.
append
(
self
.
program
.
loss_fn
(
outputs
,
output_grads
))
vgrad
.
append
(
torch
.
ones_like
(
vloss
[
-
1
]))
if
self
.
acc_loss
is
None
:
self
.
acc_loss
=
vloss
[
-
1
]
.
detach
()
else
:
self
.
acc_loss
+=
vloss
[
-
1
]
.
detach
()
else
:
vloss
.
extend
(
outputs
)
vgrad
.
extend
(
output_grads
)
prev_works
=
self
.
_prev_inputs_backward
()
vloss
.
extend
(
output_states
)
self
.
_vector_backward
(
inputs
,
input_grads
)
vgrad
.
extend
(
self
.
load_grads
(
output_states
))
self
.
_post_inputs_backward
(
prev_works
)
self
.
vector_backward
(
vloss
,
vgrad
)
input_grads
=
[]
for
t
in
inputs
:
g
,
t
.
grad
=
t
.
grad
,
None
input_grads
.
append
(
g
)
return
total_loss
input_state_grads
=
[]
for
s
in
input_states
:
def
_prev_inputs_backward
(
self
):
g
,
s
.
grad
=
s
.
grad
,
None
works
=
[]
if
s
.
is_floating_point
():
g
=
torch
.
zeros_like
(
s
)
if
g
is
None
else
g
works
.
extend
(
all_reduce_gradients
(
else
:
self
,
op
=
self
.
_all_reduce_grads_op
,
g
=
None
group
=
self
.
group
,
async_op
=
True
,
input_state_grads
.
append
(
g
)
))
works
.
extend
(
all_reduce_buffers
(
return
input_grads
,
input_state_grads
self
,
op
=
self
.
_all_reduce_buffers_op
,
group
=
self
.
group
,
async_op
=
True
,
def
load_states
(
self
,
states
:
STensor
):
))
if
self
.
index
>
0
:
batch_recv
(
return
works
*
[
s
.
data
for
s
in
states
],
src
=
self
.
ranks
[
self
.
index
-
1
],
def
_post_inputs_backward
(
self
,
works
):
group
=
self
.
group
,
for
w
in
works
:
async_op
=
True
,
w
.
wait
()
)
.
wait
()
works
.
clear
()
return
states
def
_load_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
):
def
load_grads
(
self
,
states
:
STensor
):
for
s
in
states
:
grads
:
SOptTensor
=
[]
s
.
grad
=
None
if
self
.
index
+
1
<
len
(
self
.
ranks
):
for
s
in
states
:
if
grad
:
# from target to source
if
s
.
is_floating_point
():
if
self
.
index
+
1
<
self
.
num_ranks
:
g
=
torch
.
zeros_like
(
s
)
s_grads
=
[]
grads
.
append
(
g
)
for
s
in
states
:
else
:
if
s
.
dtype
.
is_floating_point
:
grads
.
append
(
None
)
s
.
grad
=
torch
.
zeros_like
(
s
.
data
)
s_grads
.
append
(
s
.
grad
)
batch_recv
(
*
[
g
.
data
for
g
in
grads
if
g
is
not
None
],
batch_recv
(
src
=
self
.
ranks
[
self
.
index
+
1
],
*
s_grads
,
group
=
self
.
group
,
src
=
self
.
seq_ranks
[
self
.
index
+
1
],
async_op
=
True
,
group
=
self
.
group
,
)
.
wait
()
async_op
=
True
,
else
:
)
.
wait
()
for
s
in
states
:
else
:
# from source to target
grads
.
append
(
None
)
if
self
.
index
>
0
:
return
grads
batch_recv
(
*
[
s
.
data
for
s
in
states
],
def
save_states
(
self
,
states
:
STensor
):
src
=
self
.
seq_ranks
[
self
.
index
-
1
],
if
self
.
index
+
1
<
len
(
self
.
ranks
):
group
=
self
.
group
,
return
batch_send
(
async_op
=
True
,
*
[
s
.
data
for
s
in
states
],
)
.
wait
()
dst
=
self
.
ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
def
_save_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
)
->
BatchWork
:
async_op
=
True
,
if
grad
:
# from target to source
)
if
self
.
index
>
0
:
else
:
s_grads
=
[]
return
BatchWork
(
None
,
None
)
for
s
in
states
:
g
,
s
.
grad
=
s
.
grad
,
None
if
s
.
dtype
.
is_floating_point
:
s_grads
.
append
(
torch
.
zeros_like
(
s
)
if
g
is
None
else
g
)
return
batch_send
(
*
s_grads
,
dst
=
self
.
seq_ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
else
:
# from source to target
if
self
.
index
+
1
<
self
.
num_ranks
:
return
batch_send
(
*
[
s
.
data
for
s
in
states
],
dst
=
self
.
seq_ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
async_op
=
True
)
return
BatchWork
(
None
,
None
)
def
_vector_backward
(
self
,
vec_loss
:
List
[
Tensor
],
vec_grad
:
List
[
Optional
[
Tensor
]]):
def
save_grads
(
self
,
grads
:
STensor
):
loss
=
[]
if
self
.
index
>
0
:
grad
=
[]
return
batch_send
(
for
x
,
g
in
zip
(
vec_loss
,
vec_grad
):
*
[
g
.
data
for
g
in
grads
if
g
is
not
None
],
dst
=
self
.
ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
else
:
return
BatchWork
(
None
,
None
)
def
vector_backward
(
self
,
vloss
:
STensor
,
vgrad
:
STensor
):
loss
:
List
[
Tensor
]
=
[]
grad
:
List
[
Tensor
]
=
[]
for
x
,
g
in
zip
(
vloss
,
vgrad
):
if
g
is
None
:
if
g
is
None
:
continue
continue
if
not
x
.
requires_grad
:
if
not
x
.
requires_grad
:
continue
continue
# if not x.dtype.is_floating_point:
# continue
loss
.
append
(
x
.
flatten
())
loss
.
append
(
x
.
flatten
())
grad
.
append
(
g
.
flatten
())
grad
.
append
(
g
.
flatten
())
if
l
en
(
loss
)
!=
0
:
if
l
oss
:
loss
=
torch
.
cat
(
loss
,
dim
=
0
)
loss
=
torch
.
cat
(
loss
,
dim
=
0
)
grad
=
torch
.
cat
(
grad
,
dim
=
0
)
grad
=
torch
.
cat
(
grad
,
dim
=
0
)
loss
.
backward
(
grad
)
loss
.
backward
(
grad
)
class
SequencePipeFunction
(
autograd
.
Function
):
@staticmethod
def
forward
(
ctx
:
autograd
.
function
.
FunctionCtx
,
micro_batch_size
:
int
,
program
:
SequencePipe
,
*
inputs
:
Tensor
,
):
runtime
=
SequencePipeRuntime
(
micro_batch_size
,
program
)
ctx
.
save_for_backward
(
*
inputs
)
ctx
.
saved_runtime
=
runtime
return
runtime
.
forward
(
inputs
)
@staticmethod
def
backward
(
ctx
:
autograd
.
function
.
FunctionCtx
,
*
grads
:
Tensor
,
):
inputs
:
STensor
=
ctx
.
saved_tensors
runtime
:
SequencePipeRuntime
=
ctx
.
saved_runtime
with
torch
.
enable_grad
():
input_grads
=
runtime
.
backward
(
inputs
,
grads
)
return
None
,
None
,
*
input_grads
def
_get_batch_inputs
(
self
,
start
:
int
,
end
:
int
,
# class SequencePipe(nn.Module,ABC):
detach
:
Tuple
[
Tensor
,
...
],
# def __init__(self,
inputs
:
Optional
[
Tuple
[
Tensor
,
...
]]
=
None
,
# batch_size: int,
)
->
Tuple
[
Tensor
,
...
]:
# seq_ranks: Optional[List[int]] = None,
outs
:
List
[
Tensor
]
=
[]
# group: Any = None,
for
i
,
t
in
enumerate
(
detach
):
# ) -> None:
assert
not
t
.
requires_grad
# super().__init__()
if
inputs
is
None
:
# self._batch_size = int(batch_size)
outs
.
append
(
t
[
start
:
end
])
elif
inputs
[
i
]
.
requires_grad
:
# if seq_ranks is None:
t
=
t
[
start
:
end
]
# seq_ranks = list(range(dist.get_world_size(group)))
t
.
requires_grad_
()
# self._seq_ranks = tuple(seq_ranks)
t
.
retain_grad
()
# self._group = group
outs
.
append
(
t
)
else
:
# rank = dist.get_rank(group)
outs
.
append
(
t
[
start
:
end
])
# self._index = self._seq_ranks.index(rank)
return
tuple
(
outs
)
# self._init_states: Optional[Tuple[Tensor,...]] = None
def
_get_batch_states
(
self
,
bs
:
int
,
requires_grad
:
bool
=
False
)
->
Tuple
[
Tensor
,
...
]:
assert
self
.
_init_states
is
not
None
,
"please call init_states()."
# self._all_reduce_grads_op = None
states
:
List
[
Tensor
]
=
[]
# self._all_reduce_buffers_op = None
for
s
in
self
.
_init_states
:
# self._all_reduce_group = None
s
=
s
.
unsqueeze
(
0
)
.
broadcast_to
(
bs
,
*
s
.
size
())
.
contiguous
()
if
requires_grad
and
self
.
index
>
0
and
s
.
dtype
.
is_floating_point
:
s
.
requires_grad_
()
s
.
retain_grad
()
states
.
append
(
s
)
return
tuple
(
states
)
def
_detach_inputs
(
self
,
inputs
:
Tuple
[
Tensor
,
...
])
->
Tuple
[
Tensor
,
...
]:
# @property
assert
inputs
[
0
]
.
size
(
0
)
>
0
,
"input tensors' batch dimension must be greater than one."
# def batch_size(self) -> int:
detach_inputs
:
List
[
Tensor
]
=
[]
# return self._batch_size
for
t
in
inputs
:
assert
t
.
size
(
0
)
==
inputs
[
0
]
.
size
(
0
),
"all tensors' batch dimension must be the exact same."
# @property
detach_inputs
.
append
(
t
.
detach
())
# def seq_ranks(self):
return
tuple
(
detach_inputs
)
# return self._seq_ranks
def
_forward_inner
(
self
,
# @property
batch_id
:
int
,
# def num_ranks(self) -> int:
inputs
:
Tuple
[
Tensor
,
...
],
# return len(self._seq_ranks)
states
:
Tuple
[
Tensor
,
...
],
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
]]:
# @property
# def index(self) -> int:
# return self._index
# @property
# def group(self):
# return self._group
# def init_states(self, *states: Tensor):
# for s in states:
# assert isinstance(s, Tensor), f"states must be tuple of tensors"
# self._init_states = tuple(states)
# return self
# def enable_all_reduce(self,
# grads_op = dist.ReduceOp.SUM,
# buffers_op = dist.ReduceOp.AVG,
# group: Any = None,
# ):
# self._all_reduce_grads_op = grads_op
# self._all_reduce_buffers_op = buffers_op
# self._all_reduce_group = group
# return self
# def disable_all_reduce(self):
# self._all_reduce_grads_op = None
# self._all_reduce_buffers_op = None
# self._all_reduce_group = None
# return self
# @abstractmethod
# def state_forward(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# raise NotImplementedError()
# @abstractmethod
# def loss_fn(self,
# batch_id: int,
# outputs: Tuple[Tensor,...],
# ) -> Tensor:
# raise NotImplementedError()
# @torch.inference_mode()
# def forward(self, *inputs: Tensor) -> Tuple[Tensor,...]:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
self
.
_load_states
(
states
)
# last_work = None
outputs
,
next_states
=
self
.
state_forward
(
batch_id
,
inputs
,
states
)
return
outputs
,
next_states
def
_backward_inner
(
self
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
outputs
:
Tuple
[
Tensor
,
...
],
input_states
:
Tuple
[
Tensor
,
...
],
output_states
:
Tuple
[
Tensor
,
...
],
scale_factor
:
float
=
1.0
,
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tensor
]:
loss
=
self
.
loss_fn
(
batch_id
,
outputs
)
if
scale_factor
!=
1.0
:
loss
=
loss
*
scale_factor
vec_loss
=
[
loss
]
vec_grad
=
[
torch
.
ones_like
(
loss
)]
self
.
_load_states
(
output_states
,
grad
=
True
)
for
s
in
output_states
:
if
s
.
requires_grad
:
s
.
retain_grad
()
g
,
s
.
grad
=
s
.
grad
,
None
vec_loss
.
append
(
s
)
vec_grad
.
append
(
g
)
self
.
_vector_backward
(
vec_loss
,
vec_grad
)
input_grads
=
[]
# outputs = None
for
t
in
inputs
:
# for batch_id in range(num_batchs):
g
,
t
.
grad
=
t
.
grad
,
None
# start = batch_id * self.batch_size
input_grads
.
append
(
g
)
# end = min(B, start + self.batch_size)
return
tuple
(
input_grads
),
loss
.
detach
()
# batch_inputs = self._get_batch_inputs(start, end, detach)
# batch_states = self._get_batch_states(end - start)
# batch_outputs, batch_states = self._forward_inner(batch_id, batch_inputs, batch_states)
# if outputs is None:
# outputs = []
# for t in batch_outputs:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# outputs.append(t)
# outputs = tuple(outputs)
# for o, t in zip(outputs, batch_outputs):
# o[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# last_work = self._save_states(batch_states)
# if last_work is not None:
# last_work.wait()
# return outputs
# def backward(self, *inputs: Tensor, scale: float = 1.0) -> Tensor:
# detach = self._detach_inputs(inputs)
# B = inputs[0].size(0)
# num_batchs = (B + self.batch_size - 1) // self.batch_size
# footprint = F1B1Footprint(self.index, self.num_ranks, num_batchs)
# source_footprint = F1B1Footprint(self.index-1, self.num_ranks, num_batchs)
# target_footprint = F1B1Footprint(self.index+1, self.num_ranks, num_batchs)
# fw_ready: Dict[int, Any] = {}
# bw_ready: Dict[int, Any] = {}
# last_work = None
# # input_batch_grads = []
# input_grads = None
# total_loss = None
# while True:
# _, op, batch_id = footprint.step()
# _, source_op, source_batch_id = source_footprint.step()
# _, target_op, target_batch_id = target_footprint.step()
# if op is None and source_op is None and target_op is None:
# break
# if last_work is not None:
# last_work.wait()
# last_work = None
# if source_op == "backward":
# input_states, = bw_ready.pop(source_batch_id)
# last_work = self._save_states(input_states, grad=True)
# del input_states
# elif target_op == "forward":
# *_, output_states = fw_ready[target_batch_id]
# last_work = self._save_states(output_states)
# del _, output_states
# if op == "forward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs = self._get_batch_inputs(start, end, detach, inputs)
# batch_input_states = self._get_batch_states(end - start, requires_grad=True)
# batch_outputs, batch_output_states = self._forward_inner(
# batch_id,
# batch_inputs, batch_input_states,
# )
# fw_ready[batch_id] = [
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# ]
# elif op == "backward":
# start = batch_id * self.batch_size
# end = min(B, start + self.batch_size)
# batch_inputs, batch_outputs, batch_input_states, batch_output_states = fw_ready.pop(batch_id)
# scale_factor = scale * self.batch_size / (end - start)
# grads, loss = self._backward_inner(
# batch_id,
# batch_inputs,
# batch_outputs,
# batch_input_states,
# batch_output_states,
# scale_factor=scale_factor,
# )
# bw_ready[batch_id] = [
# batch_input_states,
# ]
# total_loss = loss if total_loss is None else total_loss + loss
# if input_grads is None:
# input_grads = []
# for t in grads:
# if t is not None:
# t = torch.empty(B, *t.shape[1:], dtype=t.dtype, device=t.device)
# input_grads.append(t)
# input_grads = tuple(input_grads)
# for g, t in zip(input_grads, grads):
# if g is not None:
# g[start:end] = t.data
# if last_work is not None:
# last_work.wait()
# prev_works = self._prev_inputs_backward()
# self._vector_backward(inputs, input_grads)
# self._post_inputs_backward(prev_works)
# return total_loss
# def _prev_inputs_backward(self):
# works = []
# works.extend(all_reduce_gradients(
# self, op=self._all_reduce_grads_op,
# group=self._all_reduce_group, async_op=True,
# ))
# works.extend(all_reduce_buffers(
# self, op=self._all_reduce_buffers_op,
# group=self._all_reduce_group, async_op=True,
# ))
# return works
# def _post_inputs_backward(self, works):
# for w in works:
# w.wait()
# works.clear()
# def _load_states(self, states: Tuple[Tensor,...], grad: bool = False):
# for s in states:
# s.grad = None
# if grad: # from target to source
# if self.index + 1 < self.num_ranks:
# s_grads = []
# for s in states:
# if s.dtype.is_floating_point:
# s.grad = torch.zeros_like(s.data)
# s_grads.append(s.grad)
# batch_recv(
# *s_grads,
# src=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True,
# ).wait()
# else: # from source to target
# if self.index > 0:
# batch_recv(
# *[s.data for s in states],
# src=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# ).wait()
# def _save_states(self, states: Tuple[Tensor,...], grad: bool = False) -> BatchWork:
# if grad: # from target to source
# if self.index > 0:
# s_grads = []
# for s in states:
# g, s.grad = s.grad, None
# if s.dtype.is_floating_point:
# s_grads.append(torch.zeros_like(s) if g is None else g)
# return batch_send(
# *s_grads,
# dst=self.seq_ranks[self.index - 1],
# group=self.group,
# async_op=True,
# )
# else: # from source to target
# if self.index + 1 < self.num_ranks:
# return batch_send(
# *[s.data for s in states],
# dst=self.seq_ranks[self.index + 1],
# group=self.group,
# async_op=True
# )
# return BatchWork(None, None)
# def _vector_backward(self, vec_loss: List[Tensor], vec_grad: List[Optional[Tensor]]):
# loss = []
# grad = []
# for x, g in zip(vec_loss, vec_grad):
# if g is None:
# continue
# if not x.requires_grad:
# continue
# # if not x.dtype.is_floating_point:
# # continue
# loss.append(x.flatten())
# grad.append(g.flatten())
# if len(loss) != 0:
# loss = torch.cat(loss, dim=0)
# grad = torch.cat(grad, dim=0)
# loss.backward(grad)
# def _get_batch_inputs(self,
# start: int, end: int,
# detach: Tuple[Tensor,...],
# inputs: Optional[Tuple[Tensor,...]] = None,
# ) -> Tuple[Tensor,...]:
# outs: List[Tensor] = []
# for i, t in enumerate(detach):
# assert not t.requires_grad
# if inputs is None:
# outs.append(t[start:end])
# elif inputs[i].requires_grad:
# t = t[start:end]
# t.requires_grad_()
# t.retain_grad()
# outs.append(t)
# else:
# outs.append(t[start:end])
# return tuple(outs)
# def _get_batch_states(self, bs: int, requires_grad: bool = False) -> Tuple[Tensor,...]:
# assert self._init_states is not None, "please call init_states()."
# states: List[Tensor] = []
# for s in self._init_states:
# s = s.unsqueeze(0).broadcast_to(bs, *s.size()).contiguous()
# if requires_grad and self.index > 0 and s.dtype.is_floating_point:
# s.requires_grad_()
# s.retain_grad()
# states.append(s)
# return tuple(states)
# def _detach_inputs(self, inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]:
# assert inputs[0].size(0) > 0, "input tensors' batch dimension must be greater than one."
# detach_inputs: List[Tensor] = []
# for t in inputs:
# assert t.size(0) == inputs[0].size(0), "all tensors' batch dimension must be the exact same."
# detach_inputs.append(t.detach())
# return tuple(detach_inputs)
# def _forward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# states: Tuple[Tensor,...],
# ) -> Tuple[Tuple[Tensor,...], Tuple[Tensor,...]]:
# self._load_states(states)
# outputs, next_states = self.state_forward(batch_id, inputs, states)
# return outputs, next_states
# def _backward_inner(self,
# batch_id: int,
# inputs: Tuple[Tensor,...],
# outputs: Tuple[Tensor,...],
# input_states: Tuple[Tensor,...],
# output_states: Tuple[Tensor,...],
# scale_factor: float = 1.0,
# ) -> Tuple[Tuple[Tensor,...], Tensor]:
# loss = self.loss_fn(batch_id, outputs)
# if scale_factor != 1.0:
# loss = loss * scale_factor
# vec_loss = [loss]
# vec_grad = [torch.ones_like(loss)]
# self._load_states(output_states, grad=True)
# for s in output_states:
# if s.requires_grad:
# s.retain_grad()
# g, s.grad = s.grad, None
# vec_loss.append(s)
# vec_grad.append(g)
# self._vector_backward(vec_loss, vec_grad)
# input_grads = []
# for t in inputs:
# g, t.grad = t.grad, None
# input_grads.append(g)
# return tuple(input_grads), loss.detach()
class
F1B1Footprint
:
class
F1B1Footprint
:
...
@@ -417,11 +811,11 @@ class F1B1Footprint:
...
@@ -417,11 +811,11 @@ class F1B1Footprint:
else
:
else
:
self
.
_finish
=
False
self
.
_finish
=
False
def
step
(
self
)
->
Tuple
[
int
,
Optional
[
str
],
Optional
[
int
]
]:
def
step
(
self
)
->
Tuple
[
int
,
Optional
[
str
],
int
]:
if
self
.
_finish
:
if
self
.
_finish
:
return
(
self
.
_count
,
None
,
None
)
return
(
self
.
_count
,
None
,
-
1
)
ret
=
(
self
.
_count
,
"nop"
,
None
)
ret
=
(
self
.
_count
,
"nop"
,
-
1
)
if
self
.
_count
>=
self
.
_bw_offset
+
2
*
self
.
_bw_batch_id
:
if
self
.
_count
>=
self
.
_bw_offset
+
2
*
self
.
_bw_batch_id
:
ret
=
(
self
.
_count
,
"backward"
,
self
.
_bw_batch_id
)
ret
=
(
self
.
_count
,
"backward"
,
self
.
_bw_batch_id
)
self
.
_bw_batch_id
+=
1
self
.
_bw_batch_id
+=
1
...
...
starrygl/parallel/sparse.py
View file @
09c175e2
...
@@ -58,6 +58,9 @@ class SparseBlocks:
...
@@ -58,6 +58,9 @@ class SparseBlocks:
def
__fetch_ids_sizes
(
local_ids
:
Tensor
,
group
:
Any
):
def
__fetch_ids_sizes
(
local_ids
:
Tensor
,
group
:
Any
):
assert
local_ids
.
dim
()
==
1
assert
local_ids
.
dim
()
==
1
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
rank
=
dist
.
get_rank
(
group
)
rank
=
dist
.
get_rank
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
world_size
=
dist
.
get_world_size
(
group
)
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
local_ids
.
device
)
ikw
=
dict
(
dtype
=
torch
.
long
,
device
=
local_ids
.
device
)
...
@@ -80,8 +83,9 @@ class SparseBlocks:
...
@@ -80,8 +83,9 @@ class SparseBlocks:
all_ids
[
i
]
=
local_ids
all_ids
[
i
]
=
local_ids
else
:
else
:
all_ids
[
i
]
=
torch
.
empty
(
all_lens
[
i
],
**
ikw
)
all_ids
[
i
]
=
torch
.
empty
(
all_lens
[
i
],
**
ikw
)
src
=
dist
.
get_global_rank
(
group
,
i
)
all_get
[
i
]
=
dist
.
broadcast
(
all_get
[
i
]
=
dist
.
broadcast
(
all_ids
[
i
],
src
=
i
,
async_op
=
True
,
group
=
group
all_ids
[
i
],
src
=
src
,
async_op
=
True
,
group
=
group
)
)
imp
:
Tensor
=
torch
.
full
((
num_nodes
,),
(
2
**
62
-
1
)
*
2
+
1
,
**
ikw
)
imp
:
Tensor
=
torch
.
full
((
num_nodes
,),
(
2
**
62
-
1
)
*
2
+
1
,
**
ikw
)
...
@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
...
@@ -151,13 +155,18 @@ class SparseBlockMM(autograd.Function):
part_id
=
sp
.
part_id
part_id
=
sp
.
part_id
num_parts
=
sp
.
num_parts
num_parts
=
sp
.
num_parts
group
=
sp
.
group
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
def
async_fetch
(
i
:
int
):
def
async_fetch
(
i
:
int
):
n
=
sp
.
adj_t
(
i
)
.
sparse_size
(
1
)
n
=
sp
.
adj_t
(
i
)
.
sparse_size
(
1
)
if
i
==
part_id
:
if
i
==
part_id
:
h
=
x
.
clone
()
h
=
x
.
clone
()
else
:
else
:
h
=
torch
.
empty
(
n
,
*
x
.
shape
[
1
:],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
h
=
torch
.
empty
(
n
,
*
x
.
shape
[
1
:],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
dist
.
broadcast
(
h
,
src
=
i
,
group
=
sp
.
group
,
async_op
=
True
)
src
=
dist
.
get_global_rank
(
group
,
i
)
return
dist
.
broadcast
(
h
,
src
=
src
,
group
=
sp
.
group
,
async_op
=
True
)
last_work
=
None
last_work
=
None
out
=
None
out
=
None
...
@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
...
@@ -192,9 +201,14 @@ class SparseBlockMM(autograd.Function):
part_id
=
sp
.
part_id
part_id
=
sp
.
part_id
num_parts
=
sp
.
num_parts
num_parts
=
sp
.
num_parts
group
=
sp
.
group
if
group
is
None
:
group
=
dist
.
GroupMember
.
WORLD
def
async_reduce
(
i
:
int
,
g
:
Tensor
):
def
async_reduce
(
i
:
int
,
g
:
Tensor
):
dst
=
dist
.
get_global_rank
(
group
,
i
)
return
dist
.
reduce
(
return
dist
.
reduce
(
g
,
dst
=
i
,
op
=
dist
.
ReduceOp
.
SUM
,
g
,
dst
=
dst
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
sp
.
group
,
async_op
=
True
,
group
=
sp
.
group
,
async_op
=
True
,
)
)
...
...
train_hybrid.py
0 → 100644
View file @
09c175e2
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.data
import
GraphData
from
starrygl.parallel
import
Route
,
SequencePipe
from
starrygl.parallel.sequence
import
STensor
from
starrygl.parallel.utils
import
*
import
torch_geometric.nn
as
pyg_nn
import
torch_geometric.datasets
as
pyg_datasets
import
torch_geometric.utils
as
pyg_utils
import
logging
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
def
prepare_data
(
root
:
str
,
num_parts
,
part_algo
:
str
=
"metis"
):
ctx
=
DistributedContext
.
get_default_context
()
data
=
pyg_datasets
.
Planetoid
(
root
,
"Cora"
)[
0
]
if
data
.
is_directed
():
data
.
edge_index
,
_
=
pyg_utils
.
to_undirected
(
data
.
edge_index
)
data
.
edge_index
,
_
=
pyg_utils
.
add_remaining_self_loops
(
data
.
edge_index
)
data
.
num_classes
=
data
.
y
.
max
()
.
item
()
+
1
logging
.
info
(
f
"num_nodes: {data.num_nodes}"
)
logging
.
info
(
f
"num_edges: {data.num_edges}"
)
logging
.
info
(
f
"num_features: {data.num_features}"
)
logging
.
info
(
f
"num_classes: {data.num_classes}"
)
g
=
GraphData
.
from_pyg_data
(
data
)
logging
.
info
(
f
"GraphData.meta().keys(): {g.meta().keys()}"
)
logging
.
info
(
f
"GraphData.node().keys(): {g.node().keys()}"
)
logging
.
info
(
f
"GraphData.edge().keys(): {g.edge().keys()}"
)
g
.
save_partition
(
root
,
num_parts
,
part_algo
)
return
g
class
SimpleConv
(
pyg_nn
.
MessagePassing
):
def
__init__
(
self
,
in_feats
:
int
,
out_feats
:
int
):
super
()
.
__init__
(
aggr
=
"mean"
)
self
.
linear
=
nn
.
Linear
(
in_feats
,
out_feats
)
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
route
:
Route
):
dst_len
=
x
.
size
(
0
)
x
=
route
.
apply
(
x
)
# exchange features
return
self
.
propagate
(
edge_index
,
x
=
x
)[:
dst_len
]
def
message
(
self
,
x_j
:
Tensor
):
return
x_j
def
update
(
self
,
x
:
Tensor
):
return
F
.
relu
(
self
.
linear
(
x
))
class
SimpleGNN
(
nn
.
Module
):
def
__init__
(
self
,
num_features
:
int
,
hidden_dims
:
int
,
num_layers
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_layers
):
in_ch
=
hidden_dims
if
i
>
0
else
num_features
out_ch
=
hidden_dims
self
.
layers
.
append
(
SimpleConv
(
in_ch
,
out_ch
))
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
route
:
Route
):
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
edge_index
,
route
)
return
x
class
SimpleRNN
(
SequencePipe
,
nn
.
Module
):
def
__init__
(
self
,
num_classes
:
int
,
hidden_dims
:
int
,
num_layers
:
int
,
device
:
Any
,
group
:
Any
,
)
->
None
:
super
()
.
__init__
()
self
.
device
=
device
self
.
group
=
group
self
.
num_layers
=
num_layers
self
.
hidden_dims
=
hidden_dims
self
.
gru
=
nn
.
GRU
(
input_size
=
hidden_dims
,
hidden_size
=
hidden_dims
,
num_layers
=
num_layers
,
batch_first
=
True
,
)
self
.
out
=
nn
.
Linear
(
hidden_dims
,
num_classes
)
def
forward
(
self
,
inputs
,
states
):
x
,
=
inputs
# (N, L, H)
h
,
=
states
# (N, L, H)
h
=
h
.
transpose
(
0
,
1
)
.
contiguous
()
# (L, N, H)
x
,
h
=
self
.
gru
(
x
,
h
)
# (N, L, H), (L, N, H)
h
=
h
.
transpose
(
0
,
1
)
.
contiguous
()
# (N, L, H)
return
(
x
,),
(
h
,
)
def
loss_fn
(
self
,
inputs
,
labels
)
->
Tensor
:
x
,
=
inputs
return
x
.
square
()
.
mean
()
def
get_group
(
self
)
->
Any
:
return
self
.
group
def
get_init_states
(
self
):
s
=
torch
.
zeros
(
self
.
num_layers
,
self
.
hidden_dims
)
.
to
(
self
.
device
)
return
(
s
,)
if
__name__
==
"__main__"
:
data_root
=
"./dataset"
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
hybrid_matrix
=
ctx
.
get_hybrid_matrix
()
if
hybrid_matrix
.
size
(
0
)
==
1
:
hybrid_matrix
=
hybrid_matrix
.
view
(
2
,
-
1
)
ctx
.
sync_print
(
hybrid_matrix
)
# sp is sequence parallel
# pp is partition parallel
sp_group
,
pp_group
=
ctx
.
new_hybrid_subgroups
(
hybrid_matrix
)
# partition data
if
ctx
.
rank
==
0
:
prepare_data
(
data_root
,
dist
.
get_world_size
(
pp_group
))
dist
.
barrier
()
g
=
GraphData
.
load_partition
(
data_root
,
dist
.
get_rank
(
pp_group
),
dist
.
get_world_size
(
pp_group
),
)
.
to
(
ctx
.
device
)
route
=
g
.
to_route
(
pp_group
)
# only on subgroup
num_features
=
g
.
node
(
"dst"
)[
"x"
]
.
size
(
-
1
)
num_classes
=
g
.
meta
()[
"num_classes"
]
hidden_dims
=
128
num_layers
=
3
gnn
=
SimpleGNN
(
num_features
,
hidden_dims
,
num_layers
)
.
to
(
ctx
.
device
)
rnn
=
SimpleRNN
(
num_classes
,
hidden_dims
,
num_layers
,
device
=
ctx
.
device
,
group
=
sp_group
)
.
to
(
ctx
.
device
)
opt
=
torch
.
optim
.
Adam
([
p
for
p
in
gnn
.
parameters
()]
+
[
p
for
p
in
rnn
.
parameters
()])
for
ep
in
range
(
1
,
100
+
1
):
seq_len
=
200
xs
=
[]
opt
.
zero_grad
()
for
_
in
range
(
seq_len
):
# snapshot parallel between partition parallel subgroups
z
=
gnn
(
x
=
g
.
node
(
"dst"
)[
"x"
],
edge_index
=
g
.
edge_index
(),
route
=
route
,
#
)
xs
.
append
(
z
.
unsqueeze
(
1
))
x
=
torch
.
cat
(
xs
,
dim
=
1
)
# (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss
=
rnn
.
fast_backward
(
32
,
(
x
,),
(
g
.
node
(
"dst"
)[
"train_mask"
],))
# all reduce
all_reduce_gradients
(
rnn
)
all_reduce_buffers
(
rnn
)
all_reduce_gradients
(
gnn
)
all_reduce_buffers
(
gnn
)
opt
.
step
()
ctx
.
sync_print
(
loss
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment