Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
506058f4
Commit
506058f4
authored
Mar 26, 2025
by
zhlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fist
parent
911d3eb8
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
examples/train_boundery.py
+8
-5
No files found.
examples/train_boundery.py
View file @
506058f4
...
@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=1, type=float, metavar='W',
...
@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=1, type=float, metavar='W',
help
=
'name of model'
)
help
=
'name of model'
)
parser
.
add_argument
(
'--sample_type'
,
default
=
'recent'
,
type
=
str
,
metavar
=
'W'
,
parser
.
add_argument
(
'--sample_type'
,
default
=
'recent'
,
type
=
str
,
metavar
=
'W'
,
help
=
'name of model'
)
help
=
'name of model'
)
parser
.
add_argument
(
'--local_neg_sample'
,
default
=
False
,
type
=
bool
,
metavar
=
'W'
,
parser
.
add_argument
(
'--local_neg_sample'
,
default
=
'local'
,
type
=
str
,
metavar
=
'W'
,
help
=
'name of model'
)
help
=
'name of model'
)
parser
.
add_argument
(
'--shared_memory_ssim'
,
default
=
2
,
type
=
float
,
metavar
=
'W'
,
parser
.
add_argument
(
'--shared_memory_ssim'
,
default
=
2
,
type
=
float
,
metavar
=
'W'
,
help
=
'name of model'
)
help
=
'name of model'
)
...
@@ -275,10 +275,13 @@ def main():
...
@@ -275,10 +275,13 @@ def main():
neg_samples
=
args
.
eval_neg_samples
neg_samples
=
args
.
eval_neg_samples
mask
=
DistIndex
(
graph
.
nids_mapper
[
graph
.
edge_index
[
1
,:]]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()
mask
=
DistIndex
(
graph
.
nids_mapper
[
graph
.
edge_index
[
1
,:]]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()
if
args
.
local_neg_sample
:
if
args
.
local_neg_sample
==
'local'
:
print
(
'dst len {} origin len {}'
.
format
(
graph
.
edge_index
[
1
,
mask
]
.
unique
()
.
shape
[
0
],
full_dst
.
unique
()
.
shape
[
0
]))
print
(
'dst len {} origin len {}'
.
format
(
graph
.
edge_index
[
1
,
mask
]
.
unique
()
.
shape
[
0
],
full_dst
.
unique
()
.
shape
[
0
]))
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
graph
.
edge_index
[
1
,
mask
]
.
unique
())
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
graph
.
edge_index
[
1
,
mask
]
.
unique
())
elif
args
.
local_neg_sample
==
'all'
:
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()),
prob
=
1
)
train_ratio_pos
=
1
train_ratio_neg
=
1
else
:
else
:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()),
prob
=
args
.
probability
)
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()),
prob
=
args
.
probability
)
...
@@ -300,7 +303,7 @@ def main():
...
@@ -300,7 +303,7 @@ def main():
mode
=
'train'
,
mode
=
'train'
,
queue_size
=
200
,
queue_size
=
200
,
mailbox
=
mailbox
,
mailbox
=
mailbox
,
is_pipeline
=
Tru
e
,
is_pipeline
=
Fals
e
,
use_local_feature
=
False
,
use_local_feature
=
False
,
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
probability
=
args
.
probability
,
probability
=
args
.
probability
,
...
@@ -561,7 +564,7 @@ def main():
...
@@ -561,7 +564,7 @@ def main():
#print(time_count.elapsed_event(t2))
#print(time_count.elapsed_event(t2))
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
if
args
.
local_neg_sample
is
False
:
if
args
.
local_neg_sample
!=
'local'
:
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
train_ratio_pos
,
ones
*
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
train_ratio_pos
,
ones
*
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
loss
+=
neg_creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
loss
+=
neg_creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment