Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
f4af3231
Commit
f4af3231
authored
Jul 24, 2023
by
Wenjie Huang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rpc based
parent
3f94b191
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
760 additions
and
196 deletions
+760
-196
.gitignore
+5
-2
starrygl/loader/__init__.py
+0
-0
starrygl/loader/buffer.py
+165
-26
starrygl/loader/context.py
+31
-0
starrygl/loader/route.py
+238
-0
starrygl/loader/utils.py
+0
-6
starrygl/nn/base_model.py
+44
-44
starrygl/parallel/__init__.py
+64
-15
starrygl/parallel/degree.py
+27
-27
train.py
+186
-76
No files found.
.gitignore
View file @
f4af3231
...
@@ -159,4 +159,7 @@ cython_debug/
...
@@ -159,4 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/
cora/
cora/
\ No newline at end of file
/test_*
/*.ipynb
/s.py
\ No newline at end of file
starrygl/loader/__init__.py
View file @
f4af3231
This diff is collapsed.
Click to expand it.
starrygl/loader/buffer.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch.futures
import
Future
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
from
starrygl.parallel
import
all_gather_remote_objects
from
.route
import
Route
from
threading
import
Lock
class
TensorBuffers
(
nn
.
Module
):
def
__init__
(
self
,
src_size
:
int
,
channels
:
Tuple
[
Union
[
int
,
Tuple
[
int
]]])
->
None
:
class
TensorBuffer
(
nn
.
Module
):
def
__init__
(
self
,
channels
:
int
,
num_nodes
:
int
,
route
:
Route
,
)
->
None
:
super
()
.
__init__
()
super
()
.
__init__
()
self
.
src_size
=
src_size
self
.
channels
=
channels
self
.
num_layers
=
len
(
channels
)
self
.
num_nodes
=
num_nodes
for
i
,
s
in
enumerate
(
channels
):
self
.
route
=
route
s
=
(
s
,)
if
isinstance
(
s
,
int
)
else
s
self
.
register_buffer
(
f
"data_{i}"
,
torch
.
zeros
(
src_size
,
*
s
),
persistent
=
False
)
self
.
local_lock
=
Lock
()
self
.
register_buffer
(
"_data"
,
torch
.
zeros
(
num_nodes
,
channels
),
persistent
=
False
)
self
.
rrefs
=
all_gather_remote_objects
(
self
)
@property
def
data
(
self
)
->
Tensor
:
return
self
.
get_buffer
(
"_data"
)
def
get
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
])
->
Tensor
:
@property
i
=
self
.
_idx
(
i
)
def
device
(
self
)
->
torch
.
device
:
return
self
.
data
.
device
def
local_get
(
self
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
)
->
Tensor
:
if
lock
:
with
self
.
local_lock
:
return
self
.
local_get
(
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
return
self
.
get_buffer
(
f
"data_{i}"
)
return
self
.
data
else
:
else
:
return
self
.
get_buffer
(
f
"data_{i}"
)[
index
]
return
self
.
data
[
index
]
def
set
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
],
value
:
Tensor
):
def
local_set
(
self
,
value
:
Tensor
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
i
=
self
.
_idx
(
i
)
if
lock
:
with
self
.
local_lock
:
return
self
.
local_set
(
value
,
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
self
.
get_buffer
(
f
"data_{i}"
)[:,
...
]
=
value
self
.
data
.
copy_
(
value
)
else
:
else
:
self
.
get_buffer
(
f
"data_{i}"
)[
index
]
=
value
# value = value.to(self.device)
self
.
data
[
index
]
=
value
def
local_add
(
self
,
value
:
Tensor
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
if
lock
:
with
self
.
local_lock
:
return
self
.
local_add
(
value
,
index
,
lock
=
False
)
if
index
is
None
:
self
.
data
.
add_
(
value
)
else
:
# value = value.to(self.device)
self
.
data
[
index
]
+=
value
def
add
(
self
,
i
:
int
,
index
:
Optional
[
Tensor
],
value
:
Tensor
):
def
local_cls
(
self
,
index
:
Optional
[
Tensor
]
=
None
,
lock
:
bool
=
True
):
i
=
self
.
_idx
(
i
)
if
lock
:
with
self
.
local_lock
:
return
self
.
local_cls
(
index
,
lock
=
False
)
if
index
is
None
:
if
index
is
None
:
self
.
get_buffer
(
f
"data_{i}"
)[:,
...
]
+=
value
self
.
data
.
zero_
()
else
:
else
:
self
.
get_buffer
(
f
"data_{i}"
)[
index
]
+=
value
self
.
data
[
index
]
=
0
def
remote_get
(
self
,
dst
:
int
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_get
,
self
.
rrefs
[
dst
],
index
=
index
,
lock
=
lock
)
def
remote_set
(
self
,
dst
:
int
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_set
,
self
.
rrefs
[
dst
],
value
,
index
=
index
,
lock
=
lock
)
def
remote_add
(
self
,
dst
:
int
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
return
TensorBuffer
.
_remote_call
(
TensorBuffer
.
local_add
,
self
.
rrefs
[
dst
],
value
,
index
=
index
,
lock
=
lock
)
def
all_remote_get
(
self
,
index
:
Tensor
,
lock
:
bool
=
True
):
def
cb0
(
idx
):
def
f
(
x
:
torch
.
futures
.
Future
[
Tensor
]):
return
x
.
value
(),
idx
return
f
def
cb1
(
buf
):
def
f
(
xs
:
torch
.
futures
.
Future
[
List
[
torch
.
futures
.
Future
]])
->
Tensor
:
for
x
in
xs
.
value
():
dat
,
idx
=
x
.
value
()
# print(dat.size(), idx.size())
buf
[
idx
]
+=
dat
return
buf
return
f
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_get
(
i
,
remote_idx
,
lock
=
lock
)
.
then
(
cb0
(
idx
)))
futs
=
torch
.
futures
.
collect_all
(
futs
)
buf
=
torch
.
zeros
(
index
.
size
(
0
),
self
.
channels
,
dtype
=
self
.
data
.
dtype
,
device
=
self
.
data
.
device
)
return
futs
.
then
(
cb1
(
buf
))
def
all_remote_set
(
self
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_set
(
i
,
value
[
idx
],
remote_idx
,
lock
=
lock
))
return
torch
.
futures
.
collect_all
(
futs
)
def
all_remote_add
(
self
,
value
:
Tensor
,
index
:
Tensor
,
lock
:
bool
=
True
):
futs
=
[]
for
i
,
(
idx
,
remote_idx
)
in
enumerate
(
self
.
route
.
parts_iter
(
index
)):
futs
.
append
(
self
.
remote_add
(
i
,
value
[
idx
],
remote_idx
,
lock
=
lock
))
return
torch
.
futures
.
collect_all
(
futs
)
def
broadcast
(
self
,
barrier
:
bool
=
True
):
if
barrier
:
dist
.
barrier
()
index
=
torch
.
arange
(
self
.
num_nodes
,
dtype
=
torch
.
long
,
device
=
self
.
data
.
device
)
data
=
self
.
all_remote_get
(
index
,
lock
=
True
)
.
wait
()
self
.
local_set
(
data
,
lock
=
True
)
if
barrier
:
dist
.
barrier
()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
def
zero_grad
(
self
):
@staticmethod
for
name
,
grad
in
self
.
named_buffers
(
"data_"
,
recurse
=
False
):
def
_remote_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
grad
.
zero_
()
args
=
(
method
,
rref
)
+
args
return
rpc
.
rpc_async
(
rref
.
owner
(),
TensorBuffer
.
_method_call
,
args
=
args
,
kwargs
=
kwargs
)
def
_idx
(
self
,
i
:
int
)
->
int
:
@staticmethod
assert
-
self
.
num_layers
<
i
and
i
<
self
.
num_layers
def
_method_call
(
method
,
rref
:
rpc
.
RRef
,
*
args
,
**
kwargs
):
return
(
self
.
num_layers
+
i
)
%
self
.
num_layers
self
:
TensorBuffer
=
rref
.
local_value
()
index
=
kwargs
[
"index"
]
kwargs
[
"index"
]
=
self
.
route
.
to_local_ids
(
index
)
return
method
(
self
,
*
args
,
**
kwargs
)
\ No newline at end of file
starrygl/loader/context.py
0 → 100644
View file @
f4af3231
import
torch
from
contextlib
import
contextmanager
from
torch
import
Tensor
from
typing
import
*
class
RouteContext
:
def
__init__
(
self
)
->
None
:
self
.
_futs
:
List
[
torch
.
futures
.
Future
]
=
[]
def
synchronize
(
self
):
for
fut
in
self
.
_futs
:
fut
.
wait
()
self
.
_futs
=
[]
def
add_futures
(
self
,
*
futs
):
for
fut
in
futs
:
assert
isinstance
(
fut
,
torch
.
futures
.
Future
)
self
.
_futs
.
append
(
fut
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
exc_type
is
not
None
:
raise
exc_type
(
exc_value
)
self
.
synchronize
()
starrygl/loader/route.py
0 → 100644
View file @
f4af3231
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
typing
import
*
from
starrygl.parallel
import
all_gather_remote_objects
from
.utils
import
init_local_edge_index
class
Route
(
nn
.
Module
):
def
__init__
(
self
,
src_ids
:
Tensor
,
dst_size
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
register_buffer
(
"_src_ids"
,
src_ids
,
persistent
=
False
)
self
.
dst_size
=
dst_size
self
.
_init_nids_mapper
()
self
.
_init_part_mapper
()
@staticmethod
def
from_edge_index
(
dst_ids
:
Tensor
,
edge_index
:
Tensor
):
src_ids
,
local_edge_index
=
init_local_edge_index
(
dst_ids
,
edge_index
)
return
Route
(
src_ids
,
dst_ids
.
size
(
0
)),
local_edge_index
@property
def
src_ids
(
self
)
->
Tensor
:
return
self
.
get_buffer
(
"_src_ids"
)
@property
def
src_size
(
self
)
->
int
:
return
self
.
src_ids
.
size
(
0
)
@property
def
dst_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
[:
self
.
dst_size
]
@property
def
ext_ids
(
self
)
->
Tensor
:
return
self
.
src_ids
[
self
.
dst_size
:]
@property
def
ext_size
(
self
)
->
int
:
return
self
.
src_size
-
self
.
dst_size
def
parts_iter
(
self
,
local_ids
:
Tensor
)
->
Iterator
[
Tuple
[
Tensor
,
Tensor
]]:
world_size
=
dist
.
get_world_size
()
part_mapper
=
self
.
part_mapper
[
local_ids
]
for
i
in
range
(
world_size
):
# part_ids = local_ids[part_mapper == i]
part_ids
=
torch
.
where
(
part_mapper
==
i
)[
0
]
glob_ids
=
self
.
src_ids
[
part_ids
]
yield
part_ids
,
glob_ids
def
to_local_ids
(
self
,
ids
:
Tensor
)
->
Tensor
:
return
self
.
nids_mapper
[
ids
]
def
_init_nids_mapper
(
self
):
num_nodes
:
int
=
self
.
src_ids
.
max
()
.
item
()
+
1
device
:
torch
.
device
=
self
.
src_ids
.
device
mapper
=
torch
.
empty
(
num_nodes
,
dtype
=
torch
.
long
,
device
=
device
)
.
fill_
((
2
**
62
-
1
)
*
2
+
1
)
mapper
[
self
.
src_ids
]
=
torch
.
arange
(
self
.
src_ids
.
size
(
0
),
dtype
=
torch
.
long
,
device
=
device
)
self
.
register_buffer
(
"nids_mapper"
,
mapper
,
persistent
=
False
)
def
_init_part_mapper
(
self
):
device
:
torch
.
device
=
self
.
src_ids
.
device
nids_mapper
=
self
.
get_buffer
(
"nids_mapper"
)
mapper
=
torch
.
empty
(
self
.
src_size
,
dtype
=
torch
.
int32
,
device
=
device
)
.
fill_
(
-
1
)
for
i
,
dst_ids
in
enumerate
(
all_gather_remote_objects
(
self
.
dst_ids
)):
dst_ids
:
Tensor
=
dst_ids
.
to_here
()
.
to
(
device
)
dst_ids
=
dst_ids
[
dst_ids
<
nids_mapper
.
size
(
0
)]
dst_local_inds
=
nids_mapper
[
dst_ids
]
dst_local_mask
=
dst_local_inds
!=
((
2
**
62
-
1
)
*
2
+
1
)
dst_local_inds
=
dst_local_inds
[
dst_local_mask
]
mapper
[
dst_local_inds
]
=
i
assert
(
mapper
>=
0
)
.
all
()
self
.
register_buffer
(
"part_mapper"
,
mapper
,
persistent
=
False
)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
starrygl/loader/utils.py
View file @
f4af3231
...
@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
...
@@ -50,11 +50,6 @@ def calc_max_ids(*ids: Tensor) -> int:
x
=
[
t
.
max
()
.
item
()
if
t
.
numel
()
>
0
else
0
for
t
in
ids
]
x
=
[
t
.
max
()
.
item
()
if
t
.
numel
()
>
0
else
0
for
t
in
ids
]
return
max
(
*
x
)
return
max
(
*
x
)
def
collect_feat0
(
src_ids
:
Tensor
,
dst_ids
:
Tensor
,
feat0
:
Tensor
):
device
=
get_compute_device
()
route
=
Route
(
dst_ids
.
to
(
device
),
src_ids
.
to
(
device
))
return
route
.
forward_a2a
(
feat0
.
to
(
device
))[
0
]
.
to
(
feat0
.
device
)
def
local_partition_fn
(
dst_size
:
Tensor
,
edge_index
:
Tensor
,
num_parts
:
int
)
->
Tensor
:
def
local_partition_fn
(
dst_size
:
Tensor
,
edge_index
:
Tensor
,
num_parts
:
int
)
->
Tensor
:
edge_index
=
edge_index
[:,
edge_index
[
0
]
<
dst_size
]
edge_index
=
edge_index
[:,
edge_index
[
0
]
<
dst_size
]
return
metis_partition
(
edge_index
,
dst_size
,
num_parts
)[
0
]
return
metis_partition
(
edge_index
,
dst_size
,
num_parts
)[
0
]
\ No newline at end of file
starrygl/nn/base_model.py
View file @
f4af3231
import
torch
#
import torch
import
torch.nn
as
nn
#
import torch.nn as nn
import
torch.distributed
as
dist
#
import torch.distributed as dist
from
torch
import
Tensor
#
from torch import Tensor
from
typing
import
*
#
from typing import *
from
starrygl.loader
import
BatchHandle
#
from starrygl.loader import BatchHandle
class
BaseLayer
(
nn
.
Module
):
#
class BaseLayer(nn.Module):
def
__init__
(
self
)
->
None
:
#
def __init__(self) -> None:
super
()
.
__init__
()
#
super().__init__()
def
forward
(
self
,
x
:
Tensor
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
#
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
return
x
#
return x
def
update_forward
(
self
,
handle
:
BatchHandle
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
):
#
def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x
=
handle
.
fetch_feat
()
#
x = handle.fetch_feat()
with
torch
.
no_grad
():
#
with torch.no_grad():
x
=
self
.
forward
(
x
,
edge_index
,
edge_attr
)
#
x = self.forward(x, edge_index, edge_attr)
handle
.
update_feat
(
x
)
#
handle.update_feat(x)
def
block_backward
(
self
,
handle
:
BatchHandle
,
edge_index
:
Tensor
,
edge_attr
:
Optional
[
Tensor
]
=
None
):
#
def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
x
=
handle
.
fetch_feat
()
.
requires_grad_
()
#
x = handle.fetch_feat().requires_grad_()
g
=
handle
.
fetch_grad
()
#
g = handle.fetch_grad()
self
.
forward
(
x
,
edge_index
,
edge_attr
)
.
backward
(
g
)
#
self.forward(x, edge_index, edge_attr).backward(g)
handle
.
accumulate_grad
(
x
.
grad
)
#
handle.accumulate_grad(x.grad)
x
.
grad
=
None
#
x.grad = None
def
all_reduce_grad
(
self
):
#
def all_reduce_grad(self):
for
p
in
self
.
parameters
():
#
for p in self.parameters():
if
p
.
grad
is
not
None
:
#
if p.grad is not None:
dist
.
all_reduce
(
p
.
grad
,
op
=
dist
.
ReduceOp
.
SUM
)
#
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
class
BaseModel
(
nn
.
Module
):
# class BaseModel(nn.Module):
def
__init__
(
self
,
# def __init__(self,
num_features
:
int
,
# num_features: int,
layers
:
List
[
int
],
# layers: List[int],
prev_layer
:
bool
=
False
,
# prev_layer: bool = False,
post_layer
:
bool
=
False
,
# post_layer: bool = False,
)
->
None
:
# ) -> None:
super
()
.
__init__
()
# super().__init__()
def
init_prev_layer
(
self
)
->
Tensor
:
# def init_prev_layer(self) -> Tensor:
pass
# pass
def
init_post_layer
(
self
)
->
Tensor
:
# def init_post_layer(self) -> Tensor:
pass
# pass
def
init_conv_layer
(
self
)
->
Tensor
:
# def init_conv_layer(self) -> Tensor:
pass
# pass
\ No newline at end of file
\ No newline at end of file
starrygl/parallel/__init__.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
import
os
import
os
from
typing
import
*
from
typing
import
*
...
@@ -15,29 +16,36 @@ def convert_parallel_model(
...
@@ -15,29 +16,36 @@ def convert_parallel_model(
net
=
SyncBatchNorm
.
convert_sync_batchnorm
(
net
)
net
=
SyncBatchNorm
.
convert_sync_batchnorm
(
net
)
net
=
nn
.
parallel
.
DistributedDataParallel
(
net
,
net
=
nn
.
parallel
.
DistributedDataParallel
(
net
,
find_unused_parameters
=
find_unused_parameters
,
find_unused_parameters
=
find_unused_parameters
,
# broadcast_buffers=False,
)
)
# for name, buffer in net.named_buffers():
# if name.endswith("last_embd"):
# continue
# if name.endswith("last_w"):
# continue
# dist.broadcast(buffer, src=0)
return
net
return
net
def
init_process_group
(
backend
:
str
=
"gloo"
)
->
torch
.
device
:
def
init_process_group
(
backend
:
str
=
"gloo"
)
->
torch
.
device
:
rank
=
int
(
os
.
getenv
(
"RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_RANK"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_SIZE"
))
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
ccl_init_method
(),
rank
=
rank
,
world_size
=
world_size
,
)
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
()
rpc_backend_options
.
init_method
=
rpc_init_method
()
for
i
in
range
(
world_size
):
rpc_backend_options
.
set_device_map
(
f
"worker{i}"
,
{
rank
:
i
})
rpc
.
init_rpc
(
name
=
f
"worker{rank}"
,
rank
=
rank
,
world_size
=
world_size
,
rpc_backend_options
=
rpc_backend_options
,
)
local_rank
=
os
.
getenv
(
"LOCAL_RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_LOCAL_RANK"
)
local_rank
=
os
.
getenv
(
"LOCAL_RANK"
)
or
os
.
getenv
(
"OMPI_COMM_WORLD_LOCAL_RANK"
)
if
local_rank
is
not
None
:
if
local_rank
is
not
None
:
local_rank
=
int
(
local_rank
)
local_rank
=
int
(
local_rank
)
dist
.
init_process_group
(
backend
)
if
backend
==
"nccl"
or
backend
==
"mpi"
:
if
backend
==
"nccl"
or
backend
==
"mpi"
:
if
local_rank
is
None
:
device
=
torch
.
device
(
f
"cuda:{local_rank or rank}"
)
device
=
torch
.
device
(
f
"cuda:{dist.get_rank()}"
)
else
:
device
=
torch
.
device
(
f
"cuda:{local_rank}"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
...
@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
...
@@ -46,7 +54,48 @@ def init_process_group(backend: str = "gloo") -> torch.device:
_COMPUTE_DEVICE
=
device
_COMPUTE_DEVICE
=
device
return
device
return
device
def
rank_world_size
()
->
Tuple
[
int
,
int
]:
return
dist
.
get_rank
(),
dist
.
get_world_size
()
def
get_worker_info
(
rank
:
Optional
[
int
]
=
None
)
->
rpc
.
WorkerInfo
:
rank
=
dist
.
get_rank
()
if
rank
is
None
else
rank
return
rpc
.
get_worker_info
(
f
"worker{rank}"
)
_COMPUTE_DEVICE
=
torch
.
device
(
"cpu"
)
_COMPUTE_DEVICE
=
torch
.
device
(
"cpu"
)
def
get_compute_device
()
->
torch
.
device
:
def
get_compute_device
()
->
torch
.
device
:
global
_COMPUTE_DEVICE
global
_COMPUTE_DEVICE
return
_COMPUTE_DEVICE
return
_COMPUTE_DEVICE
\ No newline at end of file
_TEMP_AG_REMOTE_OBJECT
=
None
def
_remote_object
():
global
_TEMP_AG_REMOTE_OBJECT
return
_TEMP_AG_REMOTE_OBJECT
def
all_gather_remote_objects
(
obj
:
Any
)
->
List
[
rpc
.
RRef
]:
global
_TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT
=
rpc
.
RRef
(
obj
)
dist
.
barrier
()
world_size
=
dist
.
get_world_size
()
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
world_size
):
info
=
get_worker_info
(
i
)
futs
.
append
(
rpc
.
rpc_async
(
info
,
_remote_object
))
rrefs
:
List
[
rpc
.
RRef
]
=
[]
for
f
in
futs
:
f
.
wait
()
rrefs
.
append
(
f
.
value
())
dist
.
barrier
()
_TEMP_AG_REMOTE_OBJECT
=
None
return
rrefs
def
ccl_init_method
()
->
str
:
master_addr
=
os
.
environ
[
"MASTER_ADDR"
]
master_port
=
int
(
os
.
environ
[
"MASTER_PORT"
])
return
f
"tcp://{master_addr}:{master_port}"
def
rpc_init_method
()
->
str
:
master_addr
=
os
.
environ
[
"MASTER_ADDR"
]
master_port
=
int
(
os
.
environ
[
"MASTER_PORT"
])
return
f
"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
starrygl/parallel/degree.py
View file @
f4af3231
import
torch
#
import torch
from
torch
import
Tensor
#
from torch import Tensor
from
typing
import
*
#
from typing import *
from
torch_scatter
import
scatter_sum
#
from torch_scatter import scatter_sum
from
starrygl.core.route
import
Route
#
from starrygl.core.route import Route
def
compute_in_degree
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
dst_size
=
route
.
src_size
#
dst_size = route.src_size
x
=
torch
.
ones
(
edge_index
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
edge_index
.
device
)
#
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
in_deg
=
scatter_sum
(
x
,
edge_index
[
1
],
dim
=
0
,
dim_size
=
dst_size
)
#
in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
in_deg
,
_
=
route
.
forward_a2a
(
in_deg
)
#
in_deg, _ = route.forward_a2a(in_deg)
return
in_deg
#
return in_deg
def
compute_out_degree
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
src_size
=
route
.
dst_size
#
src_size = route.dst_size
x
=
torch
.
ones
(
edge_index
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
edge_index
.
device
)
#
x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
out_deg
=
scatter_sum
(
x
,
edge_index
[
0
],
dim
=
0
,
dim_size
=
src_size
)
#
out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
out_deg
,
_
=
route
.
backward_a2a
(
out_deg
)
#
out_deg, _ = route.backward_a2a(out_deg)
out_deg
,
_
=
route
.
forward_a2a
(
out_deg
)
#
out_deg, _ = route.forward_a2a(out_deg)
return
out_deg
#
return out_deg
def
compute_gcn_norm
(
edge_index
:
Tensor
,
route
:
Route
)
->
Tensor
:
#
def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
in_deg
=
compute_in_degree
(
edge_index
,
route
)
#
in_deg = compute_in_degree(edge_index, route)
out_deg
=
compute_out_degree
(
edge_index
,
route
)
#
out_deg = compute_out_degree(edge_index, route)
a
=
in_deg
[
edge_index
[
0
]]
.
pow
(
-
0.5
)
#
a = in_deg[edge_index[0]].pow(-0.5)
b
=
out_deg
[
edge_index
[
0
]]
.
pow
(
-
0.5
)
#
b = out_deg[edge_index[0]].pow(-0.5)
x
=
a
*
b
#
x = a * b
x
[
x
.
isinf
()]
=
0.0
#
x[x.isinf()] = 0.0
x
[
x
.
isnan
()]
=
0.0
#
x[x.isnan()] = 0.0
return
x
#
return x
train.py
View file @
f4af3231
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.distributed.rpc
as
rpc
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
*
from
typing
import
*
import
os
from
torch_sparse
import
SparseTensor
import
time
import
psutil
from
starrygl.nn
import
*
from
starrygl.graph
import
DistGraph
from
starrygl.loader
import
NodeLoader
,
NodeHandle
,
TensorBuffer
,
RouteContext
from
starrygl.parallel
import
init_process_group
,
convert_parallel_model
from
starrygl.parallel
import
init_process_group
,
convert_parallel_model
from
starrygl.
parallel
import
compute_gcn_norm
,
SyncBatchNorm
,
with_nccl
from
starrygl.
utils
import
partition_load
,
main_print
,
sync_print
from
starrygl.utils
import
train_epoch
,
eval_epoch
,
partition_load
,
main_print
class
SimpleGNNConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
linear
=
nn
.
Linear
(
in_channels
,
out_channels
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
)
->
Tensor
:
x
=
self
.
linear
(
x
)
x
=
adj_t
@
x
return
self
.
norm
(
x
)
class
SimpleGNN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
hidden_channels
:
int
,
out_channels
:
int
,
)
->
None
:
super
()
.
__init__
()
self
.
conv1
=
SimpleGNNConv
(
in_channels
,
hidden_channels
)
self
.
conv2
=
SimpleGNNConv
(
hidden_channels
,
hidden_channels
)
self
.
fc_out
=
nn
.
Linear
(
hidden_channels
,
out_channels
)
def
forward
(
self
,
handle
:
NodeHandle
,
buffers
:
List
[
TensorBuffer
])
->
Tensor
:
futs
=
[
handle
.
get_src_feats
(
buffers
[
0
]),
handle
.
get_ext_feats
(
buffers
[
1
]),
]
with
RouteContext
()
as
ctx
:
x
=
futs
[
0
]
.
wait
()
x
=
self
.
conv1
(
x
,
handle
.
adj_t
)
x
,
f
=
handle
.
push_and_pull
(
x
,
futs
[
1
],
buffers
[
1
])
x
=
self
.
conv2
(
x
,
handle
.
adj_t
)
ctx
.
add_futures
(
f
)
# 等当前batch推理完成后需要等待所有futures完成
x
=
self
.
fc_out
(
x
)
return
x
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# 启动分布式进程组,并分配计算设备
# 启动分布式进程组,并分配计算设备
device
=
init_process_group
(
backend
=
"nccl"
)
device
=
init_process_group
(
backend
=
"nccl"
)
# 加载数据集
# 加载数据集
pdata
=
partition_load
(
"./cora"
,
algo
=
"metis"
)
.
to
(
device
)
pdata
=
partition_load
(
"./cora"
,
algo
=
"metis"
)
g
=
DistGraph
(
ids
=
pdata
.
ids
,
edge_index
=
pdata
.
edge_index
)
loader
=
NodeLoader
(
pdata
.
ids
,
pdata
.
edge_index
,
device
)
# 创建历史缓存
hidden_size
=
64
buffers
:
List
[
TensorBuffer
]
=
[
TensorBuffer
(
pdata
.
num_features
,
loader
.
src_size
,
loader
.
route
),
TensorBuffer
(
hidden_size
,
loader
.
src_size
,
loader
.
route
),
]
# 设置节点初始特征,并预同步到其它分区
buffers
[
0
]
.
data
[:
loader
.
dst_size
]
=
pdata
.
x
buffers
[
0
]
.
broadcast
()
# g.args["async_op"] = True
# 创建模型
g
.
args
[
"sample_k"
]
=
20
net
=
SimpleGNN
(
pdata
.
num_features
,
hidden_size
,
pdata
.
num_classes
)
.
to
(
device
)
net
=
convert_parallel_model
(
net
)
opt
=
torch
.
optim
.
Adam
(
net
.
parameters
(),
lr
=
1e-3
)
# 训练阶段
for
ep
in
range
(
1
,
100
+
1
):
epoch_loss
=
0.0
net
.
train
()
for
handle
in
loader
.
iter
(
128
):
fut_m
=
handle
.
get_dst_feats
(
pdata
.
train_mask
)
fut_y
=
handle
.
get_dst_feats
(
pdata
.
y
)
h
=
net
(
handle
,
buffers
)
train_mask
=
fut_m
.
wait
()
logits
=
h
[
train_mask
]
if
logits
.
size
(
0
)
>
0
:
y
=
fut_y
.
wait
()[
train_mask
]
loss
=
nn
.
CrossEntropyLoss
()(
logits
,
y
)
opt
.
zero_grad
()
loss
.
backward
()
opt
.
step
()
epoch_loss
+=
loss
.
item
()
main_print
(
ep
,
epoch_loss
)
rpc
.
shutdown
()
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# import os
# import time
# import psutil
# from starrygl.nn import *
# from starrygl.graph import DistGraph
# from starrygl.parallel import init_process_group, convert_parallel_model
# from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
# from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
# if __name__ == "__main__":
# # 启动分布式进程组,并分配计算设备
# device = init_process_group(backend="nccl")
g
.
edata
[
"gcn_norm"
]
=
compute_gcn_norm
(
g
)
# # 加载数据集
g
.
ndata
[
"x"
]
=
pdata
.
x
# pdata = partition_load("./cora", algo="metis").to(device)
g
.
ndata
[
"y"
]
=
pdata
.
y
# g = DistGraph(ids=pdata.ids, edge_index=pdata.edge_index)
# 定义GAT图神经网络模型
# # g.args["async_op"] = True
net
=
GCN
(
# # g.args["num_samples"] = 20
g
=
g
,
layer_options
=
BasicLayerOptions
(
in_channels
=
pdata
.
num_features
,
hidden_channels
=
64
,
num_layers
=
2
,
out_channels
=
pdata
.
num_classes
,
norm
=
"batchnorm"
,
),
input_options
=
BasicInputOptions
(
straight_enabled
=
True
,
),
jk_options
=
BasicJKOptions
(
jk_mode
=
None
,
),
straight_options
=
BasicStraightOptions
(
enabled
=
True
,
),
)
.
to
(
device
)
# 转换成分布式并行版本
# g.edata["gcn_norm"] = compute_gcn_norm(g)
net
=
convert_parallel_model
(
net
)
# g.ndata["x"] = pdata.x
# g.ndata["y"] = pdata.y
# # 定义GAT图神经网络模型
# net = ShrinkGCN(
# g=g,
# layer_options=BasicLayerOptions(
# in_channels=pdata.num_features,
# hidden_channels=64,
# num_layers=3,
# out_channels=pdata.num_classes,
# norm="batchnorm",
# ),
# input_options=BasicInputOptions(
# straight_enabled=True,
# straight_num_samples = 200,
# ),
# straight_options=BasicStraightOptions(
# enabled=True,
# num_samples = 20,
# # beta=1.1,
# ),
# ).to(device)
# # 转换成分布式并行版本
# net = convert_parallel_model(net)
# 定义优化器
#
# 定义优化器
opt
=
torch
.
optim
.
Adam
(
net
.
parameters
(),
lr
=
0.01
,
weight_decay
=
5e-4
)
#
opt = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)
avg_mem
=
0.0
#
avg_mem = 0.0
avg_dur
=
0.0
#
avg_dur = 0.0
avg_num
=
0
#
avg_num = 0
# 开始训练
#
# 开始训练
best_val_acc
=
best_test_acc
=
0
#
best_val_acc = best_test_acc = 0
for
ep
in
range
(
1
,
1
0
+
1
):
# for ep in range(1, 5
0+1):
time_start
=
time
.
time
()
#
time_start = time.time()
train_loss
,
train_acc
=
train_epoch
(
net
,
opt
,
g
,
pdata
.
train_mask
)
#
train_loss, train_acc = train_epoch(net, opt, g, pdata.train_mask)
val_loss
,
val_acc
=
eval_epoch
(
net
,
g
,
pdata
.
val_mask
)
#
val_loss, val_acc = eval_epoch(net, g, pdata.val_mask)
test_loss
,
test_acc
=
eval_epoch
(
net
,
g
,
pdata
.
test_mask
)
#
test_loss, test_acc = eval_epoch(net, g, pdata.test_mask)
# val_loss, val_acc = train_loss, train_acc
#
# val_loss, val_acc = train_loss, train_acc
# test_loss, test_acc = train_loss, train_acc
#
# test_loss, test_acc = train_loss, train_acc
if
val_acc
>
best_val_acc
:
#
if val_acc > best_val_acc:
best_val_acc
=
val_acc
#
best_val_acc = val_acc
best_test_acc
=
test_acc
#
best_test_acc = test_acc
duration
=
time
.
time
()
-
time_start
#
duration = time.time() - time_start
if
with_nccl
():
#
if with_nccl():
cur_mem
=
torch
.
cuda
.
memory_reserved
()
#
cur_mem = torch.cuda.memory_reserved()
else
:
#
else:
cur_mem
=
psutil
.
Process
(
os
.
getpid
())
.
memory_info
()
.
rss
#
cur_mem = psutil.Process(os.getpid()).memory_info().rss
cur_mem_mb
=
round
(
cur_mem
/
1024
**
2
)
#
cur_mem_mb = round(cur_mem / 1024**2)
if
ep
>
1
:
#
if ep > 1:
avg_mem
+=
cur_mem
#
avg_mem += cur_mem
avg_dur
+=
duration
#
avg_dur += duration
avg_num
+=
1
#
avg_num += 1
main_print
(
#
main_print(
f
"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
#
f"ep: {ep}, mem: {cur_mem_mb}MiB, duration: {duration:.2f}s, "
f
"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
#
f"loss: [{train_loss:.4f}/{val_loss:.4f}/{test_loss:.6f}], "
f
"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
#
f"accuracy: [{train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}], "
f
"best_accuracy: {best_test_acc:.4f}"
)
#
f"best_accuracy: {best_test_acc:.4f}")
avg_mem
=
round
(
avg_mem
/
avg_num
/
1024
**
2
)
#
avg_mem = round(avg_mem / avg_num / 1024**2)
avg_dur
=
avg_dur
/
avg_num
#
avg_dur = avg_dur / avg_num
main_print
(
f
"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s"
)
#
main_print(f"average memory: {avg_mem}MiB, average duration: {avg_dur:.2f}s")
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment