Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
ef367556
Commit
ef367556
authored
Nov 06, 2024
by
zhlj
Browse files
Options
Browse Files
Download
Plain Diff
fix bugs in jodie and APAN
parents
4927a9e0
ce4b726f
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
85 additions
and
53 deletions
+85
-53
bts-mtgnn.md
+3
-0
config/APAN.yml
+2
-1
examples/a.out
+0
-0
examples/test_JODIE.sh
+7
-0
examples/test_all.sh
+8
-29
examples/train_boundery.py
+3
-2
starrygl/module/memorys.py
+8
-3
starrygl/sample/batch_data.py
+5
-0
starrygl/sample/memory/change.py
+8
-6
starrygl/sample/memory/shared_mailbox.py
+39
-11
starrygl/sample/sample_core/neighbor_sampler.py
+2
-1
No files found.
bts-mtgnn.md
0 → 100644
View file @
ef367556
# Introduction
·
\ No newline at end of file
config/APAN.yml
View file @
ef367556
...
...
@@ -22,7 +22,7 @@ gnn:
train
:
-
epoch
:
100
batch_size
:
1000
lr
:
0.000
1
lr
:
0.000
2
dropout
:
0.1
att_dropout
:
0.1
# all_on_gpu: True
\ No newline at end of file
examples/a.out
View file @
ef367556
This diff is collapsed.
Click to expand it.
examples/test_JODIE.sh
0 → 100644
View file @
ef367556
bash test_all.sh 12347
wait
bash test_all.sh 12357
wait
bash test_all.sh 63457
wait
\ No newline at end of file
examples/test_all.sh
View file @
ef367556
...
...
@@ -2,41 +2,24 @@
#跑了4卡的TaoBao
# 定义数组变量
seed
=
$1
addr
=
"192.168.1.10
6
"
addr
=
"192.168.1.10
5
"
partition_params
=(
"ours"
)
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions
=
"
8
"
partitions
=
"
4
"
node_per
=
"4"
nnodes
=
"2"
<<<<<<
< Updated upstream
nnodes
=
"1"
node_rank
=
"0"
probability_params
=(
"0.1"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type
=(
"historical"
)
memory_type
=(
"historical"
)
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<
< HEAD
data_param
=(
"GDELT"
)
=======
data_param
=(
"LASTFM"
)
>>>>>>>
8233776274204f6cf2f8a2eb37022d426d6197d8
=======
node_rank
=
"1"
probability_params
=(
"0.1"
"0.05"
"0.01"
"0"
)
sample_type_params
=(
"recent"
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local")
memory_type
=(
"historical"
"all_update"
"local"
)
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
"0.7"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param
=(
"WIKI"
"LASTFM"
"WikiTalk"
"StackOverflow"
"GDELT"
"TaoBao"
)
>>>>>>>
Stashed changes
data_param
=(
"WIKI"
"LASTFM"
"WikiTalk"
"StackOverflow"
"GDELT"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
...
...
@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir
-p
all_
"
$seed
"
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"
APAN
"
model
=
"
JODIE
"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"
APAN
"
model
=
"
JODIE
"
fi
#model="APAN"
mkdir all_
"
$seed
"
/
"
$data
"
...
...
@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do
if
[
"
$mem
"
=
"historical"
]
;
then
for
ssim
in
"
${
shared_memory_ssim
[@]
}
"
;
do
if
[
"
$partition
"
=
"ours"
]
;
then
<<<<<<
< Updated upstream
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--shared_memory_ssim
"
$ssim
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$ssim
"
-
"
$sample
"
-
"
$pro
"
.out &
=======
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--shared_memory_ssim
"
$ssim
"
>
all/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$ssim
"
-
"
$sample
"
-
"
$pro
"
.out &
>>>>>>>
Stashed changes
wait
fi
done
...
...
examples/train_boundery.py
View file @
ef367556
...
...
@@ -229,7 +229,8 @@ def main():
fanout
=
sample_param
[
'neighbor'
]
if
'neighbor'
in
sample_param
else
[
10
]
policy
=
sample_param
[
'strategy'
]
if
'strategy'
in
sample_param
else
'recent'
no_neg
=
sample_param
[
'no_neg'
]
if
'no_neg'
in
sample_param
else
False
if
policy
!=
'recent'
:
print
(
policy
)
if
policy
==
'recent'
:
policy_train
=
args
.
sample_type
#'boundery_recent_decay'
else
:
policy_train
=
policy
...
...
@@ -480,7 +481,7 @@ def main():
val_list
=
[]
loss_list
=
[]
for
e
in
range
(
train_param
[
'epoch'
]):
model
.
module
.
memory_updater
.
empty_cache
()
#
model.module.memory_updater.empty_cache()
tt
.
_zero
()
torch
.
cuda
.
synchronize
()
epoch_start_time
=
time
.
time
()
...
...
starrygl/module/memorys.py
View file @
ef367556
...
...
@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module):
def
forward
(
self
,
b
,
param
=
None
):
Q
=
self
.
w_q
(
b
.
srcdata
[
'mem'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
att_h
,
-
1
))
mails
=
b
.
srcdata
[
'mem_input'
]
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
#print(mails.shape,b.srcdata['mem_input'].shape,b.srcdata['mail_ts'].shape)
if
self
.
dim_time
>
0
:
time_feat
=
self
.
time_enc
(
b
.
srcdata
[
'ts'
][:,
None
]
-
b
.
srcdata
[
'mail_ts'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
#print(time_feat.shape)
mails
=
torch
.
cat
([
mails
,
time_feat
],
dim
=
2
)
#print(mails.shape)
K
=
self
.
w_k
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
V
=
self
.
w_v
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
att
=
self
.
att_act
((
Q
[:,
None
,:,:]
*
K
)
.
sum
(
dim
=
3
))
...
...
@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self
.
mailbox
.
handle_last_async
()
submit_to_queue
=
False
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
submit_to_queue
=
True
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
...
...
@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit
=
submit_to_queue
,
spread_mail
=
spread_mail
,
update_cross_mm
=
False
,
)
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
def
local_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
if
nxt_fetch_func
is
not
None
:
...
...
@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
shared_ind
=
self
.
mailbox
.
is_shared_mask
[
DistIndex
(
b
.
srcdata
[
'ID'
][
mask
])
.
loc
]
transition_dense
=
b
.
srcdata
[
'his_mem'
][
mask
]
+
self
.
filter
.
get_incretment
(
shared_ind
)
#print(transition_dense.shape)
if
not
(
transition_dense
.
max
()
.
item
()
==
0
):
transition_dense
-=
transition_dense
.
min
()
transition_dense
/=
transition_dense
.
max
()
...
...
@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
local_mask
=
(
DistIndex
(
index
)
.
part
==
torch
.
distributed
.
get_rank
())
local_mask_mail
=
(
DistIndex
(
index0
)
.
part
==
torch
.
distributed
.
get_rank
())
#
self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self
.
mailbox
.
set_mailbox_local
(
DistIndex
(
index0
[
local_mask_mail
])
.
loc
,
mail
[
local_mask_mail
],
mail_ts
[
local_mask_mail
],
Reduce_Op
=
'max'
)
self
.
mailbox
.
set_memory_local
(
DistIndex
(
index
[
local_mask
])
.
loc
,
memory
[
local_mask
],
memory_ts
[
local_mask
],
Reduce_Op
=
'max'
)
is_deliver
=
(
self
.
mailbox
.
deliver_to
==
'neighbors'
)
self
.
update_hunk
(
index
,
memory
,
memory_ts
,
index0
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
is_deliver
)
...
...
starrygl/sample/batch_data.py
View file @
ef367556
...
...
@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
else
:
metadata
=
None
nid_mapper
:
torch
.
Tensor
=
graph
.
nids_mapper
#print('reverse block {}\n'.format(identity))
if
identity
is
False
:
assert
len
(
sample_out
)
==
1
ret
=
sample_out
[
0
]
...
...
@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
src_index
=
ret
.
src_index
()
.
to
(
device
)
else
:
#print('is jodie')
#print(sample_out)
src_index
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dst
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
...
...
@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
row_len
=
root_len
col
=
first_block_id
[:
row_len
]
max_row
=
col
.
max
()
.
item
()
+
1
#print(src_index,dst)
b
=
dgl
.
create_block
((
col
[
src_index
]
.
to
(
device
),
torch
.
arange
(
dst
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)),
num_src_nodes
=
first_block_id
.
max
()
.
item
()
+
1
,
num_dst_nodes
=
dst
.
shape
[
0
])
...
...
@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
t_s
=
time
.
time
()
param
=
{
'is_unique'
:
False
,
'nid_mapper'
:
nid_mapper
,
'eid_mapper'
:
eid_mapper
,
'out_device'
:
out_device
}
out
=
sample_fn
(
sampler
,
data
,
neg_sampling
,
**
param
)
#print(sampler.policy)
if
reversed
is
False
:
out
,
dist_nid
,
dist_eid
=
to_block
(
graph
,
data
,
out
,
out_device
)
else
:
...
...
starrygl/sample/memory/change.py
View file @
ef367556
...
...
@@ -18,9 +18,10 @@ class MemoryMoniter:
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid)
def
draw
(
self
,
degree
,
data
,
model
,
e
):
torch
.
save
(
self
.
nid_list
,
'all_args.seed/{}/{}/memorynid_{}.pt'
.
format
(
data
,
model
,
e
))
torch
.
save
(
self
.
memorychange
,
'all_args.seed/{}/{}/memoryF_{}.pt'
.
format
(
data
,
model
,
e
))
torch
.
save
(
self
.
memory_ssim
,
'all_args.seed/{}/{}/memcos_{}.pt'
.
format
(
data
,
model
,
e
))
pass
#torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
#torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
#torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data)
# if not os.path.exists(path):
...
...
@@ -87,6 +88,7 @@ class MemoryMoniter:
def
set_zero
(
self
):
self
.
memorychange
=
[]
self
.
nid_list
=
[]
self
.
memory_ssim
=
[]
pass
#self.memorychange = []
#self.nid_list =[]
#self.memory_ssim = []
starrygl/sample/memory/shared_mailbox.py
View file @
ef367556
...
...
@@ -146,7 +146,7 @@ class SharedMailBox():
source_ts
=
max_ts
source
=
source
[
id
]
index
=
unq_id
#print(self.next_mail_pos[index])
self
.
mailbox_ts
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source_ts
self
.
mailbox
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source
if
self
.
memory_param
[
'mailbox_size'
]
>
1
:
...
...
@@ -180,9 +180,23 @@ class SharedMailBox():
if
self
.
deliver_to
==
'neighbors'
:
assert
block
is
not
None
and
Reduce_score
is
None
mail
=
torch
.
cat
([
mail
,
mail
[
block
.
edges
()[
0
]
.
long
()]],
dim
=
0
)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
block
.
edges
()[
0
]
.
long
()]],
dim
=
0
)
# print(block.edges().shape)
root
=
torch
.
cat
([
src
,
dst
])
.
reshape
(
-
1
)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos
,
idx
=
torch_scatter
.
scatter_max
(
mail_ts
,
root
,
0
)
mail
=
torch
.
cat
([
mail
,
mail
[
idx
]],
dim
=
0
)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
idx
]],
dim
=
0
)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index
=
torch
.
cat
([
index
,
block
.
dstdata
[
'ID'
][
block
.
edges
()[
1
]
.
long
()]],
dim
=
0
)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if
Reduce_score
is
not
None
:
Reduce_score
=
torch
.
cat
((
Reduce_score
,
Reduce_score
),
-
1
)
.
to
(
self
.
device
)
if
Reduce_score
is
None
:
...
...
@@ -192,12 +206,19 @@ class SharedMailBox():
mail
=
mail
[
idx
]
index
=
unq_index
else
:
unq_index
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
uni
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
perm
=
inv
.
new_empty
(
uni
.
size
(
0
))
.
scatter_
(
0
,
inv
,
perm
)
index
=
index
[
perm
]
mail
=
mail
[
perm
]
mail_ts
=
mail_ts
[
perm
]
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_score
,
idx
=
torch_scatter
.
scatter_max
(
Reduce_score
,
inv
,
0
)
mail_ts
=
mail_ts
[
idx
]
mail
=
mail
[
idx
]
index
=
unq_index
#max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
#mail_ts = mail_ts[idx]
#mail = mail[idx]
#index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return
index
,
mail
,
mail_ts
def
get_update_memory
(
self
,
index
,
memory
,
memory_ts
,
embedding
):
...
...
@@ -205,7 +226,8 @@ class SharedMailBox():
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
memory_ts
,
inv
,
0
)
ts
=
max_ts
index
=
unq_index
memory
=
memory
[
idx
]
memory
=
memory
[
idx
]
#print('memory {} {}\n'.format(index.shape,memory.shape,ts.shape))
return
index
,
memory
,
ts
def
pack
(
self
,
memory
=
None
,
memory_ts
=
None
,
mail
=
None
,
mail_ts
=
None
,
index
=
None
,
mode
=
None
):
...
...
@@ -250,6 +272,7 @@ class SharedMailBox():
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_mail
,
gather_mail_ts
,
Reduce_Op
=
reduce_Op
)
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_mail
(
self
,
reduce_Op
=
None
,):
if
self
.
last_mail_sync
is
not
None
:
...
...
@@ -260,6 +283,7 @@ class SharedMailBox():
if
isinstance
(
gather_memory
,
list
):
gather_memory
=
torch
.
cat
(
gather_memory
,
dim
=
0
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_async
(
self
,
reduce_Op
=
None
):
self
.
handle_last_memory
(
reduce_Op
)
...
...
@@ -303,6 +327,7 @@ class SharedMailBox():
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_mail_job
self
.
next_wait_mail_job
=
None
#print(index,gather_id_list)
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
...
...
@@ -409,6 +434,7 @@ class SharedMailBox():
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
ctx
=
DistributedContext
.
get_default_context
()
...
...
@@ -483,7 +509,7 @@ class SharedMailBox():
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
shared_memory
=
shared_memory
[
idx
]
shared_memory
=
shared_memory
[
idx
]
#
shared_memory = shared_memory[idx]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail
=
shared_mail
[
idx
]
...
...
@@ -495,11 +521,13 @@ class SharedMailBox():
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_mail_indx
],
shared_mail
,
shared_mail_ts
)
else
:
self
.
next_wait_gather_memory_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
not
wait_submit
:
self
.
update_shared
()
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
self
.
sychronize_shared
()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
...
...
starrygl/sample/sample_core/neighbor_sampler.py
View file @
ef367556
...
...
@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler):
"""
if
self
.
no_neg
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]
//
3
*
2
],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
],
is_unique
=
False
)
else
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
]
,
is_unique
=
False
)
out
,
metadata
=
self
.
sample_from_nodes
(
seed
,
seed_ts
,
is_unique
=
False
)
src_pos_index
=
torch
.
arange
(
0
,
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
dst_pos_index
=
torch
.
arange
(
num_pos
,
2
*
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment