Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
f4af3231
Commit
f4af3231
authored
Jul 24, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rpc based
parent
3f94b191
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1076 additions
and
360 deletions
+1076
-360
.gitignore
+4
-0
starrygl/loader/__init__.py
+313
-163
starrygl/loader/buffer.py
+164
-25
starrygl/loader/context.py
+31
-0
starrygl/loader/route.py
+238
-0
starrygl/loader/utils.py
+0
-6
starrygl/nn/base_model.py
+41
-41
starrygl/parallel/__init__.py
+63
-13
starrygl/parallel/degree.py
+27
-27
train.py
+195
-85
No files found.
.gitignore
View file @
f4af3231
...
@@ -160,3 +160,6 @@ cython_debug/
...
@@ -160,3 +160,6 @@ cython_debug/
#.idea/
#.idea/
cora/
cora/
/test_*
/*.ipynb
/s.py
\ No newline at end of file
starrygl/loader/__init__.py
View file @
f4af3231
...
@@ -4,201 +4,351 @@ import torch.nn as nn
...
@@ -4,201 +4,351 @@ import torch.nn as nn
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
.buffer
import
TensorBuffers
from
torch_sparse
import
SparseTensor
from
torch_scatter
import
scatter_min
from
multiprocessing.pool
import
ThreadPool
class
BatchHandle
:
from
.buffer
import
TensorBuffer
from
.route
import
Route
from
.context
import
RouteContext
# class BatchHandle:
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# feat_buffer: TensorBuffers,
# grad_buffer: TensorBuffers,
# with_feat0: bool = False,
# layer_id: Optional[int] = None,
# target_device: Any = None,
# ) -> None:
# self.src_ids = src_ids
# self.dst_size = dst_size
# self.feat_buffer = feat_buffer
# self.grad_buffer = grad_buffer
# self.with_feat0 = with_feat0
# self.layer_id = layer_id
# if target_device is None:
# self.target_device = src_ids.device
# else:
# self.target_device = torch.device(target_device)
# if with_feat0:
# _ = self.feat0
# @property
# def dst_ids(self) -> Tensor:
# return self.src_ids[:self.dst_size]
# @property
# def src_size(self) -> int:
# return self.src_ids.size(0)
# @property
# def device(self) -> torch.device:
# return self.src_ids.device
# @property
# def feat0(self) -> Tensor:
# if not hasattr(self, "_feat0"):
# self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
# return self._feat0
# def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device)
# def update_feat(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
# return torch.cat([x, o], dim=0)
# def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device)
# def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.src_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class
NodeHandle
:
def
__init__
(
self
,
def
__init__
(
self
,
src_ids
:
Tensor
,
src_ids
:
Tensor
,
dst_size
:
int
,
dst_size
:
int
,
feat_buffer
:
TensorBuffers
,
edge_ids
:
Tensor
,
grad_buffer
:
TensorBuffers
,
adj_t
:
SparseTensor
,
with_feat0
:
bool
=
False
,
executor
:
ThreadPool
,
layer_id
:
Optional
[
int
]
=
None
,
device
:
Any
,
target_device
:
Any
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
edge_time
:
Optional
[
Tensor
]
=
None
,
node_time
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
self
.
src_ids
=
src_ids
self
.
_src_ids
=
src_ids
self
.
dst_size
=
dst_size
self
.
_dst_size
=
dst_size
self
.
feat_buffer
=
feat_buffer
self
.
_edge_ids
=
edge_ids
self
.
grad_buffer
=
grad_buffer
self
.
with_feat0
=
with_feat0
self
.
_device
=
torch
.
device
(
device
)
self
.
layer_id
=
layer_id
if
self
.
_device
!=
torch
.
device
(
"cpu"
)
and
stream
is
None
:
self
.
_stream
=
torch
.
cuda
.
Stream
(
self
.
_device
)
if
target_device
is
None
:
else
:
self
.
target_device
=
src_ids
.
device
self
.
_stream
=
stream
self
.
_executor
=
executor
self
.
_adj_t
=
adj_t
.
to
(
self
.
device
)
self
.
_node_time
=
node_time
self
.
_edge_time
=
edge_time
def
get_src_feats
(
self
,
data
:
Union
[
Tensor
,
TensorBuffer
]):
return
self
.
async_select
(
data
,
self
.
src_ids
)
def
get_dst_feats
(
self
,
data
:
Union
[
Tensor
,
TensorBuffer
]):
return
self
.
async_select
(
data
,
self
.
dst_ids
)
def
get_ext_feats
(
self
,
data
:
Union
[
Tensor
,
TensorBuffer
]):
return
self
.
async_select
(
data
,
self
.
ext_ids
)
def
get_edge_feats
(
self
,
data
:
Union
[
Tensor
,
TensorBuffer
]):
return
self
.
async_select
(
data
,
self
.
edge_ids
)
def
get_src_time
(
self
):
return
self
.
async_select
(
self
.
_node_time
,
self
.
src_ids
)
def
get_dst_time
(
self
):
return
self
.
async_select
(
self
.
_node_time
,
self
.
dst_ids
)
def
get_ext_time
(
self
):
return
self
.
async_select
(
self
.
_node_time
,
self
.
ext_ids
)
def
get_edge_time
(
self
)
->
Tensor
:
return
self
.
async_select
(
self
.
_edge_time
,
self
.
edge_ids
)
def
push_and_pull
(
self
,
value
:
Tensor
,
ext_fut
:
torch
.
futures
.
Future
[
Tensor
],
data
:
Union
[
Tensor
,
TensorBuffer
]):
fut
=
self
.
async_update
(
data
,
value
,
self
.
dst_ids
)
ext_value
=
ext_fut
.
wait
()
x
=
torch
.
cat
([
value
,
ext_value
],
dim
=
0
)
return
x
,
fut
def
async_select
(
self
,
src
:
Union
[
Tensor
,
TensorBuffer
],
idx
:
Tensor
)
->
torch
.
futures
.
Future
[
Tensor
]:
fut
=
torch
.
futures
.
Future
(
devices
=
[
self
.
_device
])
def
run
():
try
:
with
torch
.
no_grad
():
with
torch
.
cuda
.
stream
(
self
.
_stream
):
index
=
idx
.
to
(
src
.
device
)
if
isinstance
(
src
,
TensorBuffer
):
val
=
src
.
local_get
(
index
,
lock
=
True
)
else
:
else
:
self
.
target_device
=
torch
.
device
(
target_device
)
val
=
src
.
index_select
(
0
,
index
)
val
=
val
.
to
(
self
.
_device
)
fut
.
set_result
(
val
)
except
Exception
as
e
:
fut
.
set_exception
(
e
)
self
.
_executor
.
apply_async
(
run
)
return
fut
def
async_update
(
self
,
src
:
Union
[
Tensor
,
TensorBuffer
],
val
:
Tensor
,
idx
:
Tensor
,
ops
:
str
=
"mov"
)
->
torch
.
futures
.
Future
:
assert
ops
in
[
"mov"
,
"add"
]
fut
=
torch
.
futures
.
Future
(
devices
=
[
self
.
_device
])
def
run
():
try
:
with
torch
.
no_grad
():
with
torch
.
cuda
.
stream
(
self
.
_stream
):
value
=
val
.
to
(
src
.
device
)
index
=
idx
.
to
(
src
.
device
)
if
isinstance
(
src
,
TensorBuffer
):
if
ops
==
"mov"
:
src
.
all_remote_set
(
value
,
index
,
lock
=
True
)
elif
ops
==
"add"
:
src
.
all_remote_add
(
value
,
index
,
lock
=
True
)
else
:
if
ops
==
"mov"
:
src
[
index
]
=
value
elif
ops
==
"add"
:
src
[
index
]
+=
value
fut
.
set_result
(
None
)
except
Exception
as
e
:
fut
.
set_exception
(
e
)
self
.
_executor
.
apply_async
(
run
)
return
fut
if
with_feat0
:
@property
_
=
self
.
feat0
def
adj_t
(
self
)
->
SparseTensor
:
return
self
.
_adj_t
@property
@property
def
dst
_ids
(
self
)
->
Tensor
:
def
src
_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
[:
self
.
dst_size
]
return
self
.
_src_ids
@property
@property
def
src_size
(
self
)
->
int
:
def
src_size
(
self
)
->
int
:
return
self
.
src_ids
.
size
(
0
)
return
self
.
_
src_ids
.
size
(
0
)
@property
@property
def
d
evice
(
self
)
->
torch
.
device
:
def
d
st_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
.
device
return
self
.
_src_ids
[:
self
.
_dst_size
]
@property
@property
def
feat0
(
self
)
->
Tensor
:
def
dst_size
(
self
)
->
int
:
if
not
hasattr
(
self
,
"_feat0"
):
return
self
.
_dst_size
self
.
_feat0
=
self
.
feat_buffer
.
get
(
0
,
self
.
src_ids
)
.
to
(
self
.
target_device
)
return
self
.
_feat0
def
fetch_feat
(
self
,
layer_id
:
Optional
[
int
]
=
None
)
->
Tensor
:
layer_id
=
int
(
self
.
layer_id
if
layer_id
is
None
else
layer_id
)
return
self
.
feat_buffer
.
get
(
layer_id
,
self
.
src_ids
)
.
to
(
self
.
target_device
)
def
update_feat
(
self
,
x
:
Tensor
,
layer_id
:
Optional
[
int
]
=
None
):
assert
x
.
size
(
0
)
==
self
.
dst_size
layer_id
=
int
(
self
.
layer_id
if
layer_id
is
None
else
layer_id
)
self
.
feat_buffer
.
set
(
layer_id
+
1
,
self
.
dst_ids
,
x
.
detach
()
.
to
(
self
.
device
))
def
push_and_pull
(
self
,
x
:
Tensor
,
layer_id
:
Optional
[
int
]
=
None
)
->
Tensor
:
assert
x
.
size
(
0
)
==
self
.
dst_size
layer_id
=
int
(
self
.
layer_id
if
layer_id
is
None
else
layer_id
)
self
.
feat_buffer
.
set
(
layer_id
+
1
,
self
.
dst_ids
,
x
.
detach
()
.
to
(
self
.
device
))
o
=
self
.
feat_buffer
.
get
(
layer_id
+
1
,
self
.
src_ids
[
self
.
dst_size
:])
.
to
(
x
.
device
)
return
torch
.
cat
([
x
,
o
],
dim
=
0
)
def
fetch_grad
(
self
,
layer_id
:
Optional
[
int
]
=
None
)
->
Tensor
:
layer_id
=
int
(
self
.
layer_id
if
layer_id
is
None
else
layer_id
)
return
self
.
grad_buffer
.
get
(
layer_id
+
1
,
self
.
dst_ids
)
.
to
(
self
.
target_device
)
def
accumulate_grad
(
self
,
x
:
Tensor
,
layer_id
:
Optional
[
int
]
=
None
):
assert
x
.
size
(
0
)
==
self
.
src_size
layer_id
=
int
(
self
.
layer_id
if
layer_id
is
None
else
layer_id
)
self
.
grad_buffer
.
add
(
layer_id
,
self
.
src_ids
,
x
.
detach
()
.
to
(
self
.
device
))
class
DataLoader
:
def
__init__
(
self
,
node_parts
:
Tensor
,
feat_buffer
:
TensorBuffers
,
grad_buffer
:
TensorBuffers
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
,
edge_time
:
Optional
[
Tensor
]
=
None
,
node_time
:
Optional
[
Tensor
]
=
None
,
)
->
None
:
self
.
src_size
=
feat_buffer
.
src_size
self
.
dst_size
=
node_parts
.
size
(
0
)
assert
node_parts
.
size
(
0
)
<=
self
.
src_size
assert
grad_buffer
.
src_size
==
self
.
src_size
num_parts
=
node_parts
.
max
()
.
item
()
+
1
@property
cluster
,
node_perm
=
node_parts
.
sort
(
dim
=
0
)
def
ext_ids
(
self
)
->
Tensor
:
node_ptr
:
Tensor
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
cluster
,
num_parts
)
return
self
.
_src_ids
[
self
.
_dst_size
:]
edge_parts
=
node_parts
[
edge_index
[
1
]]
@property
cluster
,
edge_perm
=
edge_parts
.
sort
()
def
ext_size
(
self
)
->
int
:
edge_ptr
:
Tensor
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
cluster
,
num_parts
)
return
self
.
src_size
-
self
.
dst_size
self
.
num_parts
=
num_parts
@property
self
.
node_ptr
=
node_ptr
def
edge_ids
(
self
)
->
Tensor
:
self
.
edge_ptr
=
edge_ptr
return
self
.
_edge_ids
self
.
node_perm
=
node_perm
self
.
edge_perm
=
edge_perm
self
.
edge_index
=
edge_index
[:,
edge_perm
]
if
edge_attr
is
not
None
:
@property
self
.
edge_attr
=
edge_attr
[
edge_perm
]
def
edge_size
(
self
)
->
Tensor
:
return
self
.
_edge_ids
.
size
(
0
)
if
node_time
is
not
None
:
@property
self
.
node_time
=
node_time
[
node_perm
]
def
device
(
self
)
->
torch
.
device
:
return
self
.
_device
if
edge_time
is
not
None
:
class
NodeLoader
:
self
.
edge_time
=
edge_time
[
edge_perm
]
def
__init__
(
self
,
global_ids
:
Tensor
,
global_edges
:
Tensor
,
device
:
Any
,
edge_time
:
Optional
[
Tensor
]
=
None
,
node_time
:
Optional
[
Tensor
]
=
None
,
num_threads
:
int
=
1
,
)
->
None
:
self
.
route
,
edge_index
=
Route
.
from_edge_index
(
global_ids
,
global_edges
)
self
.
feat_buffer
=
feat_buffer
if
node_time
is
None
and
edge_time
is
not
None
:
self
.
grad_buffer
=
grad_buffer
node_time
=
scatter_min
(
edge_time
[
edge_index
[
0
]],
edge_index
[
1
],
dim
=
0
,
dim_size
=
self
.
dst_size
)
def
iter
(
self
,
batch_size
:
int
=
1
,
layer_id
:
Optional
[
int
]
=
None
,
seed
:
int
=
0
,
filter
:
Callable
[[
Tensor
],
Tensor
]
=
None
,
device
=
None
):
node_ids
=
torch
.
arange
(
self
.
dst_size
)
.
type_as
(
global_ids
)
rnd
=
torch
.
Generator
()
if
node_time
is
not
None
:
if
seed
!=
0
:
perm
=
node_time
.
argsort
()
rnd
.
manual_seed
(
seed
)
node_ids
=
node_ids
[
perm
]
sampled
=
torch
.
randperm
(
self
.
num_parts
,
generator
=
rnd
,
dtype
=
torch
.
long
)
s
=
0
self
.
node_ids
=
node_ids
imp
=
torch
.
empty
(
self
.
src_size
,
dtype
=
torch
.
long
)
self
.
node_time
=
node_time
while
s
<
sampled
.
size
(
0
):
t
=
min
(
s
+
batch_size
,
sampled
.
size
(
0
))
dst_ids
=
[]
edge_ids
=
torch
.
arange
(
edge_index
.
size
(
1
))
.
type_as
(
global_ids
)
edge_index
=
[]
perm
=
edge_index
[
1
]
.
argsort
()
edge_attr
=
[]
imp
.
zero_
()
for
i
in
sampled
[
s
:
t
]
.
tolist
():
s
+=
batch_size
a
,
b
=
self
.
node_ptr
[
i
:
i
+
2
]
.
tolist
()
nidx
=
self
.
node_perm
[
a
:
b
]
if
hasattr
(
self
,
"node_time"
)
and
filter
is
not
None
:
node_mask
=
filter
(
self
.
node_time
[
a
:
b
])
if
not
node_mask
.
any
():
continue
nidx
=
nidx
[
node_mask
]
else
:
node_mask
=
None
a
,
b
=
self
.
edge_ptr
[
i
:
i
+
2
]
.
tolist
()
edge_ids
=
edge_ids
[
perm
]
eidx
=
self
.
edge_index
[:,
a
:
b
]
edge_index
=
edge_index
[:,
perm
]
if
hasattr
(
self
,
"edge_time"
)
and
filter
is
not
None
:
if
node_mask
is
None
:
edge_mask
=
filter
(
self
.
edge_time
[
a
:
b
])
else
:
imp
[
nidx
]
=
i
+
1
edge_mask
=
(
imp
[
eidx
[
1
]]
==
i
+
1
)
edge_mask
&=
filter
(
self
.
edge_time
[
a
:
b
])
if
not
edge_mask
.
any
():
continue
eidx
=
eidx
[:,
edge_mask
]
else
:
edge_mask
=
None
dst_ids
.
append
(
nidx
)
self
.
edge_ids
=
edge_ids
edge_index
.
append
(
eidx
)
self
.
edge_index
=
edge_index
self
.
edge_time
=
edge_time
# self.rowptr = torch.ops.torch_sparse.ind2ptr(self.edge_index[1], self.dst_size)
if
hasattr
(
self
,
"edge_attr"
):
self
.
device
=
torch
.
device
(
device
)
attr
=
self
.
edge_attr
[
a
:
b
]
if
self
.
device
!=
torch
.
device
(
"cpu"
):
if
edge_mask
is
None
:
self
.
stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
edge_attr
.
append
(
attr
)
else
:
else
:
edge_attr
.
append
(
attr
[
edge_mask
])
self
.
stream
=
None
self
.
executor
=
ThreadPool
(
num_threads
)
if
len
(
dst_ids
)
==
0
or
len
(
edge_index
)
==
0
:
continue
dst_ids
=
torch
.
cat
(
dst_ids
,
dim
=-
1
)
edge_index
=
torch
.
cat
(
edge_index
,
dim
=-
1
)
if
hasattr
(
self
,
"edge_attr"
):
edge_attr
=
torch
.
cat
(
edge_attr
,
dim
=
0
)
else
:
edge_attr
=
None
imp
.
zero_
()
@property
imp
.
index_fill_
(
0
,
edge_index
[
0
],
1
)
.
index_fill_
(
0
,
dst_ids
,
0
)
def
src_size
(
self
)
->
int
:
src_ids
=
torch
.
cat
([
dst_ids
,
torch
.
where
(
imp
>
0
)[
0
]],
dim
=-
1
)
return
self
.
route
.
src_size
assert
(
src_ids
[:
dst_ids
.
size
(
0
)]
==
dst_ids
)
.
all
()
imp
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
@property
imp
[
src_ids
]
=
torch
.
arange
(
src_ids
.
size
(
0
),
dtype
=
torch
.
long
)
def
dst_size
(
self
)
->
int
:
return
self
.
route
.
dst_size
# def _select_nodes(self, start_time: int, end_time: int) -> Tuple[Optional[Tensor], Optional[Tensor]]:
# if self.node_time is None:
# if self.edge_time is None:
# raise RuntimeError("neither node_time nor edge_time exists!")
# else:
# edge_mask = (start_time <= self.edge_time) & (self.edge_time < end_time)
# idx = self.edge_index[1, edge_mask]
# node_mask = torch.zeros(self.dst_size, dtype=torch.bool, device=idx.device).index_fill_(0, idx, 1)
# return torch.where(node_mask)[0], None
# else:
# time_range = torch.tensor([start_time, end_time]).type_as(self.node_time)
# s, t = torch.searchsorted(self.node_time, time_range).tolist()
# return self.node_ids[s:t], self.node_time[s:t]
def
_get_node_handle
(
self
,
node_ids
:
Tensor
,
edge_ids
:
Tensor
,
edge_index
:
Tensor
)
->
NodeHandle
:
num_node
:
int
=
torch
.
max
(
node_ids
.
max
(),
edge_index
.
max
())
.
item
()
+
1
imp
=
torch
.
zeros
(
num_node
,
dtype
=
torch
.
bool
,
device
=
node_ids
.
device
)
imp
.
index_fill_
(
0
,
edge_index
[
0
],
1
)
imp
.
index_fill_
(
0
,
node_ids
,
0
)
src_ids
=
torch
.
cat
([
node_ids
,
torch
.
where
(
imp
)[
0
]],
dim
=
0
)
src_size
=
src_ids
.
size
(
0
)
dst_size
=
node_ids
.
size
(
0
)
imp
=
torch
.
zeros
(
num_node
,
dtype
=
torch
.
long
,
device
=
node_ids
.
device
)
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
imp
[
src_ids
]
=
torch
.
arange
(
src_size
,
dtype
=
torch
.
long
,
device
=
node_ids
.
device
)
edge_index
=
imp
[
edge_index
.
flatten
()]
.
view_as
(
edge_index
)
edge_index
=
imp
[
edge_index
.
flatten
()]
.
view_as
(
edge_index
)
handle
=
BatchHandle
(
perm
=
edge_index
[
1
]
.
argsort
()
src_ids
,
dst_ids
.
size
(
0
),
edge_ids
=
edge_ids
[
perm
]
self
.
feat_buffer
,
self
.
grad_buffer
,
edge_index
=
edge_index
[:,
perm
]
with_feat0
=
(
layer_id
is
None
),
layer_id
=
layer_id
,
rowptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
edge_index
[
1
],
dst_size
)
target_device
=
device
,
adj_t
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
edge_index
[
0
],
sparse_sizes
=
(
dst_size
,
src_size
))
return
NodeHandle
(
src_ids
=
src_ids
,
dst_size
=
dst_size
,
edge_ids
=
edge_ids
,
adj_t
=
adj_t
,
executor
=
self
.
executor
,
device
=
self
.
device
,
stream
=
self
.
stream
,
edge_time
=
self
.
edge_time
,
node_time
=
self
.
node_time
,
)
)
edge_index
=
edge_index
.
to
(
device
)
edge_attr
=
edge_attr
.
to
(
device
)
yield
handle
,
edge_index
,
edge_attr
# very very very slow !!!
def
iter
(
self
,
batch_size
:
int
,
wind
:
bool
=
False
,
start_time
:
int
=
0
,
end_time
:
int
=
9223372036854775807
,
seed
:
int
=
0
):
rand_gen
=
torch
.
Generator
()
if
seed
!=
0
:
rand_gen
.
manual_seed
(
seed
)
if
not
wind
:
sampled
=
torch
.
randperm
(
self
.
dst_size
,
generator
=
rand_gen
,
dtype
=
torch
.
long
)
for
s
in
range
(
0
,
sampled
.
size
(
0
),
batch_size
):
t
=
min
(
s
+
batch_size
,
sampled
.
size
(
0
))
node_ids
=
sampled
[
s
:
t
]
imp
=
torch
.
zeros
(
self
.
dst_size
,
dtype
=
torch
.
bool
,
device
=
node_ids
.
device
)
imp
.
index_fill_
(
0
,
node_ids
,
1
)
edge_mask
=
imp
[
self
.
edge_index
[
1
]]
edge_ids
=
self
.
edge_ids
[
edge_mask
]
edge_index
=
self
.
edge_index
[:,
edge_mask
]
yield
self
.
_get_node_handle
(
node_ids
,
edge_ids
,
edge_index
)
else
:
assert
self
.
node_time
is
not
None
node_mask
=
(
start_time
<=
self
.
node_time
)
&
(
self
.
node_time
<
end_time
)
all_node_ids
=
torch
.
where
(
node_mask
)[
0
]
sampled
=
torch
.
randperm
(
node_ids
.
size
(
0
),
generator
=
rand_gen
,
dtype
=
torch
.
long
)
all_node_ids
=
all_node_ids
[
sampled
]
for
s
in
range
(
0
,
sampled
.
size
(
0
),
batch_size
):
t
=
min
(
s
+
batch_size
,
sampled
.
size
(
0
))
node_ids
=
all_node_ids
[
s
:
t
]
imp
=
torch
.
zeros
(
self
.
dst_size
,
dtype
=
torch
.
bool
,
device
=
node_ids
.
device
)
imp
.
index_fill_
(
0
,
node_ids
,
1
)
edge_mask
=
imp
[
self
.
edge_index
[
1
]]
if
self
.
edge_time
is
not
None
:
edge_time
=
self
.
edge_time
[
self
.
edge_ids
]
edge_mask
&=
(
start_time
<=
edge_time
)
&
(
edge_time
<
end_time
)
edge_ids
=
self
.
edge_ids
[
edge_mask
]
edge_index
=
self
.
edge_index
[:,
edge_mask
]
yield
self
.
_get_node_handle
(
node_ids
,
edge_ids
,
edge_index
)
starrygl/loader/buffer.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch.futures
import
Future
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.parallel
import
all_gather_remote_objects
from
.route
import
Route
from
threading
import
Lock
class
TensorBuffers
(
nn
.
Module
):
def
__init__
(
self
,
src_size
:
int
,
channels
:
Tuple
[
Union
[
int
,
Tuple
[
int
]]])
->
None
:
class
TensorBuffer
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
num_nodes
:
int
,
route
:
Route
,
)
->
None
:
super
()
.
__init__
()
super
()
.
__init__
()
self
.
src_size
=
src_size
self
.
channels
=
channels
self
.
num_layers
=
len
(
channels
)
self
.
num_nodes
=
num_nodes
for
i
,
s
in
enumerate
(
channels
):
self
.
route
=
route
s
=
(
s
,)
if
isinstance
(
s
,
int
)
else
s
self
.
register_buffer
(
f
"data_{i}"
,
torch
.
zeros
(
src_size
,
*
s
),
persistent
=
False
)
self
.
local_lock
=
Lock
()
self
.
register_buffer
(
"_data"
,
torch
.
zeros
(
num_nodes
,
channels
),
persistent
=
False
)
self
.
rrefs
=
all_gather_remote_objects
(
self
)
@property
def
data
(
self
)
->
Tensor
:
return
self
.
get_buffer
(
"_data"
)
def
get
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
])
->
Tensor
:
@property
i
=
self
.
_idx
(
i
)
def
device
(
self
)
->
torch
.
device
:
return
self
.
data
.
device
def
local_get
(
self
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
)
->
Tensor
:
if
lock
:
with
self
.
local_lock
:
return
self
.
local_get
(
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
return
self
.
get_buffer
(
f
"data_{i}"
)
return
self
.
data
else
:
else
:
return
self
.
get_buffer
(
f
"data_{i}"
)
[
index
]
return
self
.
data
[
index
]
def
set
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
],
value
:
Tensor
):
def
local_set
(
self
,
value
:
Tensor
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
i
=
self
.
_idx
(
i
)
if
lock
:
with
self
.
local_lock
:
return
self
.
local_set
(
value
,
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
self
.
get_buffer
(
f
"data_{i}"
)[:,
...
]
=
value
self
.
data
.
copy_
(
value
)
else
:
else
:
self
.
get_buffer
(
f
"data_{i}"
)[
index
]
=
value
# value = value.to(self.device)
self
.
data
[
index
]
=
value
def
add
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
],
value
:
Tensor
):
def
local_add
(
self
,
value
:
Tensor
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
i
=
self
.
_idx
(
i
)
if
lock
:
with
self
.
local_lock
:
return
self
.
local_add
(
value
,
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
self
.
get_buffer
(
f
"data_{i}"
)[:,
...
]
+=
value
self
.
data
.
add_
(
value
)
else
:
else
:
self
.
get_buffer
(
f
"data_{i}"
)[
index
]
+=
value
# value = value.to(self.device)
self
.
data
[
index
]
+=
value
def
local_cls
(
self
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
if
lock
:
with
self
.
local_lock
:
return
self
.
local_cls
(
index
,
lock
=
False
)
if
index
is
None
:
self
.
data
.
zero_
()
else
:
self
.
data
[
index
]
=
0
def
remote_get
(
self
,
dst
:
int
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_get
,
self
.
rrefs
[
dst
],
index
=
index
,
lock
=
lock
)
def
remote_set
(
self
,
dst
:
int
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_set
,
self
.
rrefs
[
dst
],
value
,
index
=
index
,
lock
=
lock
)
def
remote_add
(
self
,
dst
:
int
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_add
,
self
.
rrefs
[
dst
],
value
,
index
=
index
,
lock
=
lock
)
def
all_remote_get
(
self
,
index
:
Tensor
,
lock
:
bool
=
True
):
def
cb0
(
idx
):
def
f
(
x
:
torch
.
futures
.
Future
[
Tensor
]):
return
x
.
value
(),
idx
return
f
def
cb1
(
buf
):
def
f
(
xs
:
torch
.
futures
.
Future
[
List
[
torch
.
futures
.
Future
]])
->
Tensor
:
for
x
in
xs
.
value
():
dat
,
idx
=
x
.
value
()
# print(dat.size(), idx.size())
buf
[
idx
]
+=
dat
return
buf
return
f
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_get
(
i
,
remote_idx
,
lock
=
lock
)
.
then
(
cb0
(
idx
)))
futs
=
torch
.
futures
.
collect_all
(
futs
)
buf
=
torch
.
zeros
(
index
.
size
(
0
),
self
.
channels
,
dtype
=
self
.
data
.
dtype
,
device
=
self
.
data
.
device
)
return
futs
.
then
(
cb1
(
buf
))
def
all_remote_set
(
self
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_set
(
i
,
value
[
idx
],
remote_idx
,
lock
=
lock
))
return
torch
.
futures
.
collect_all
(
futs
)
def
all_remote_add
(
self
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_add
(
i
,
value
[
idx
],
remote_idx
,
lock
=
lock
))
return
torch
.
futures
.
collect_all
(
futs
)
def
broadcast
(
self
,
barrier
:
bool
=
True
):
if
barrier
:
dist
.
barrier
()
index
=
torch
.
arange
(
self
.
num_nodes
,
dtype
=
torch
.
long
,
device
=
self
.
data
.
device
)
data
=
self
.
all_remote_get
(
index
,
lock
=
True
)
.
wait
()
self
.
local_set
(
data
,
lock
=
True
)
if
barrier
:
dist
.
barrier
()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
def
zero_grad
(
self
):
@staticmethod
for
name
,
grad
in
self
.
named_buffers
(
"data_"
,
recurse
=
False
):
def
_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
grad
.
zero_
()
args
=
(
method
,
rref
)
+
args
return
rpc
.
rpc_async
(
rref
.
owner
(),
TensorBuffer
.
_method_call
,
args
=
args
,
kwargs
=
kwargs
)
def
_idx
(
self
,
i
:
int
)
->
int
:
@staticmethod
assert
-
self
.
num_layers
<
i
and
i
<
self
.
num_layers
def
_method_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
(
self
.
num_layers
+
i
)
%
self
.
num_layers
self
:
TensorBuffer
=
rref
.
local_value
()
index
=
kwargs
[
"index"
]
kwargs
[
"index"
]
=
self
.
route
.
to_local_ids
(
index
)
return
method
(
self
,
*
args
,
**
kwargs
)
\ No newline at end of file
starrygl/loader/context.py
0 → 100644
View file @
f4af3231
import
torch
from
contextlib
import
contextmanager
from
torch
import
Tensor
from
typing
import
*
class
RouteContext
:
def
__init__
(
self
)
->
None
:
self
.
_futs
:
List
[
torch
.
futures
.
Future
]
=
[]
def
synchronize
(
self
):
for
fut
in
self
.
_futs
:
fut
.
wait
()
self
.
_futs
=
[]
def
add_futures
(
self
,
*
futs
):
for
fut
in
futs
:
assert
isinstance
(
fut
,
torch
.
futures
.
Future
)
self
.
_futs
.
append
(
fut
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
exc_type
is
not
None
:
raise
exc_type
(
exc_value
)
self
.
synchronize
()
starrygl/loader/route.py
0 → 100644
View file @
f4af3231
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
typing
import
*
from
starrygl.parallel
import
all_gather_remote_objects
from
.utils
import
init_local_edge_index
class
Route
(
nn
.
Module
):
def
__init__
(
self
,
src_ids
:
Tensor
,
dst_size
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
register_buffer
(
"_src_ids"
,
src_ids
,
persistent
=
False
)
self
.
dst_size
=
dst_size
self
.
_init_nids_mapper
()
self
.
_init_part_mapper
()
@staticmethod
def
from_edge_index
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
):
src_ids
,
local_edge_index
=
init_local_edge_index
(
dst_ids
,
edge_index
)
return
Route
(
src_ids
,
dst_ids
.
size
(
0
)),
local_edge_index
@property
def
src_ids
(
self
)
->
Tensor
:
return
self
.
get_buffer
(
"_src_ids"
)
@property
def
src_size
(
self
)
->
int
:
return
self
.
src_ids
.
size
(
0
)
@property
def
dst_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
[:
self
.
dst_size
]
@property
def
ext_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
[
self
.
dst_size
:]
@property
def
ext_size
(
self
)
->
int
:
return
self
.
src_size
-
self
.
dst_size
def
parts_iter
(
self
,
local_ids
:
Tensor
)
->
Iterator
[
Tuple
[
Tensor
,
Tensor
]]:
world_size
=
dist
.
get_world_size
()
part_mapper
=
self
.
part_mapper
[
local_ids
]
for
i
in
range
(
world_size
):
# part_ids = local_ids[part_mapper == i]
part_ids
=
torch
.
where
(
part_mapper
==
i
)[
0
]
glob_ids
=
self
.
src_ids
[
part_ids
]
yield
part_ids
,
glob_ids
def
to_local_ids
(
self
,
ids
:
Tensor
)
->
Tensor
:
return
self
.
nids_mapper
[
ids
]
def
_init_nids_mapper
(
self
):
num_nodes
:
int
=
self
.
src_ids
.
max
()
.
item
()
+
1
device
:
torch
.
device
=
self
.
src_ids
.
device
mapper
=
torch
.
empty
(
num_nodes
,
dtype
=
torch
.
long
,
device
=
device
)
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
mapper
[
self
.
src_ids
]
=
torch
.
arange
(
self
.
src_ids
.
size
(
0
),
dtype
=
torch
.
long
,
device
=
device
)
self
.
register_buffer
(
"nids_mapper"
,
mapper
,
persistent
=
False
)
def
_init_part_mapper
(
self
):
device
:
torch
.
device
=
self
.
src_ids
.
device
nids_mapper
=
self
.
get_buffer
(
"nids_mapper"
)
mapper
=
torch
.
empty
(
self
.
src_size
,
dtype
=
torch
.
int32
,
device
=
device
)
.
fill_
(
-
1
)
for
i
,
dst_ids
in
enumerate
(
all_gather_remote_objects
(
self
.
dst_ids
)):
dst_ids
:
Tensor
=
dst_ids
.
to_here
()
.
to
(
device
)
dst_ids
=
dst_ids
[
dst_ids
<
nids_mapper
.
size
(
0
)]
dst_local_inds
=
nids_mapper
[
dst_ids
]
dst_local_mask
=
dst_local_inds
!=
((
2
**
62
-
1
)
*
2
+
1
)
dst_local_inds
=
dst_local_inds
[
dst_local_mask
]
mapper
[
dst_local_inds
]
=
i
assert
(
mapper
>=
0
)
.
all
()
self
.
register_buffer
(
"part_mapper"
,
mapper
,
persistent
=
False
)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
starrygl/loader/utils.py
View file @
f4af3231
...
@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
...
@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
x
=
[
t
.
max
()
.
item
()
if
t
.
numel
()
>
0
else
0
for
t
in
ids
]
x
=
[
t
.
max
()
.
item
()
if
t
.
numel
()
>
0
else
0
for
t
in
ids
]
return
max
(
*
x
)
return
max
(
*
x
)
def
collect_feat0
(
src_ids
:
Tensor
,
dst_ids
:
Tensor
,
feat0
:
Tensor
):
device
=
get_compute_device
()
route
=
Route
(
dst_ids
.
to
(
device
),
src_ids
.
to
(
device
))
return
route
.
forward_a2a
(
feat0
.
to
(
device
))[
0
]
.
to
(
feat0
.
device
)
def
local_partition_fn
(
dst_size
:
Tensor
,
edge_index
:
Tensor
,
num_parts
:
int
)
->
Tensor
:
def
local_partition_fn
(
dst_size
:
Tensor
,
edge_index
:
Tensor
,
num_parts
:
int
)
->
Tensor
:
edge_index
=
edge_index
[:,
edge_index
[
0
]
<
dst_size
]
edge_index
=
edge_index
[:,
edge_index
[
0
]
<
dst_size
]
return
metis_partition
(
edge_index
,
dst_size
,
num_parts
)[
0
]
return
metis_partition
(
edge_index
,
dst_size
,
num_parts
)[
0
]
\ No newline at end of file
starrygl/nn/base_model.py
View file @
f4af3231
import
torch
#
import torch
import
torch.nn
as
nn
#
import torch.nn as nn
import
torch.distributed
as
dist
#
import torch.distributed as dist
from
torch
import
Tensor
#
from torch import Tensor
from
typing
import
*
#
from typing import *
from
starrygl.loader
import
BatchHandle
#
from starrygl.loader import BatchHandle
class
BaseLayer
(
nn
.
Module
):
#
class BaseLayer(nn.Module):
def
__init__
(
self
)
->
None
:
#
def __init__(self) -> None:
super
()
.
__init__
()
#
super().__init__()
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
#
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return
x
#
return x
def
update_forward
(
self
,
handle
:
BatchHandle
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
):
#
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x
=
handle
.
fetch_feat
()
#
x = handle.fetch_feat()
with
torch
.
no_grad
():
#
with torch.no_grad():
x
=
self
.
forward
(
x
,
edge_index
,
edge_attr
)
#
x = self.forward(x, edge_index, edge_attr)
handle
.
update_feat
(
x
)
#
handle.update_feat(x)
def
block_backward
(
self
,
handle
:
BatchHandle
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
):
#
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x
=
handle
.
fetch_feat
()
.
requires_grad_
()
#
x = handle.fetch_feat().requires_grad_()
g
=
handle
.
fetch_grad
()
#
g = handle.fetch_grad()
self
.
forward
(
x
,
edge_index
,
edge_attr
)
.
backward
(
g
)
#
self.forward(x, edge_index, edge_attr).backward(g)
handle
.
accumulate_grad
(
x
.
grad
)
#
handle.accumulate_grad(x.grad)
x
.
grad
=
None
#
x.grad = None
def
all_reduce_grad
(
self
):
#
def all_reduce_grad(self):
for
p
in
self
.
parameters
():
#
for p in self.parameters():
if
p
.
grad
is
not
None
:
#
if p.grad is not None:
dist
.
all_reduce
(
p
.
grad
,
op
=
dist
.
ReduceOp
.
SUM
)
#
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class
BaseModel
(
nn
.
Module
):
#
class BaseModel(nn.Module):
def
__init__
(
self
,
#
def __init__(self,
num_features
:
int
,
#
num_features: int,
layers
:
List
[
int
],
#
layers: List[int],
prev_layer
:
bool
=
False
,
#
prev_layer: bool = False,
post_layer
:
bool
=
False
,
#
post_layer: bool = False,
)
->
None
:
#
) -> None:
super
()
.
__init__
()
#
super().__init__()
def
init_prev_layer
(
self
)
->
Tensor
:
#
def init_prev_layer(self) -> Tensor:
pass
#
pass
def
init_post_layer
(
self
)
->
Tensor
:
#
def init_post_layer(self) -> Tensor:
pass
#
pass
def
init_conv_layer
(
self
)
->
Tensor
:
# def init_conv_layer(self) -> Tensor:
pass
# pass
\ No newline at end of file
\ No newline at end of file
starrygl/parallel/__init__.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
import
os
import
os
from
typing
import
*
from
typing
import
*
...
@@ -15,29 +16,36 @@ def convert_parallel_model(
...
@@ -15,29 +16,36 @@ def convert_parallel_model(
net
=
SyncBatchNorm
.
convert_sync_batchnorm
(
net
)
net
=
SyncBatchNorm
.
convert_sync_batchnorm
(
net
)
net
=
nn
.
parallel
.
DistributedDataParallel
(
net
,
net
=
nn
.
parallel
.
DistributedDataParallel
(
net
,
find_unused_parameters
=
find_unused_parameters
,
find_unused_parameters
=
find_unused_parameters
,
# broadcast_buffers=False,
)
)
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return
net
return
net
def
init_process_group
(
backend
:
str
=
"gloo"
)
->
torch
.
device
:
def
init_process_group
(
backend
:
str
=
"gloo"
)
->
torch
.
device
:
rank
=
int
(
os
.
getenv
(
"RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_RANK"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_SIZE"
))
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
ccl_init_method
(),
rank
=
rank
,
world_size
=
world_size
,
)
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
()
rpc_backend_options
.
init_method
=
rpc_init_method
()
for
i
in
range
(
world_size
):
rpc_backend_options
.
set_device_map
(
f
"worker{i}"
,
{
rank
:
i
})
rpc
.
init_rpc
(
name
=
f
"worker{rank}"
,
rank
=
rank
,
world_size
=
world_size
,
rpc_backend_options
=
rpc_backend_options
,
)
local_rank
=
os
.
getenv
(
"LOCAL_RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_LOCAL_RANK"
)
local_rank
=
os
.
getenv
(
"LOCAL_RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_LOCAL_RANK"
)
if
local_rank
is
not
None
:
if
local_rank
is
not
None
:
local_rank
=
int
(
local_rank
)
local_rank
=
int
(
local_rank
)
dist
.
init_process_group
(
backend
)
if
backend
==
"nccl"
or
backend
==
"mpi"
:
if
backend
==
"nccl"
or
backend
==
"mpi"
:
if
local_rank
is
None
:
device
=
torch
.
device
(
f
"cuda:{local_rank or rank}"
)
device
=
torch
.
device
(
f
"cuda:{dist.get_rank()}"
)
else
:
device
=
torch
.
device
(
f
"cuda:{local_rank}"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
...
@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
...
@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
_COMPUTE_DEVICE
=
device
_COMPUTE_DEVICE
=
device
return
device
return
device
def
rank_world_size
()
->
Tuple
[
int
,
int
]:
return
dist
.
get_rank
(),
dist
.
get_world_size
()
def
get_worker_info
(
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
return
rpc
.
get_worker_info
(
f
"worker{rank}"
)
_COMPUTE_DEVICE
=
torch
.
device
(
"cpu"
)
_COMPUTE_DEVICE
=
torch
.
device
(
"cpu"
)
def
get_compute_device
()
->
torch
.
device
:
def
get_compute_device
()
->
torch
.
device
:
global
_COMPUTE_DEVICE
global
_COMPUTE_DEVICE
return
_COMPUTE_DEVICE
return
_COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT
=
None
def
_remote_object
():
global
_TEMP_AG_REMOTE_OBJECT
return
_TEMP_AG_REMOTE_OBJECT
def
all_gather_remote_objects
(
obj
:
Any
)
->
List
[
rpc
.
RRef
]:
global
_TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT
=
rpc
.
RRef
(
obj
)
dist
.
barrier
()
world_size
=
dist
.
get_world_size
()
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
world_size
):
info
=
get_worker_info
(
i
)
futs
.
append
(
rpc
.
rpc_async
(
info
,
_remote_object
))
rrefs
:
List
[
rpc
.
RRef
]
=
[]
for
f
in
futs
:
f
.
wait
()
rrefs
.
append
(
f
.
value
())
dist
.
barrier
()
_TEMP_AG_REMOTE_OBJECT
=
None
return
rrefs
def
ccl_init_method
()
->
str
:
master_addr
=
os
.
environ
[
"MASTER_ADDR"
]
master_port
=
int
(
os
.
environ
[
"MASTER_PORT"
])
return
f
"tcp://{master_addr}:{master_port}"
def
rpc_init_method
()
->
str
:
master_addr
=
os
.
environ
[
"MASTER_ADDR"
]
master_port
=
int
(
os
.
environ
[
"MASTER_PORT"
])
return
f
"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
starrygl/parallel/degree.py
View file @
f4af3231
import
torch
#
import torch
from
torch
import
Tensor
#
from torch import Tensor
from
typing
import
*
#
from typing import *
from
torch_scatter
import
scatter_sum
#
from torch_scatter import scatter_sum
from
starrygl.core.route
import
Route
#
from starrygl.core.route import Route
def
compute_in_degree
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size
=
route
.
src_size
#
dst_size = route.src_size
x
=
torch
.
ones
(
edge_index
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
edge_index
.
device
)
#
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_deg
=
scatter_sum
(
x
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
dst_size
)
#
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg
,
_
=
route
.
forward_a2a
(
in_deg
)
#
in_deg, _ = route.forward_a2a(in_deg)
return
in_deg
#
return in_deg
def
compute_out_degree
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size
=
route
.
dst_size
#
src_size = route.dst_size
x
=
torch
.
ones
(
edge_index
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
edge_index
.
device
)
#
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg
=
scatter_sum
(
x
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
src_size
)
#
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg
,
_
=
route
.
backward_a2a
(
out_deg
)
#
out_deg, _ = route.backward_a2a(out_deg)
out_deg
,
_
=
route
.
forward_a2a
(
out_deg
)
#
out_deg, _ = route.forward_a2a(out_deg)
return
out_deg
#
return out_deg
def
compute_gcn_norm
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg
=
compute_in_degree
(
edge_index
,
route
)
#
in_deg = compute_in_degree(edge_index, route)
out_deg
=
compute_out_degree
(
edge_index
,
route
)
#
out_deg = compute_out_degree(edge_index, route)
a
=
in_deg
[
edge_index
[
0
]]
.
pow
(
-
0.5
)
#
a = in_deg[edge_index[0]].pow(-0.5)
b
=
out_deg
[
edge_index
[
0
]]
.
pow
(
-
0.5
)
#
b = out_deg[edge_index[0]].pow(-0.5)
x
=
a
*
b
#
x = a * b
x
[
x
.
isinf
()]
=
0.0
#
x[x.isinf()] = 0.0
x
[
x
.
isnan
()]
=
0.0
#
x[x.isnan()] = 0.0
return
x
#
return x
train.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
import
os
from
torch_sparse
import
SparseTensor
import
time
import
psutil
from
starrygl.nn
import
*
from
starrygl.graph
import
DistGraph
from
starrygl.loader
import
NodeLoader
,
NodeHandle
,
TensorBuffer
,
RouteContext
from
starrygl.parallel
import
init_process_group
,
convert_parallel_model
from
starrygl.parallel
import
init_process_group
,
convert_parallel_model
from
starrygl.parallel
import
compute_gcn_norm
,
SyncBatchNorm
,
with_nccl
from
starrygl.utils
import
partition_load
,
main_print
,
sync_print
from
starrygl.utils
import
train_epoch
,
eval_epoch
,
partition_load
,
main_print
class
SimpleGNNConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
linear
=
nn
.
Linear
(
in_channels
,
out_channels
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
)
->
Tensor
:
x
=
self
.
linear
(
x
)
x
=
adj_t
@
x
return
self
.
norm
(
x
)
class
SimpleGNN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
hidden_channels
:
int
,
out_channels
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
conv1
=
SimpleGNNConv
(
in_channels
,
hidden_channels
)
self
.
conv2
=
SimpleGNNConv
(
hidden_channels
,
hidden_channels
)
self
.
fc_out
=
nn
.
Linear
(
hidden_channels
,
out_channels
)
def
forward
(
self
,
handle
:
NodeHandle
,
buffers
:
List
[
TensorBuffer
])
->
Tensor
:
futs
=
[
handle
.
get_src_feats
(
buffers
[
0
]),
handle
.
get_ext_feats
(
buffers
[
1
]),
]
with
RouteContext
()
as
ctx
:
x
=
futs
[
0
]
.
wait
()
x
=
self
.
conv1
(
x
,
handle
.
adj_t
)
x
,
f
=
handle
.
push_and_pull
(
x
,
futs
[
1
],
buffers
[
1
])
x
=
self
.
conv2
(
x
,
handle
.
adj_t
)
ctx
.
add_futures
(
f
)
# 等当前batch推理完成后需要等待所有futures完成
x
=
self
.
fc_out
(
x
)
return
x
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# 启动分布式进程组,并分配计算设备
# 启动分布式进程组,并分配计算设备
device
=
init_process_group
(
backend
=
"nccl"
)
device
=
init_process_group
(
backend
=
"nccl"
)
# 加载数据集
# 加载数据集
pdata
=
partition_load
(
"./cora"
,
algo
=
"metis"
)
.
to
(
device
)
pdata
=
partition_load
(
"./cora"
,
algo
=
"metis"
)
g
=
DistGraph
(
ids
=
pdata
.
ids
,
edge_index
=
pdata
.
edge_index
)
loader
=
NodeLoader
(
pdata
.
ids
,
pdata
.
edge_index
,
device
)
# g.args["async_op"] = True
# 创建历史缓存
g
.
args
[
"sample_k"
]
=
20
hidden_size
=
64
buffers
:
List
[
TensorBuffer
]
=
[
g
.
edata
[
"gcn_norm"
]
=
compute_gcn_norm
(
g
)
TensorBuffer
(
pdata
.
num_features
,
loader
.
src_size
,
loader
.
route
),
g
.
ndata
[
"x"
]
=
pdata
.
x
TensorBuffer
(
hidden_size
,
loader
.
src_size
,
loader
.
route
),
g
.
ndata
[
"y"
]
=
pdata
.
y
]
# 定义GAT图神经网络模型
# 设置节点初始特征,并预同步到其它分区
net
=
GCN
(
buffers
[
0
]
.
data
[:
loader
.
dst_size
]
=
pdata
.
x
g
=
g
,
buffers
[
0
]
.
broadcast
()
layer_options
=
BasicLayerOptions
(
in_channels
=
pdata
.
num_features
,
# 创建模型
hidden_channels
=
64
,
net
=
SimpleGNN
(
pdata
.
num_features
,
hidden_size
,
pdata
.
num_classes
)
.
to
(
device
)
num_layers
=
2
,
out_channels
=
pdata
.
num_classes
,
norm
=
"batchnorm"
,
),
input_options
=
BasicInputOptions
(
straight_enabled
=
True
,
),
jk_options
=
BasicJKOptions
(
jk_mode
=
None
,
),
straight_options
=
BasicStraightOptions
(
enabled
=
True
,
),
)
.
to
(
device
)
# 转换成分布式并行版本
net
=
convert_parallel_model
(
net
)
net
=
convert_parallel_model
(
net
)
opt
=
torch
.
optim
.
Adam
(
net
.
parameters
(),
lr
=
1e-3
)
# 训练阶段
for
ep
in
range
(
1
,
100
+
1
):
epoch_loss
=
0.0
net
.
train
()
for
handle
in
loader
.
iter
(
128
):
fut_m
=
handle
.
get_dst_feats
(
pdata
.
train_mask
)
fut_y
=
handle
.
get_dst_feats
(
pdata
.
y
)
# 定义优化器
h
=
net
(
handle
,
buffers
)
opt
=
torch
.
optim
.
Adam
(
net
.
parameters
(),
lr
=
0.01
,
weight_decay
=
5e-4
)
train_mask
=
fut_m
.
wait
()
avg_mem
=
0.0
logits
=
h
[
train_mask
]
avg_dur
=
0.0
if
logits
.
size
(
0
)
>
0
:
avg_num
=
0
y
=
fut_y
.
wait
()[
train_mask
]
loss
=
nn
.
CrossEntropyLoss
()(
logits
,
y
)
# 开始训练
best_val_acc
=
best_test_acc
=
0
opt
.
zero_grad
()
for
ep
in
range
(
1
,
10
+
1
):
loss
.
backward
()
time_start
=
time
.
time
()
opt
.
step
()
train_loss
,
train_acc
=
train_epoch
(
net
,
opt
,
g
,
pdata
.
train_mask
)
epoch_loss
+=
loss
.
item
()
val_loss
,
val_acc
=
eval_epoch
(
net
,
g
,
pdata
.
val_mask
)
main_print
(
ep
,
epoch_loss
)
test_loss
,
test_acc
=
eval_epoch
(
net
,
g
,
pdata
.
test_mask
)
rpc
.
shutdown
()
# val_loss, val_acc = train_loss, train_acc
# test_loss, test_acc = train_loss, train_acc
if
val_acc
>
best_val_acc
:
# import torch
best_val_acc
=
val_acc
# import torch.nn as nn
best_test_acc
=
test_acc
# from torch import Tensor
duration
=
time
.
time
()
-
time_start
# from typing import *
if
with_nccl
():
cur_mem
=
torch
.
cuda
.
memory_reserved
()
# import os
else
:
# import time
cur_mem
=
psutil
.
Process
(
os
.
getpid
())
.
memory_info
()
.
rss
# import psutil
cur_mem_mb
=
round
(
cur_mem
/
1024
**
2
)
# from starrygl.nn import *
if
ep
>
1
:
# from starrygl.graph import DistGraph
avg_mem
+=
cur_mem
avg_dur
+=
duration
# from starrygl.parallel import init_process_group, convert_parallel_model
avg_num
+=
1
# from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
main_print
(
# from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
f
"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f
"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f
"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f
"best_accuracy: {best_test_acc:.4f}"
)
# if __name__ == "__main__":
avg_mem
=
round
(
avg_mem
/
avg_num
/
1024
**
2
)
# # 启动分布式进程组,并分配计算设备
avg_dur
=
avg_dur
/
avg_num
# device = init_process_group(backend="nccl")
main_print
(
f
"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s"
)
# # 加载数据集
# pdata = partition_load("./cora", algo="metis").to(device)
# g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# # g.args["async_op"] = True
# # g.args["num_samples"] = 20
# g.edata["gcn_norm"] = compute_gcn_norm(g)
# g.ndata["x"] = pdata.x
# g.ndata["y"] = pdata.y
# # 定义GAT图神经网络模型
# net = ShrinkGCN(
# g=g,
# layer_options=BasicLayerOptions(
# in_channels=pdata.num_features,
# hidden_channels=64,
# num_layers=3,
# out_channels=pdata.num_classes,
# norm="batchnorm",
# ),
# input_options=BasicInputOptions(
# straight_enabled=True,
# straight_num_samples = 200,
# ),
# straight_options=BasicStraightOptions(
# enabled=True,
# num_samples = 20,
# # beta=1.1,
# ),
# ).to(device)
# # 转换成分布式并行版本
# net = convert_parallel_model(net)
# # 定义优化器
# opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
# avg_mem = 0.0
# avg_dur = 0.0
# avg_num = 0
# # 开始训练
# best_val_acc = best_test_acc = 0
# for ep in range(1, 50+1):
# time_start = time.time()
# train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
# val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
# test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# # val_loss, val_acc = train_loss, train_acc
# # test_loss, test_acc = train_loss, train_acc
# if val_acc > best_val_acc:
# best_val_acc = val_acc
# best_test_acc = test_acc
# duration = time.time() - time_start
# if with_nccl():
# cur_mem = torch.cuda.memory_reserved()
# else:
# cur_mem = psutil.Process(os.getpid()).memory_info().rss
# cur_mem_mb = round(cur_mem / 1024**2)
# if ep > 1:
# avg_mem += cur_mem
# avg_dur += duration
# avg_num += 1
# main_print(
# f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
# f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
# f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
# f"best_accuracy: {best_test_acc:.4f}")
# avg_mem = round(avg_mem / avg_num / 1024**2)
# avg_dur = avg_dur / avg_num
# main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment