Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
df36000d
Commit
df36000d
authored
Feb 20, 2025
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix train_remote_pos and train_remote_neg delay
parent
7ae21c05
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
12 deletions
+24
-12
examples/test_all.sh
+1
-1
examples/train_boundery.py
+8
-8
starrygl/module/utils.py
+12
-3
starrygl/sample/sample_core/neighbor_sampler.py
+3
-0
No files found.
examples/test_all.sh
View file @
df36000d
...
@@ -10,7 +10,7 @@ partitions="8"
...
@@ -10,7 +10,7 @@ partitions="8"
node_per
=
"4"
node_per
=
"4"
nnodes
=
"2"
nnodes
=
"2"
node_rank
=
"0"
node_rank
=
"0"
probability_params
=(
"
0.
1"
)
probability_params
=(
"1"
)
sample_type_params
=(
"boundery_recent_decay"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
...
...
examples/train_boundery.py
View file @
df36000d
...
@@ -563,22 +563,22 @@ def main():
...
@@ -563,22 +563,22 @@ def main():
ones
=
torch
.
ones
(
metadata
[
'dst_neg_index'
]
.
shape
[
0
],
device
=
model
.
device
,
dtype
=
torch
.
float
)
ones
=
torch
.
ones
(
metadata
[
'dst_neg_index'
]
.
shape
[
0
],
device
=
model
.
device
,
dtype
=
torch
.
float
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
,
neg_samples
=
args
.
neg_samples
,
async_param
=
param
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
,
neg_samples
=
args
.
neg_samples
,
async_param
=
param
)
ada_param
.
update_gnn_aggregate_time
(
ada_param
.
last_start_event_gnn_aggregate
)
ada_param
.
update_gnn_aggregate_time
(
ada_param
.
last_start_event_gnn_aggregate
)
if
len
(
trainloader
.
result_queue
)
>
0
:
batch_data
,
dist_nid
,
dist_eid
,
edge_feat
,
node_feat0
=
trainloader
.
result_queue
[
0
]
edge_feat
[
1
]
.
wait
()
node_feat0
[
1
]
.
wait
()
if
ada_param
is
not
None
:
ada_param
.
update_fetch_time
(
ada_param
.
last_start_event_fetch
)
ada_param
.
update_parameter
()
#print(time_count.elapsed_event(t2))
#print(time_count.elapsed_event(t2))
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
if
args
.
local_neg_sample
is
False
:
if
args
.
local_neg_sample
is
False
:
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
train_neg_sampler
.
train_ratio_pos
,
ones
*
train_neg_sampler
.
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
metadata
[
'train_ratio_pos'
],
ones
*
metadata
[
'train_ratio_neg'
]
)
.
reshape
(
-
1
,
1
)
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
loss
+=
neg_creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
loss
+=
neg_creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
else
:
else
:
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
total_loss
+=
float
(
loss
.
item
())
total_loss
+=
float
(
loss
.
item
())
if
len
(
trainloader
.
result_queue
)
>
0
:
_
,
_
,
_
,
edge_feat
,
node_feat0
=
trainloader
.
result_queue
[
0
]
edge_feat
[
1
]
.
wait
()
node_feat0
[
1
]
.
wait
()
if
ada_param
is
not
None
:
ada_param
.
update_fetch_time
(
ada_param
.
last_start_event_fetch
)
ada_param
.
update_parameter
()
#mailbox.handle_last_async()
#mailbox.handle_last_async()
#trainloader.async_feature()
#trainloader.async_feature()
#torch.cuda.synchronize()
#torch.cuda.synchronize()
...
...
starrygl/module/utils.py
View file @
df36000d
...
@@ -2,6 +2,8 @@ import yaml
...
@@ -2,6 +2,8 @@ import yaml
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
math
import
math
from
starrygl.distributed.context
import
DistributedContext
def
parse_config
(
f
):
def
parse_config
(
f
):
conf
=
yaml
.
safe_load
(
open
(
f
,
'r'
))
conf
=
yaml
.
safe_load
(
open
(
f
,
'r'
))
sample_param
=
conf
[
'sampling'
][
0
]
sample_param
=
conf
[
'sampling'
][
0
]
...
@@ -135,7 +137,7 @@ class AdaParameter:
...
@@ -135,7 +137,7 @@ class AdaParameter:
self
.
count_gnn_aggregate
=
0
self
.
count_gnn_aggregate
=
0
def
update_parameter
(
self
):
def
update_parameter
(
self
):
print
(
'beta is {} alpha is {}
\n
'
.
format
(
self
.
beta
,
self
.
alpha
))
#
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
#if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
# return
# return
if
self
.
end_event_fetch
is
None
or
self
.
end_event_memory_sync
is
None
or
self
.
end_event_memory_update
is
None
or
self
.
end_event_gnn_aggregate
is
None
:
if
self
.
end_event_fetch
is
None
or
self
.
end_event_memory_sync
is
None
or
self
.
end_event_memory_update
is
None
or
self
.
end_event_gnn_aggregate
is
None
:
...
@@ -160,8 +162,15 @@ class AdaParameter:
...
@@ -160,8 +162,15 @@ class AdaParameter:
self
.
alpha
=
(
2
-
math
.
pow
((
2
-
self
.
alpha
),
average_memory_update_time
/
average_memory_sync_time
*
(
1
+
self
.
wait_threshold
)))
self
.
alpha
=
(
2
-
math
.
pow
((
2
-
self
.
alpha
),
average_memory_update_time
/
average_memory_sync_time
*
(
1
+
self
.
wait_threshold
)))
self
.
beta
=
max
(
min
(
self
.
beta
,
self
.
max_beta
),
self
.
min_beta
)
self
.
beta
=
max
(
min
(
self
.
beta
,
self
.
max_beta
),
self
.
min_beta
)
self
.
alpha
=
max
(
min
(
self
.
alpha
,
self
.
max_alpha
),
self
.
min_alpha
)
self
.
alpha
=
max
(
min
(
self
.
alpha
,
self
.
max_alpha
),
self
.
min_alpha
)
print
(
'gnn aggregate {} fetch {} memory sync {} memory update {}'
.
format
(
average_gnn_aggregate
,
average_fetch
,
average_memory_sync_time
,
average_memory_update_time
))
ctx
=
DistributedContext
.
get_default_context
()
print
(
'beta is {} alpha is {}
\n
'
.
format
(
self
.
beta
,
self
.
alpha
))
beta_comm
=
torch
.
tensor
([
self
.
beta
])
torch
.
distributed
.
all_reduce
(
beta_comm
,
group
=
ctx
.
gloo_group
)
self
.
beta
=
beta_comm
[
0
]
.
item
()
alpha_comm
=
torch
.
tensor
([
self
.
alpha
])
torch
.
distributed
.
all_reduce
(
alpha_comm
,
group
=
ctx
.
gloo_group
)
self
.
alpha
=
alpha_comm
[
0
]
.
item
()
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time()
#self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
...
...
starrygl/sample/sample_core/neighbor_sampler.py
View file @
df36000d
...
@@ -332,6 +332,9 @@ class NeighborSampler(BaseSampler):
...
@@ -332,6 +332,9 @@ class NeighborSampler(BaseSampler):
metadata
[
'seed_ts'
]
=
seed_ts
metadata
[
'seed_ts'
]
=
seed_ts
metadata
[
'src_pos_index'
]
=
src_pos_index
metadata
[
'src_pos_index'
]
=
src_pos_index
metadata
[
'dst_pos_index'
]
=
dst_pos_index
metadata
[
'dst_pos_index'
]
=
dst_pos_index
if
'train_ratio_pos'
in
neg_sampling
.
__dict__
:
metadata
[
'train_ratio_pos'
]
=
neg_sampling
.
train_ratio_pos
metadata
[
'train_ratio_neg'
]
=
neg_sampling
.
train_ratio_neg
if
neg_sampling
is
not
None
:
if
neg_sampling
is
not
None
:
metadata
[
'dst_neg_index'
]
=
dst_neg_index
metadata
[
'dst_neg_index'
]
=
dst_neg_index
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment