Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
88de1d9c
Commit
88de1d9c
authored
Dec 21, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
SquencePipe support long type of tensors
parent
32fec45c
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
199 additions
and
117 deletions
+199
-117
cora.py
+1
-1
starrygl/distributed/cclib.py
+6
-0
starrygl/parallel/sequence.py
+172
-116
starrygl/parallel/utils.py
+20
-0
No files found.
cora.py
View file @
88de1d9c
...
@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
...
@@ -4,7 +4,7 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import
os.path
as
osp
import
os.path
as
osp
import
sys
import
sys
from
starrygl.
graph
import
GraphData
from
starrygl.
data
import
GraphData
import
logging
import
logging
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
logging
.
getLogger
()
.
setLevel
(
logging
.
INFO
)
...
...
starrygl/distributed/cclib.py
View file @
88de1d9c
...
@@ -149,6 +149,9 @@ def batch_send(
...
@@ -149,6 +149,9 @@ def batch_send(
group
:
Any
=
None
,
group
:
Any
=
None
,
async_op
:
bool
=
False
,
async_op
:
bool
=
False
,
):
):
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
...
@@ -171,6 +174,9 @@ def batch_recv(
...
@@ -171,6 +174,9 @@ def batch_recv(
group
:
Any
=
None
,
group
:
Any
=
None
,
async_op
:
bool
=
False
,
async_op
:
bool
=
False
,
):
):
if
len
(
tensors
)
==
0
:
return
BatchWork
(
None
,
None
)
# tensors = tuple(t.data for t in tensors)
# tensors = tuple(t.data for t in tensors)
backend
=
dist
.
get_backend
(
group
)
backend
=
dist
.
get_backend
(
group
)
...
...
starrygl/parallel/sequence.py
View file @
88de1d9c
...
@@ -6,10 +6,14 @@ import torch.distributed as dist
...
@@ -6,10 +6,14 @@ import torch.distributed as dist
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.distributed
import
DistributedContext
from
starrygl.distributed.cclib
import
batch_send
,
batch_recv
,
BatchWork
from
starrygl.distributed.cclib
import
batch_send
,
batch_recv
,
BatchWork
__all__
=
[
"SequencePipe"
,
]
class
SequencePipe
:
class
SequencePipe
:
def
__init__
(
self
,
def
__init__
(
self
,
batch_size
:
int
,
batch_size
:
int
,
...
@@ -45,77 +49,12 @@ class SequencePipe:
...
@@ -45,77 +49,12 @@ class SequencePipe:
def
group
(
self
):
def
group
(
self
):
return
self
.
_group
return
self
.
_group
def
_load_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
):
if
grad
:
# from target to source
for
s
in
states
:
s
.
grad
=
torch
.
zeros_like
(
s
.
data
)
if
self
.
index
+
1
<
self
.
num_ranks
:
batch_recv
(
*
[
s
.
grad
for
s
in
states
],
src
=
self
.
seq_ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
async_op
=
True
,
)
.
wait
()
else
:
# from source to target
for
s
in
states
:
s
.
grad
=
None
if
self
.
index
>
0
:
batch_recv
(
*
[
s
.
data
for
s
in
states
],
src
=
self
.
seq_ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
.
wait
()
def
_save_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
)
->
BatchWork
:
if
grad
:
# from target to source
if
self
.
index
>
0
:
return
batch_send
(
*
[
s
.
grad
for
s
in
states
],
dst
=
self
.
seq_ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
else
:
# from source to target
if
self
.
index
+
1
<
self
.
num_ranks
:
return
batch_send
(
*
[
s
.
data
for
s
in
states
],
dst
=
self
.
seq_ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
async_op
=
True
)
return
BatchWork
(
None
,
None
)
def
init_states
(
self
,
*
states
:
Tensor
):
def
init_states
(
self
,
*
states
:
Tensor
):
for
s
in
states
:
for
s
in
states
:
assert
isinstance
(
s
,
Tensor
),
f
"states must be tuple of tensors"
assert
isinstance
(
s
,
Tensor
),
f
"states must be tuple of tensors"
self
.
_init_states
=
tuple
(
states
)
self
.
_init_states
=
tuple
(
states
)
return
self
return
self
def
batch_size_
(
self
,
bs
:
int
):
self
.
_batch_size
=
bs
return
self
def
_get_batch_states
(
self
,
bs
:
int
,
requires_grad
:
bool
=
False
)
->
Tuple
[
Tensor
,
...
]:
assert
self
.
_init_states
is
not
None
,
"please call init_states()."
states
:
List
[
Tensor
]
=
[]
for
s
in
self
.
_init_states
:
s
=
s
.
unsqueeze
(
0
)
.
broadcast_to
(
bs
,
*
s
.
size
())
states
.
append
(
s
.
contiguous
())
if
requires_grad
and
self
.
index
>
0
:
states
=
[
s
.
requires_grad_
()
for
s
in
states
]
return
tuple
(
states
)
def
_detach_inputs
(
self
,
inputs
:
Tuple
[
Tensor
,
...
])
->
Tuple
[
Tensor
,
...
]:
assert
inputs
[
0
]
.
size
(
0
)
>
0
,
"input tensors' batch dimension must be greater than one."
detach_inputs
:
List
[
Tensor
]
=
[]
for
t
in
inputs
:
assert
t
.
size
(
0
)
==
inputs
[
0
]
.
size
(
0
),
"all tensors' batch dimension must be the exact same."
detach_inputs
.
append
(
t
.
detach
())
return
tuple
(
detach_inputs
)
def
state_forward
(
self
,
def
state_forward
(
self
,
batch_id
:
int
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
inputs
:
Tuple
[
Tensor
,
...
],
...
@@ -129,58 +68,21 @@ class SequencePipe:
...
@@ -129,58 +68,21 @@ class SequencePipe:
)
->
Tensor
:
)
->
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_forward_inner
(
self
,
@torch.inference_mode
()
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
states
:
Tuple
[
Tensor
,
...
],
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
],
BatchWork
]:
self
.
_load_states
(
states
)
outputs
,
next_states
=
self
.
state_forward
(
batch_id
,
inputs
,
states
)
work
=
self
.
_save_states
(
next_states
)
return
outputs
,
next_states
,
work
def
_backward_inner
(
self
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
outputs
:
Tuple
[
Tensor
,
...
],
input_states
:
Tuple
[
Tensor
,
...
],
output_states
:
Tuple
[
Tensor
,
...
],
scale_factor
:
float
=
1.0
,
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tensor
,
BatchWork
]:
loss
=
self
.
loss_fn
(
batch_id
,
outputs
)
if
scale_factor
!=
1.0
:
loss
=
loss
*
scale_factor
total_loss
=
loss
self
.
_load_states
(
output_states
,
grad
=
True
)
for
s
in
output_states
:
g
,
s
.
grad
=
s
.
grad
,
None
total_loss
+=
torch
.
sum
(
s
*
g
)
total_loss
.
backward
()
work
=
self
.
_save_states
(
input_states
,
grad
=
True
)
input_grads
=
[]
for
t
in
inputs
:
input_grads
.
append
(
t
.
grad
)
t
.
grad
=
None
return
tuple
(
input_grads
),
loss
.
detach
(),
work
def
forward
(
self
,
*
inputs
:
Tensor
)
->
Tuple
[
Tensor
,
...
]:
def
forward
(
self
,
*
inputs
:
Tensor
)
->
Tuple
[
Tensor
,
...
]:
inputs
=
self
.
_detach_inputs
(
inputs
)
inputs
=
self
.
_detach_inputs
(
inputs
)
B
=
inputs
[
0
]
.
size
(
0
)
B
=
inputs
[
0
]
.
size
(
0
)
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
with
torch
.
no_grad
():
outputs
=
[]
last_work
=
None
last_work
=
None
outputs
=
[]
for
batch_id
in
range
(
num_batchs
):
for
batch_id
in
range
(
num_batchs
):
start
=
batch_id
*
self
.
batch_size
start
=
batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
tuple
(
t
[
start
:
end
]
for
t
in
inputs
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
inputs
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
)
batch_outputs
,
_
,
work
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_states
)
batch_outputs
,
_
,
work
=
self
.
_forward_inner
(
batch_id
,
batch_inputs
,
batch_states
)
outputs
.
append
(
batch_outputs
)
outputs
.
append
(
batch_outputs
)
...
@@ -198,8 +100,8 @@ class SequencePipe:
...
@@ -198,8 +100,8 @@ class SequencePipe:
return
tuple
(
concat_outputs
)
return
tuple
(
concat_outputs
)
def
backward
(
self
,
*
inputs
:
Tensor
,
scale
:
float
=
1.0
)
->
T
uple
[
Tuple
[
Tensor
,
...
],
Tensor
]
:
def
backward
(
self
,
*
inputs
:
Tensor
,
scale
:
float
=
1.0
)
->
T
ensor
:
inputs
=
self
.
_detach_inputs
(
inputs
)
#
inputs = self._detach_inputs(inputs)
B
=
inputs
[
0
]
.
size
(
0
)
B
=
inputs
[
0
]
.
size
(
0
)
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
num_batchs
=
(
B
+
self
.
batch_size
-
1
)
//
self
.
batch_size
...
@@ -211,7 +113,7 @@ class SequencePipe:
...
@@ -211,7 +113,7 @@ class SequencePipe:
bw_ready
:
Dict
[
int
,
Any
]
=
{}
bw_ready
:
Dict
[
int
,
Any
]
=
{}
last_bw_work
=
None
last_bw_work
=
None
input_grads
=
[]
input_
batch_
grads
=
[]
total_loss
=
None
total_loss
=
None
# hist = []
# hist = []
...
@@ -232,7 +134,7 @@ class SequencePipe:
...
@@ -232,7 +134,7 @@ class SequencePipe:
last_bw_work
=
work
last_bw_work
=
work
total_loss
=
loss
if
total_loss
is
None
else
total_loss
+
loss
total_loss
=
loss
if
total_loss
is
None
else
total_loss
+
loss
input_grads
.
append
(
grads
)
input_
batch_
grads
.
append
(
grads
)
bw_batch_id
+=
1
bw_batch_id
+=
1
else
:
else
:
...
@@ -241,7 +143,7 @@ class SequencePipe:
...
@@ -241,7 +143,7 @@ class SequencePipe:
if
fw_batch_id
<
num_batchs
and
(
fw_starting
or
fw_occupying
):
if
fw_batch_id
<
num_batchs
and
(
fw_starting
or
fw_occupying
):
start
=
fw_batch_id
*
self
.
batch_size
start
=
fw_batch_id
*
self
.
batch_size
end
=
min
(
B
,
start
+
self
.
batch_size
)
end
=
min
(
B
,
start
+
self
.
batch_size
)
batch_inputs
=
tuple
(
t
[
start
:
end
]
.
requires_grad_
()
for
t
in
inputs
)
batch_inputs
=
self
.
_get_batch_inputs
(
start
,
end
,
inputs
,
requires_grad
=
True
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
,
requires_grad
=
True
)
batch_states
=
self
.
_get_batch_states
(
end
-
start
,
requires_grad
=
True
)
batch_outputs
,
batch_next_states
,
work
=
self
.
_forward_inner
(
fw_batch_id
,
batch_inputs
,
batch_states
)
batch_outputs
,
batch_next_states
,
work
=
self
.
_forward_inner
(
fw_batch_id
,
batch_inputs
,
batch_states
)
...
@@ -258,13 +160,167 @@ class SequencePipe:
...
@@ -258,13 +160,167 @@ class SequencePipe:
fw_batch_id
+=
1
fw_batch_id
+=
1
count
+=
1
count
+=
1
concat_grads
=
[]
input_grads
=
[]
for
ts
in
zip
(
*
input_grads
):
for
ts
in
zip
(
*
input_batch_grads
):
concat_grads
.
append
(
torch
.
cat
(
ts
,
dim
=
0
))
if
None
in
ts
:
input_grads
.
append
(
None
)
else
:
input_grads
.
append
(
torch
.
cat
(
ts
,
dim
=
0
))
if
last_bw_work
is
not
None
:
if
last_bw_work
is
not
None
:
last_bw_work
.
wait
()
last_bw_work
.
wait
()
self
.
_vector_backward
(
inputs
,
input_grads
)
# print(f"{self.index+1}: {hist}")
# print(f"{self.index+1}: {hist}")
return
total_loss
def
_load_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
):
for
s
in
states
:
s
.
grad
=
None
if
grad
:
# from target to source
if
self
.
index
+
1
<
self
.
num_ranks
:
s_grads
=
[]
for
s
in
states
:
if
s
.
dtype
.
is_floating_point
:
s
.
grad
=
torch
.
zeros_like
(
s
.
data
)
s_grads
.
append
(
s
.
grad
)
batch_recv
(
*
s_grads
,
src
=
self
.
seq_ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
async_op
=
True
,
)
.
wait
()
else
:
# from source to target
if
self
.
index
>
0
:
batch_recv
(
*
[
s
.
data
for
s
in
states
],
src
=
self
.
seq_ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
.
wait
()
def
_save_states
(
self
,
states
:
Tuple
[
Tensor
,
...
],
grad
:
bool
=
False
)
->
BatchWork
:
if
grad
:
# from target to source
if
self
.
index
>
0
:
s_grads
=
[]
for
s
in
states
:
g
,
s
.
grad
=
s
.
grad
,
None
if
s
.
dtype
.
is_floating_point
:
s_grads
.
append
(
torch
.
zeros_like
(
s
)
if
g
is
None
else
g
)
return
batch_send
(
*
s_grads
,
dst
=
self
.
seq_ranks
[
self
.
index
-
1
],
group
=
self
.
group
,
async_op
=
True
,
)
else
:
# from source to target
if
self
.
index
+
1
<
self
.
num_ranks
:
return
batch_send
(
*
[
s
.
data
for
s
in
states
],
dst
=
self
.
seq_ranks
[
self
.
index
+
1
],
group
=
self
.
group
,
async_op
=
True
)
return
BatchWork
(
None
,
None
)
def
_vector_backward
(
self
,
vec_loss
:
List
[
Tensor
],
vec_grad
:
List
[
Optional
[
Tensor
]]):
loss
=
[]
grad
=
[]
for
x
,
g
in
zip
(
vec_loss
,
vec_grad
):
if
g
is
None
:
continue
if
not
x
.
requires_grad
:
continue
# if not x.dtype.is_floating_point:
# continue
loss
.
append
(
x
.
flatten
())
grad
.
append
(
g
.
flatten
())
if
len
(
loss
)
!=
0
:
loss
=
torch
.
cat
(
loss
,
dim
=
0
)
grad
=
torch
.
cat
(
grad
,
dim
=
0
)
loss
.
backward
(
grad
)
def
_get_batch_inputs
(
self
,
start
:
int
,
end
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
requires_grad
:
bool
=
False
)
->
Tuple
[
Tensor
,
...
]:
_inputs
:
List
[
Tensor
]
=
[]
for
t
in
inputs
:
if
requires_grad
and
t
.
requires_grad
:
t
=
t
.
detach
()[
start
:
end
]
t
.
requires_grad_
()
t
.
retain_grad
()
else
:
t
=
t
.
detach
()[
start
:
end
]
_inputs
.
append
(
t
)
return
tuple
(
_inputs
)
def
_get_batch_states
(
self
,
bs
:
int
,
requires_grad
:
bool
=
False
)
->
Tuple
[
Tensor
,
...
]:
assert
self
.
_init_states
is
not
None
,
"please call init_states()."
states
:
List
[
Tensor
]
=
[]
for
s
in
self
.
_init_states
:
s
=
s
.
unsqueeze
(
0
)
.
broadcast_to
(
bs
,
*
s
.
size
())
.
contiguous
()
if
requires_grad
and
self
.
index
>
0
and
s
.
dtype
.
is_floating_point
:
s
.
requires_grad_
()
s
.
retain_grad
()
states
.
append
(
s
)
return
tuple
(
states
)
def
_detach_inputs
(
self
,
inputs
:
Tuple
[
Tensor
,
...
])
->
Tuple
[
Tensor
,
...
]:
assert
inputs
[
0
]
.
size
(
0
)
>
0
,
"input tensors' batch dimension must be greater than one."
detach_inputs
:
List
[
Tensor
]
=
[]
for
t
in
inputs
:
assert
t
.
size
(
0
)
==
inputs
[
0
]
.
size
(
0
),
"all tensors' batch dimension must be the exact same."
detach_inputs
.
append
(
t
.
detach
())
return
tuple
(
detach_inputs
)
def
_forward_inner
(
self
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
states
:
Tuple
[
Tensor
,
...
],
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tuple
[
Tensor
,
...
],
BatchWork
]:
self
.
_load_states
(
states
)
outputs
,
next_states
=
self
.
state_forward
(
batch_id
,
inputs
,
states
)
work
=
self
.
_save_states
(
next_states
)
return
outputs
,
next_states
,
work
def
_backward_inner
(
self
,
batch_id
:
int
,
inputs
:
Tuple
[
Tensor
,
...
],
outputs
:
Tuple
[
Tensor
,
...
],
input_states
:
Tuple
[
Tensor
,
...
],
output_states
:
Tuple
[
Tensor
,
...
],
scale_factor
:
float
=
1.0
,
)
->
Tuple
[
Tuple
[
Tensor
,
...
],
Tensor
,
BatchWork
]:
loss
=
self
.
loss_fn
(
batch_id
,
outputs
)
if
scale_factor
!=
1.0
:
loss
=
loss
*
scale_factor
vec_loss
=
[
loss
]
vec_grad
=
[
torch
.
ones_like
(
loss
)]
self
.
_load_states
(
output_states
,
grad
=
True
)
for
s
in
output_states
:
if
s
.
requires_grad
:
s
.
retain_grad
()
g
,
s
.
grad
=
s
.
grad
,
None
vec_loss
.
append
(
s
)
vec_grad
.
append
(
g
)
self
.
_vector_backward
(
vec_loss
,
vec_grad
)
work
=
self
.
_save_states
(
input_states
,
grad
=
True
)
input_grads
=
[]
for
t
in
inputs
:
g
,
t
.
grad
=
t
.
grad
,
None
input_grads
.
append
(
g
)
return
tuple
(
input_grads
),
loss
.
detach
(),
work
return
tuple
(
concat_grads
),
total_loss
starrygl/parallel/utils.py
0 → 100644
View file @
88de1d9c
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
*
__all__
=
[
"all_reduce_gradients"
,
"all_reduce_buffers"
,
]
def
all_reduce_gradients
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
None
):
for
p
in
net
.
parameters
():
dist
.
all_reduce
(
p
.
grad
,
op
=
op
,
group
=
group
)
def
all_reduce_buffers
(
net
:
nn
.
Module
,
op
=
dist
.
ReduceOp
.
AVG
,
group
=
None
):
for
b
in
net
.
buffers
():
dist
.
all_reduce
(
b
.
data
,
op
=
op
,
group
=
group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment