Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
ff19c482
Commit
ff19c482
authored
Oct 04, 2024
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
test
parent
d1128852
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
11 deletions
+19
-11
examples/train_boundery.py
+13
-7
starrygl/module/memorys.py
+2
-2
starrygl/sample/sample_core/LocalNegSampling.py
+4
-2
No files found.
examples/train_boundery.py
View file @
ff19c482
...
...
@@ -111,6 +111,7 @@ if not 'MASTER_ADDR' in os.environ:
os
.
environ
[
"MASTER_ADDR"
]
=
'192.168.2.107'
if
not
'MASTER_PORT'
in
os
.
environ
:
os
.
environ
[
"MASTER_PORT"
]
=
'9337'
os
.
environ
[
"NCCL_IB_DISABLE"
]
=
'1'
os
.
environ
[
'NCCL_SOCKET_IFNAME'
]
=
matching_interfaces
[
0
]
print
(
'rank {}'
.
format
(
int
(
os
.
environ
[
"LOCAL_RANK"
])))
...
...
@@ -262,7 +263,7 @@ def main():
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
graph
.
edge_index
[
1
,
mask
]
.
unique
())
else
:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()))
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
())
,
prob
=
args
.
probability
)
print
(
train_neg_sampler
.
dst_node_list
)
neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
seed
=
6773
)
...
...
@@ -279,7 +280,7 @@ def main():
is_pipeline
=
True
,
use_local_feature
=
False
,
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
probability
=
0.1
,
#train_cross_
probability,
probability
=
args
.
probability
,
reversed
=
(
gnn_param
[
'arch'
]
==
'identity'
)
)
...
...
@@ -457,12 +458,11 @@ def main():
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
train_param
[
'lr'
],
weight_decay
=
1e-4
)
early_stopper
=
EarlyStopMonitor
(
max_round
=
args
.
patience
)
MODEL_SAVE_PATH
=
f
'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
total_test_time
=
0
epoch_cnt
=
0
test_ap_list
=
[]
val_list
=
[]
loss_list
=
[]
def
fetch_async
():
trainloader
.
async_feature
()
for
e
in
range
(
train_param
[
'epoch'
]):
model
.
module
.
memory_updater
.
empty_cache
()
tt
.
_zero
()
...
...
@@ -616,9 +616,15 @@ def main():
tt
.
weight_count_remote
=
0
tt
.
ssim_cnt
=
0
ap
,
auc
=
eval
(
'val'
)
torch
.
cuda
.
synchronize
()
t_test
=
time
.
time
()
test_ap
,
test_auc
=
eval
(
'test'
)
torch
.
cuda
.
synchronize
()
t_test
=
time
.
time
()
-
t_test
total_test_time
+=
t_test
test_ap_list
.
append
((
test_ap
,
test_auc
))
early_stop
=
early_stopper
.
early_stop_check
(
ap
)
early_stopper
.
early_stop_check
(
ap
)
early_stop
=
False
trainloader
.
local_node
=
0
trainloader
.
remote_node
=
0
trainloader
.
local_edge
=
0
...
...
@@ -647,7 +653,7 @@ def main():
break
else
:
print
(
'
\t
train loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f} test ap {:4f} test auc{:4f}
\n
'
.
format
(
total_loss
,
train_ap
,
ap
,
auc
,
test_ap
,
test_auc
))
print
(
'
\t
total time:{:.2f}s prep time:{:.2f}s
\n
'
.
format
(
time
.
time
()
-
epoch_start_time
,
time_prep
)
)
print
(
'
\t
total time:{:.2f}s prep time:{:.2f}s
\n
test time {:.2f}'
.
format
(
time
.
time
()
-
epoch_start_time
,
time_prep
),
t_test
)
torch
.
save
(
model
.
module
.
state_dict
(),
get_checkpoint_path
(
e
))
if
args
.
model
==
'TGN'
:
print
(
'weight {} {}
\n
'
.
format
(
tt
.
weight_count_local
,
tt
.
weight_count_remote
))
...
...
@@ -665,7 +671,7 @@ def main():
print
(
'best test AP:{:4f} test auc{:4f}'
.
format
(
*
test_ap_list
[
early_stopper
.
best_epoch
]))
val_list
=
torch
.
tensor
(
val_list
)
loss_list
=
torch
.
tensor
(
loss_list
)
print
(
'test_dataset {} avg_time {}
\n
'
.
format
(
test_data
.
edges
.
shape
[
1
],
avg
_time
/
epoch_cnt
))
print
(
'test_dataset {} avg_time {}
test time {}
\n
'
.
format
(
test_data
.
edges
.
shape
[
1
],
avg_time
/
epoch_cnt
,
total_test
_time
/
epoch_cnt
))
torch
.
save
(
model
.
module
.
state_dict
(),
MODEL_SAVE_PATH
)
ctx
.
shutdown
()
...
...
starrygl/module/memorys.py
View file @
ff19c482
...
...
@@ -464,8 +464,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
change
=
2
*
change
-
1
self
.
filter
.
update
(
shared_ind
,
change
)
#print(transition_dense)
print
(
torch
.
cosine_similarity
(
updated_memory
[
mask
],
b
.
srcdata
[
'his_mem'
][
mask
])
.
sum
()
/
torch
.
sum
(
mask
))
print
(
self
.
gamma
)
#
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
#
print(self.gamma)
self
.
pre_mem
=
b
.
srcdata
[
'his_mem'
]
self
.
last_updated_ts
=
b
.
srcdata
[
'ts'
]
.
detach
()
.
clone
()
self
.
last_updated_memory
=
updated_memory
.
detach
()
.
clone
()
...
...
starrygl/sample/sample_core/LocalNegSampling.py
View file @
ff19c482
...
...
@@ -21,7 +21,8 @@ class LocalNegativeSampling(NegativeSampling):
src_node_list
:
torch
.
Tensor
=
None
,
dst_node_list
:
torch
.
Tensor
=
None
,
local_mask
=
None
,
seed
=
None
seed
=
None
,
prob
=
None
):
super
(
LocalNegativeSampling
,
self
)
.
__init__
(
mode
,
amount
,
unique
=
unique
)
self
.
src_node_list
=
src_node_list
.
to
(
'cpu'
)
if
src_node_list
is
not
None
else
None
...
...
@@ -37,6 +38,7 @@ class LocalNegativeSampling(NegativeSampling):
self
.
local_mask
=
local_mask
if
self
.
local_mask
is
not
None
:
self
.
local_dst
=
dst_node_list
[
local_mask
]
self
.
prob
=
prob
#self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list))
def
is_binary
(
self
)
->
bool
:
...
...
@@ -61,7 +63,7 @@ class LocalNegativeSampling(NegativeSampling):
p
=
torch
.
rand
(
size
=
(
num_samples
,))
sr
=
self
.
dst_node_list
[
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
sl
=
self
.
local_dst
[
torch
.
randint
(
len
(
self
.
local_dst
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
s
=
torch
.
where
(
p
<=
1
,
sl
,
sr
)
s
=
torch
.
where
(
p
<=
self
.
prob
,
sl
,
sr
)
return
sr
else
:
s
=
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment