Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
2177c0ed
Commit
2177c0ed
authored
Dec 21, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add batched all_reduce() in SequencePipe
parent
88de1d9c
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
300 additions
and
93 deletions
+300
-93
starrygl/parallel/sequence.py
+200
-88
starrygl/parallel/utils.py
+100
-5
No files found.
starrygl/parallel/sequence.py
View file @
2177c0ed
...
...
@@ -7,6 +7,9 @@ from torch import Tensor
from
typing
import
*
from
starrygl.distributed.cclib
import
batch_send
,
batch_recv
,
BatchWork
from
abc
import
ABC
,
abstractmethod
from
.utils
import
all_reduce_buffers
,
all_reduce_gradients
__all__
=
[
...
...
@@ -14,13 +17,17 @@ __all__ = [
]
class
SequencePipe
:
class
SequencePipe
(
nn
.
Module
,
ABC
)
:
def
__init__
(
self
,
batch_size
:
int
,
seq_ranks
:
List
[
int
]
,
group
:
Any
,
seq_ranks
:
Optional
[
List
[
int
]]
=
None
,
group
:
Any
=
None
,
)
->
None
:
super
()
.
__init__
()
self
.
_batch_size
=
int
(
batch_size
)
if
seq_ranks
is
None
:
seq_ranks
=
list
(
range
(
dist
.
get_world_size
(
group
)))
self
.
_seq_ranks
=
tuple
(
seq_ranks
)
self
.
_group
=
group
...
...
@@ -28,6 +35,9 @@ class SequencePipe:
self
.
_index
=
self
.
_seq_ranks
.
index
(
rank
)
self
.
_init_states
:
Optional
[
Tuple
[
Tensor
,
...
]]
=
None
self
.
_all_reduce_grads_op
=
None
self
.
_all_reduce_buffers_op
=
None
@property
def
batch_size
(
self
)
->
int
:
...
...
@@ -55,6 +65,20 @@ class SequencePipe:
self
.
_init_states
=
tuple
(
states
)
return
self
def
enable_all_reduce
(
self
,
grads_op
=
dist
.
ReduceOp
.
SUM
,
buffers_op
=
dist
.
ReduceOp
.
AVG
,
):
self
.
_all_reduce_grads_op
=
grads_op
self
.
_all_reduce_buffers_op
=
buffers_op
return
self
def
disable_all_reduce
(
self
):
self
.
_all_reduce_grads_op
=
None
self
.
_all_reduce_buffers_op
=
None
return
self
@abstractmethod
def
state_forward
(
self
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
...
...
@@ -62,6 +86,7 @@ class SequencePipe:
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
]]:
raise
NotImplementedError
()
@abstractmethod
def
loss_fn
(
self
,
batch_id
:
int
,
outputs
:
Tuple
[
Tensor
,
...
],
...
...
@@ -70,111 +95,157 @@ class SequencePipe:
@torch.inference_mode
()
def
forward
(
self
,
*
inputs
:
Tensor
)
->
Tuple
[
Tensor
,
...
]:
inputs
=
self
.
_detach_inputs
(
inputs
)
detach
=
self
.
_detach_inputs
(
inputs
)
B
=
inputs
[
0
]
.
size
(
0
)
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
last_work
=
None
outputs
=
[]
outputs
=
None
for
batch_id
in
range
(
num_batchs
):
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
inputs
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
detach
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
)
batch_outputs
,
_
,
work
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_states
)
outputs
.
append
(
batch_outputs
)
batch_outputs
,
batch_states
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_states
)
if
outputs
is
None
:
outputs
=
[]
for
t
in
batch_outputs
:
t
=
torch
.
empty
(
B
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
outputs
.
append
(
t
)
outputs
=
tuple
(
outputs
)
for
o
,
t
in
zip
(
outputs
,
batch_outputs
):
o
[
start
:
end
]
=
t
.
data
if
last_work
is
not
None
:
last_work
.
wait
()
last_work
=
work
concat_outputs
=
[]
for
ts
in
zip
(
*
outputs
):
concat_outputs
.
append
(
torch
.
cat
(
ts
,
dim
=
0
))
last_work
=
self
.
_save_states
(
batch_states
)
if
last_work
is
not
None
:
last_work
.
wait
()
return
tuple
(
concat_outputs
)
return
outputs
def
backward
(
self
,
*
inputs
:
Tensor
,
scale
:
float
=
1.0
)
->
Tensor
:
# inputs
= self._detach_inputs(inputs)
detach
=
self
.
_detach_inputs
(
inputs
)
B
=
inputs
[
0
]
.
size
(
0
)
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
fw_batch_id
=
0
bw_batch_id
=
0
footprint
=
F1B1Footprint
(
self
.
index
,
self
.
num_ranks
,
num_batchs
)
source_footprint
=
F1B1Footprint
(
self
.
index
-
1
,
self
.
num_ranks
,
num_batchs
)
target_footprint
=
F1B1Footprint
(
self
.
index
+
1
,
self
.
num_ranks
,
num_batchs
)
bw_offset
=
2
*
self
.
num_ranks
-
self
.
index
-
1
fw_ready
:
Dict
[
int
,
Any
]
=
{}
bw_ready
:
Dict
[
int
,
Any
]
=
{}
last_bw_work
=
None
input_batch_grads
=
[]
total_loss
=
None
# hist = []
count
=
0
while
fw_batch_id
<
num_batchs
or
bw_batch_id
<
num_batchs
:
if
count
>=
bw_offset
+
2
*
bw_batch_id
:
if
bw_batch_id
<
num_batchs
:
bs
,
work
,
*
fw_graph
=
bw_ready
.
pop
(
bw_batch_id
)
work
.
wait
()
scale_factor
=
scale
*
self
.
batch_size
/
bs
grads
,
loss
,
work
=
self
.
_backward_inner
(
bw_batch_id
,
*
fw_graph
,
scale_factor
=
scale_factor
)
# hist.append(f"{count+1}bw{bw_batch_id + 1}")
last_work
=
None
if
last_bw_work
is
not
None
:
last_bw_work
.
wait
()
last_bw_work
=
work
# input_batch_grads = []
input_grads
=
None
total_loss
=
None
total_loss
=
loss
if
total_loss
is
None
else
total_loss
+
loss
input_batch_grads
.
append
(
grads
)
while
True
:
_
,
op
,
batch_id
=
footprint
.
step
()
_
,
source_op
,
source_batch_id
=
source_footprint
.
step
()
_
,
target_op
,
target_batch_id
=
target_footprint
.
step
()
if
op
is
None
and
source_op
is
None
and
target_op
is
None
:
break
bw_batch_id
+=
1
else
:
fw_starting
=
(
fw_batch_id
<
self
.
num_ranks
and
count
>=
fw_batch_id
+
self
.
index
)
fw_occupying
=
(
count
>=
self
.
index
+
2
*
fw_batch_id
)
if
fw_batch_id
<
num_batchs
and
(
fw_starting
or
fw_occupying
):
start
=
fw_batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
inputs
,
requires_grad
=
True
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
,
requires_grad
=
True
)
batch_outputs
,
batch_next_states
,
work
=
self
.
_forward_inner
(
fw_batch_id
,
batch_inputs
,
batch_states
)
# hist.append(f"{count+1}fw{fw_batch_id + 1}")
bw_ready
[
fw_batch_id
]
=
[
end
-
start
,
work
,
batch_inputs
,
batch_outputs
,
batch_states
,
batch_next_states
,
]
fw_batch_id
+=
1
count
+=
1
if
last_work
is
not
None
:
last_work
.
wait
()
last_work
=
None
if
source_op
==
"backward"
:
input_states
,
=
bw_ready
.
pop
(
source_batch_id
)
last_work
=
self
.
_save_states
(
input_states
,
grad
=
True
)
del
input_states
elif
target_op
==
"forward"
:
*
_
,
output_states
=
fw_ready
[
target_batch_id
]
last_work
=
self
.
_save_states
(
output_states
)
del
_
,
output_states
if
op
==
"forward"
:
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
detach
,
inputs
)
batch_input_states
=
self
.
_get_batch_states
(
end
-
start
,
requires_grad
=
True
)
batch_outputs
,
batch_output_states
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_input_states
,
)
fw_ready
[
batch_id
]
=
[
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
,
]
elif
op
==
"backward"
:
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
=
fw_ready
.
pop
(
batch_id
)
scale_factor
=
scale
*
self
.
batch_size
/
(
end
-
start
)
grads
,
loss
=
self
.
_backward_inner
(
batch_id
,
batch_inputs
,
batch_outputs
,
batch_input_states
,
batch_output_states
,
scale_factor
=
scale_factor
,
)
bw_ready
[
batch_id
]
=
[
batch_input_states
,
]
total_loss
=
loss
if
total_loss
is
None
else
total_loss
+
loss
if
input_grads
is
None
:
input_grads
=
[]
for
t
in
grads
:
if
t
is
not
None
:
t
=
torch
.
empty
(
B
,
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
)
input_grads
.
append
(
t
)
input_grads
=
tuple
(
input_grads
)
for
g
,
t
in
zip
(
input_grads
,
grads
):
if
g
is
not
None
:
g
[
start
:
end
]
=
t
.
data
input_grads
=
[]
for
ts
in
zip
(
*
input_batch_grads
):
if
None
in
ts
:
input_grads
.
append
(
None
)
else
:
input_grads
.
append
(
torch
.
cat
(
ts
,
dim
=
0
))
if
last_work
is
not
None
:
last_work
.
wait
()
if
last_bw_work
is
not
None
:
last_bw_work
.
wait
()
prev_works
=
self
.
_prev_inputs_backward
()
self
.
_vector_backward
(
inputs
,
input_grads
)
self
.
_post_inputs_backward
(
prev_works
)
# print(f"{self.index+1}: {hist}")
return
total_loss
def
_prev_inputs_backward
(
self
):
works
=
[]
works
.
extend
(
all_reduce_gradients
(
self
,
op
=
self
.
_all_reduce_grads_op
,
group
=
self
.
group
,
async_op
=
True
,
))
works
.
extend
(
all_reduce_buffers
(
self
,
op
=
self
.
_all_reduce_buffers_op
,
group
=
self
.
group
,
async_op
=
True
,
))
return
works
def
_post_inputs_backward
(
self
,
works
):
for
w
in
works
:
w
.
wait
()
works
.
clear
()
def
_load_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
):
for
s
in
states
:
...
...
@@ -248,19 +319,22 @@ class SequencePipe:
def
_get_batch_inputs
(
self
,
start
:
int
,
end
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
requires_grad
:
bool
=
False
detach
:
Tuple
[
Tensor
,
...
],
inputs
:
Optional
[
Tuple
[
Tensor
,
...
]]
=
None
,
)
->
Tuple
[
Tensor
,
...
]:
_inputs
:
List
[
Tensor
]
=
[]
for
t
in
inputs
:
if
requires_grad
and
t
.
requires_grad
:
t
=
t
.
detach
()[
start
:
end
]
outs
:
List
[
Tensor
]
=
[]
for
i
,
t
in
enumerate
(
detach
):
assert
not
t
.
requires_grad
if
inputs
is
None
:
outs
.
append
(
t
[
start
:
end
])
elif
inputs
[
i
]
.
requires_grad
:
t
=
t
[
start
:
end
]
t
.
requires_grad_
()
t
.
retain_grad
()
outs
.
append
(
t
)
else
:
t
=
t
.
detach
()[
start
:
end
]
_inputs
.
append
(
t
)
return
tuple
(
_inputs
)
outs
.
append
(
t
[
start
:
end
])
return
tuple
(
outs
)
def
_get_batch_states
(
self
,
bs
:
int
,
requires_grad
:
bool
=
False
)
->
Tuple
[
Tensor
,
...
]:
assert
self
.
_init_states
is
not
None
,
"please call init_states()."
...
...
@@ -285,12 +359,11 @@ class SequencePipe:
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
states
:
Tuple
[
Tensor
,
...
],
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
]
,
BatchWork
]:
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
]]:
self
.
_load_states
(
states
)
outputs
,
next_states
=
self
.
state_forward
(
batch_id
,
inputs
,
states
)
work
=
self
.
_save_states
(
next_states
)
return
outputs
,
next_states
,
work
return
outputs
,
next_states
def
_backward_inner
(
self
,
batch_id
:
int
,
...
...
@@ -299,7 +372,7 @@ class SequencePipe:
input_states
:
Tuple
[
Tensor
,
...
],
output_states
:
Tuple
[
Tensor
,
...
],
scale_factor
:
float
=
1.0
,
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tensor
,
BatchWork
]:
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tensor
]:
loss
=
self
.
loss_fn
(
batch_id
,
outputs
)
if
scale_factor
!=
1.0
:
loss
=
loss
*
scale_factor
...
...
@@ -316,11 +389,49 @@ class SequencePipe:
vec_grad
.
append
(
g
)
self
.
_vector_backward
(
vec_loss
,
vec_grad
)
work
=
self
.
_save_states
(
input_states
,
grad
=
True
)
input_grads
=
[]
for
t
in
inputs
:
g
,
t
.
grad
=
t
.
grad
,
None
input_grads
.
append
(
g
)
return
tuple
(
input_grads
),
loss
.
detach
(),
work
return
tuple
(
input_grads
),
loss
.
detach
()
class
F1B1Footprint
:
def
__init__
(
self
,
index
:
int
,
num_ranks
:
int
,
num_batchs
:
int
,
)
->
None
:
self
.
_index
=
index
self
.
_num_ranks
=
num_ranks
self
.
_num_batchs
=
num_batchs
self
.
_bw_offset
=
2
*
self
.
_num_ranks
-
self
.
_index
-
1
self
.
_fw_batch_id
=
0
self
.
_bw_batch_id
=
0
self
.
_count
=
0
if
index
<
0
or
index
>
num_ranks
:
self
.
_finish
=
True
else
:
self
.
_finish
=
False
def
step
(
self
)
->
Tuple
[
int
,
Optional
[
str
],
Optional
[
int
]]:
if
self
.
_finish
:
return
(
self
.
_count
,
None
,
None
)
ret
=
(
self
.
_count
,
"nop"
,
None
)
if
self
.
_count
>=
self
.
_bw_offset
+
2
*
self
.
_bw_batch_id
:
ret
=
(
self
.
_count
,
"backward"
,
self
.
_bw_batch_id
)
self
.
_bw_batch_id
+=
1
elif
self
.
_fw_batch_id
<
self
.
_num_batchs
:
if
self
.
_count
>=
self
.
_index
+
2
*
self
.
_fw_batch_id
:
ret
=
(
self
.
_count
,
"forward"
,
self
.
_fw_batch_id
)
self
.
_fw_batch_id
+=
1
if
self
.
_bw_batch_id
>=
self
.
_num_batchs
:
self
.
_finish
=
True
self
.
_count
+=
1
return
ret
\ No newline at end of file
starrygl/parallel/utils.py
View file @
2177c0ed
...
...
@@ -5,16 +5,111 @@ import torch.distributed as dist
from
torch
import
Tensor
from
typing
import
*
from
collections
import
defaultdict
__all__
=
[
"all_reduce_gradients"
,
"all_reduce_buffers"
,
]
def
all_reduce_gradients
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
None
):
# def all_reduce_gradients(net: nn.Module, op = dist.ReduceOp.SUM, group = None, async_op: bool = False):
# works = []
# for p in net.parameters():
# if p.grad is None:
# p.grad = torch.zeros_like(p.data)
# w = dist.all_reduce(p.grad, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
# def all_reduce_buffers(net: nn.Module, op = dist.ReduceOp.AVG, group = None, async_op: bool = False):
# works = []
# for b in net.buffers():
# w = dist.all_reduce(b.data, op=op, group=group, async_op=async_op)
# works.append(w)
# if async_op:
# return works
def
all_reduce_gradients
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
None
,
async_op
:
bool
=
False
):
device
=
None
works
=
[]
if
op
is
None
:
return
works
typed_numel
=
defaultdict
(
lambda
:
0
)
for
p
in
net
.
parameters
():
typed_numel
[
p
.
dtype
]
+=
p
.
numel
()
device
=
p
.
device
if
device
is
None
:
return
works
typed_tensors
:
Dict
[
torch
.
dtype
,
Tensor
]
=
{}
for
t
,
n
in
typed_numel
.
items
():
typed_tensors
[
t
]
=
torch
.
zeros
(
n
,
dtype
=
t
,
device
=
device
)
typed_offset
=
defaultdict
(
lambda
:
0
)
for
p
in
net
.
parameters
():
dist
.
all_reduce
(
p
.
grad
,
op
=
op
,
group
=
group
)
s
=
typed_offset
[
p
.
dtype
]
t
=
s
+
p
.
numel
()
typed_offset
[
p
.
dtype
]
=
t
if
p
.
grad
is
not
None
:
typed_tensors
[
p
.
dtype
][
s
:
t
]
=
p
.
grad
.
flatten
()
storage
=
typed_tensors
[
p
.
dtype
]
.
untyped_storage
()
g
=
torch
.
empty
(
0
,
dtype
=
p
.
dtype
,
device
=
device
)
p
.
grad
=
g
.
set_
(
storage
,
s
,
p
.
size
(),
default_stride
(
*
p
.
size
()))
for
t
in
typed_tensors
.
values
():
w
=
dist
.
all_reduce
(
t
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
works
.
append
(
w
)
return
works
def
all_reduce_buffers
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
AVG
,
group
=
None
,
async_op
:
bool
=
False
):
device
=
None
works
=
[]
if
op
is
None
:
return
works
typed_numel
=
defaultdict
(
lambda
:
0
)
for
p
in
net
.
buffers
():
typed_numel
[
p
.
dtype
]
+=
p
.
numel
()
device
=
p
.
device
if
device
is
None
:
return
works
typed_tensors
:
Dict
[
torch
.
dtype
,
Tensor
]
=
{}
for
t
,
n
in
typed_numel
.
items
():
typed_numel
[
t
]
=
torch
.
zeros
(
n
,
dtype
=
t
,
device
=
device
)
typed_offset
=
defaultdict
(
lambda
:
0
)
for
p
in
net
.
buffers
():
s
=
typed_offset
[
p
.
dtype
]
t
=
s
+
p
.
numel
()
typed_offset
[
p
.
dtype
]
=
t
typed_tensors
[
p
.
dtype
][
s
:
t
]
=
p
.
flatten
()
storage
=
typed_tensors
[
p
.
dtype
]
.
untyped_storage
()
p
.
set_
(
storage
,
s
,
p
.
size
(),
default_stride
(
*
p
.
size
()))
for
t
in
typed_tensors
.
values
():
w
=
dist
.
all_reduce
(
t
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
works
.
append
(
w
)
return
works
def
all_reduce_buffers
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
AVG
,
group
=
None
):
for
b
in
net
.
buffers
():
dist
.
all_reduce
(
b
.
data
,
op
=
op
,
group
=
group
)
def
default_stride
(
*
size
:
int
)
->
Tuple
[
int
,
...
]:
dims
=
len
(
size
)
stride
=
[
1
]
*
dims
for
i
in
range
(
1
,
dims
):
k
=
dims
-
i
stride
[
k
-
1
]
=
stride
[
k
]
*
size
[
k
]
return
tuple
(
stride
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment