Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
ef367556
Commit
ef367556
authored
Nov 06, 2024
by
zhlj
Browse files
Options
Browse Files
Download
Plain Diff
fix bugs in jodie and APAN
parents
4927a9e0
ce4b726f
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
85 additions
and
53 deletions
+85
-53
bts-mtgnn.md
+3
-0
config/APAN.yml
+2
-1
examples/a.out
+0
-0
examples/test_JODIE.sh
+7
-0
examples/test_all.sh
+8
-29
examples/train_boundery.py
+3
-2
starrygl/module/memorys.py
+8
-3
starrygl/sample/batch_data.py
+5
-0
starrygl/sample/memory/change.py
+8
-6
starrygl/sample/memory/shared_mailbox.py
+39
-11
starrygl/sample/sample_core/neighbor_sampler.py
+2
-1
No files found.
bts-mtgnn.md
0 → 100644
View file @
ef367556
# Introduction
·
\ No newline at end of file
config/APAN.yml
View file @
ef367556
...
@@ -22,7 +22,7 @@ gnn:
...
@@ -22,7 +22,7 @@ gnn:
train
:
train
:
-
epoch
:
100
-
epoch
:
100
batch_size
:
1000
batch_size
:
1000
lr
:
0.000
1
lr
:
0.000
2
dropout
:
0.1
dropout
:
0.1
att_dropout
:
0.1
att_dropout
:
0.1
# all_on_gpu: True
# all_on_gpu: True
\ No newline at end of file
examples/a.out
View file @
ef367556
This diff is collapsed.
Click to expand it.
examples/test_JODIE.sh
0 → 100644
View file @
ef367556
bash test_all.sh 12347
wait
bash test_all.sh 12357
wait
bash test_all.sh 63457
wait
\ No newline at end of file
examples/test_all.sh
View file @
ef367556
...
@@ -2,41 +2,24 @@
...
@@ -2,41 +2,24 @@
#跑了4卡的TaoBao
#跑了4卡的TaoBao
# 定义数组变量
# 定义数组变量
seed
=
$1
seed
=
$1
addr
=
"192.168.1.10
6
"
addr
=
"192.168.1.10
5
"
partition_params
=(
"ours"
)
partition_params
=(
"ours"
)
#"metis" "ldg" "random")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions
=
"
8
"
partitions
=
"
4
"
node_per
=
"4"
node_per
=
"4"
nnodes
=
"2"
nnodes
=
"1"
<<<<<<
< Updated upstream
node_rank
=
"0"
node_rank
=
"0"
probability_params
=(
"0.1"
)
probability_params
=(
"0.1"
)
sample_type_params
=(
"boundery_recent_decay"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type
=(
"historical"
)
memory_type
=(
"historical"
)
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
)
shared_memory_ssim
=(
"0.3"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<
< HEAD
data_param
=(
"WIKI"
"LASTFM"
"WikiTalk"
"StackOverflow"
"GDELT"
)
data_param
=(
"GDELT"
)
=======
data_param
=(
"LASTFM"
)
>>>>>>>
8233776274204f6cf2f8a2eb37022d426d6197d8
=======
node_rank
=
"1"
probability_params
=(
"0.1"
"0.05"
"0.01"
"0"
)
sample_type_params
=(
"recent"
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local")
memory_type
=(
"historical"
"all_update"
"local"
)
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
"0.7"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param
=(
"WIKI"
"LASTFM"
"WikiTalk"
"StackOverflow"
"GDELT"
"TaoBao"
)
>>>>>>>
Stashed changes
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
#data_param=("REDDIT" "WikiTalk")
...
@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
...
@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
#seed=(( RANDOM % 1000000 + 1 ))
#seed=(( RANDOM % 1000000 + 1 ))
mkdir
-p
all_
"
$seed
"
mkdir
-p
all_
"
$seed
"
for
data
in
"
${
data_param
[@]
}
"
;
do
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"
APAN
"
model
=
"
JODIE
"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"
APAN
"
model
=
"
JODIE
"
fi
fi
#model="APAN"
#model="APAN"
mkdir all_
"
$seed
"
/
"
$data
"
mkdir all_
"
$seed
"
/
"
$data
"
...
@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do
...
@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do
if
[
"
$mem
"
=
"historical"
]
;
then
if
[
"
$mem
"
=
"historical"
]
;
then
for
ssim
in
"
${
shared_memory_ssim
[@]
}
"
;
do
for
ssim
in
"
${
shared_memory_ssim
[@]
}
"
;
do
if
[
"
$partition
"
=
"ours"
]
;
then
if
[
"
$partition
"
=
"ours"
]
;
then
<<<<<<
< Updated upstream
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--shared_memory_ssim
"
$ssim
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$ssim
"
-
"
$sample
"
-
"
$pro
"
.out &
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--shared_memory_ssim
"
$ssim
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$ssim
"
-
"
$sample
"
-
"
$pro
"
.out &
=======
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--shared_memory_ssim
"
$ssim
"
>
all/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$ssim
"
-
"
$sample
"
-
"
$pro
"
.out &
>>>>>>>
Stashed changes
wait
wait
fi
fi
done
done
...
...
examples/train_boundery.py
View file @
ef367556
...
@@ -229,7 +229,8 @@ def main():
...
@@ -229,7 +229,8 @@ def main():
fanout
=
sample_param
[
'neighbor'
]
if
'neighbor'
in
sample_param
else
[
10
]
fanout
=
sample_param
[
'neighbor'
]
if
'neighbor'
in
sample_param
else
[
10
]
policy
=
sample_param
[
'strategy'
]
if
'strategy'
in
sample_param
else
'recent'
policy
=
sample_param
[
'strategy'
]
if
'strategy'
in
sample_param
else
'recent'
no_neg
=
sample_param
[
'no_neg'
]
if
'no_neg'
in
sample_param
else
False
no_neg
=
sample_param
[
'no_neg'
]
if
'no_neg'
in
sample_param
else
False
if
policy
!=
'recent'
:
print
(
policy
)
if
policy
==
'recent'
:
policy_train
=
args
.
sample_type
#'boundery_recent_decay'
policy_train
=
args
.
sample_type
#'boundery_recent_decay'
else
:
else
:
policy_train
=
policy
policy_train
=
policy
...
@@ -480,7 +481,7 @@ def main():
...
@@ -480,7 +481,7 @@ def main():
val_list
=
[]
val_list
=
[]
loss_list
=
[]
loss_list
=
[]
for
e
in
range
(
train_param
[
'epoch'
]):
for
e
in
range
(
train_param
[
'epoch'
]):
model
.
module
.
memory_updater
.
empty_cache
()
#
model.module.memory_updater.empty_cache()
tt
.
_zero
()
tt
.
_zero
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
epoch_start_time
=
time
.
time
()
epoch_start_time
=
time
.
time
()
...
...
starrygl/module/memorys.py
View file @
ef367556
...
@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module):
...
@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module):
def
forward
(
self
,
b
,
param
=
None
):
def
forward
(
self
,
b
,
param
=
None
):
Q
=
self
.
w_q
(
b
.
srcdata
[
'mem'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
att_h
,
-
1
))
Q
=
self
.
w_q
(
b
.
srcdata
[
'mem'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
att_h
,
-
1
))
mails
=
b
.
srcdata
[
'mem_input'
]
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
mails
=
b
.
srcdata
[
'mem_input'
]
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
#print(mails.shape,b.srcdata['mem_input'].shape,b.srcdata['mail_ts'].shape)
if
self
.
dim_time
>
0
:
if
self
.
dim_time
>
0
:
time_feat
=
self
.
time_enc
(
b
.
srcdata
[
'ts'
][:,
None
]
-
b
.
srcdata
[
'mail_ts'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
time_feat
=
self
.
time_enc
(
b
.
srcdata
[
'ts'
][:,
None
]
-
b
.
srcdata
[
'mail_ts'
])
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
-
1
))
#print(time_feat.shape)
mails
=
torch
.
cat
([
mails
,
time_feat
],
dim
=
2
)
mails
=
torch
.
cat
([
mails
,
time_feat
],
dim
=
2
)
#print(mails.shape)
K
=
self
.
w_k
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
K
=
self
.
w_k
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
V
=
self
.
w_v
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
V
=
self
.
w_v
(
mails
)
.
reshape
((
b
.
num_src_nodes
(),
self
.
memory_param
[
'mailbox_size'
],
self
.
att_h
,
-
1
))
att
=
self
.
att_act
((
Q
[:,
None
,:,:]
*
K
)
.
sum
(
dim
=
3
))
att
=
self
.
att_act
((
Q
[:,
None
,:,:]
*
K
)
.
sum
(
dim
=
3
))
...
@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self
.
mailbox
.
handle_last_async
()
self
.
mailbox
.
handle_last_async
()
submit_to_queue
=
False
submit_to_queue
=
False
if
nxt_fetch_func
is
not
None
:
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
submit_to_queue
=
True
submit_to_queue
=
True
self
.
mailbox
.
set_memory_all_reduce
(
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
index
,
memory
,
memory_ts
,
...
@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit
=
submit_to_queue
,
spread_mail
=
spread_mail
,
wait_submit
=
submit_to_queue
,
spread_mail
=
spread_mail
,
update_cross_mm
=
False
,
update_cross_mm
=
False
,
)
)
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
def
local_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
def
local_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
if
nxt_fetch_func
is
not
None
:
if
nxt_fetch_func
is
not
None
:
...
@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
shared_ind
=
self
.
mailbox
.
is_shared_mask
[
DistIndex
(
b
.
srcdata
[
'ID'
][
mask
])
.
loc
]
shared_ind
=
self
.
mailbox
.
is_shared_mask
[
DistIndex
(
b
.
srcdata
[
'ID'
][
mask
])
.
loc
]
transition_dense
=
b
.
srcdata
[
'his_mem'
][
mask
]
+
self
.
filter
.
get_incretment
(
shared_ind
)
transition_dense
=
b
.
srcdata
[
'his_mem'
][
mask
]
+
self
.
filter
.
get_incretment
(
shared_ind
)
#print(transition_dense.shape)
if
not
(
transition_dense
.
max
()
.
item
()
==
0
):
if
not
(
transition_dense
.
max
()
.
item
()
==
0
):
transition_dense
-=
transition_dense
.
min
()
transition_dense
-=
transition_dense
.
min
()
transition_dense
/=
transition_dense
.
max
()
transition_dense
/=
transition_dense
.
max
()
...
@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
local_mask
=
(
DistIndex
(
index
)
.
part
==
torch
.
distributed
.
get_rank
())
local_mask
=
(
DistIndex
(
index
)
.
part
==
torch
.
distributed
.
get_rank
())
local_mask_mail
=
(
DistIndex
(
index0
)
.
part
==
torch
.
distributed
.
get_rank
())
local_mask_mail
=
(
DistIndex
(
index0
)
.
part
==
torch
.
distributed
.
get_rank
())
#
self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
self
.
mailbox
.
set_mailbox_local
(
DistIndex
(
index0
[
local_mask_mail
])
.
loc
,
mail
[
local_mask_mail
],
mail_ts
[
local_mask_mail
],
Reduce_Op
=
'max'
)
#
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self
.
mailbox
.
set_memory_local
(
DistIndex
(
index
[
local_mask
])
.
loc
,
memory
[
local_mask
],
memory_ts
[
local_mask
],
Reduce_Op
=
'max'
)
is_deliver
=
(
self
.
mailbox
.
deliver_to
==
'neighbors'
)
is_deliver
=
(
self
.
mailbox
.
deliver_to
==
'neighbors'
)
self
.
update_hunk
(
index
,
memory
,
memory_ts
,
index0
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
is_deliver
)
self
.
update_hunk
(
index
,
memory
,
memory_ts
,
index0
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
is_deliver
)
...
...
starrygl/sample/batch_data.py
View file @
ef367556
...
@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
...
@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
else
:
else
:
metadata
=
None
metadata
=
None
nid_mapper
:
torch
.
Tensor
=
graph
.
nids_mapper
nid_mapper
:
torch
.
Tensor
=
graph
.
nids_mapper
#print('reverse block {}\n'.format(identity))
if
identity
is
False
:
if
identity
is
False
:
assert
len
(
sample_out
)
==
1
assert
len
(
sample_out
)
==
1
ret
=
sample_out
[
0
]
ret
=
sample_out
[
0
]
...
@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
...
@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
src_index
=
ret
.
src_index
()
.
to
(
device
)
src_index
=
ret
.
src_index
()
.
to
(
device
)
else
:
else
:
#print('is jodie')
#print(sample_out)
src_index
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
src_index
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dst
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dst
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
dist_eid
=
torch
.
tensor
([],
dtype
=
torch
.
long
,
device
=
device
)
...
@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
...
@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
row_len
=
root_len
row_len
=
root_len
col
=
first_block_id
[:
row_len
]
col
=
first_block_id
[:
row_len
]
max_row
=
col
.
max
()
.
item
()
+
1
max_row
=
col
.
max
()
.
item
()
+
1
#print(src_index,dst)
b
=
dgl
.
create_block
((
col
[
src_index
]
.
to
(
device
),
b
=
dgl
.
create_block
((
col
[
src_index
]
.
to
(
device
),
torch
.
arange
(
dst
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)),
num_src_nodes
=
first_block_id
.
max
()
.
item
()
+
1
,
torch
.
arange
(
dst
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)),
num_src_nodes
=
first_block_id
.
max
()
.
item
()
+
1
,
num_dst_nodes
=
dst
.
shape
[
0
])
num_dst_nodes
=
dst
.
shape
[
0
])
...
@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
...
@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
t_s
=
time
.
time
()
t_s
=
time
.
time
()
param
=
{
'is_unique'
:
False
,
'nid_mapper'
:
nid_mapper
,
'eid_mapper'
:
eid_mapper
,
'out_device'
:
out_device
}
param
=
{
'is_unique'
:
False
,
'nid_mapper'
:
nid_mapper
,
'eid_mapper'
:
eid_mapper
,
'out_device'
:
out_device
}
out
=
sample_fn
(
sampler
,
data
,
neg_sampling
,
**
param
)
out
=
sample_fn
(
sampler
,
data
,
neg_sampling
,
**
param
)
#print(sampler.policy)
if
reversed
is
False
:
if
reversed
is
False
:
out
,
dist_nid
,
dist_eid
=
to_block
(
graph
,
data
,
out
,
out_device
)
out
,
dist_nid
,
dist_eid
=
to_block
(
graph
,
data
,
out
,
out_device
)
else
:
else
:
...
...
starrygl/sample/memory/change.py
View file @
ef367556
...
@@ -18,9 +18,10 @@ class MemoryMoniter:
...
@@ -18,9 +18,10 @@ class MemoryMoniter:
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid)
#self.nid_list.append(nid)
def
draw
(
self
,
degree
,
data
,
model
,
e
):
def
draw
(
self
,
degree
,
data
,
model
,
e
):
torch
.
save
(
self
.
nid_list
,
'all_args.seed/{}/{}/memorynid_{}.pt'
.
format
(
data
,
model
,
e
))
pass
torch
.
save
(
self
.
memorychange
,
'all_args.seed/{}/{}/memoryF_{}.pt'
.
format
(
data
,
model
,
e
))
#torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch
.
save
(
self
.
memory_ssim
,
'all_args.seed/{}/{}/memcos_{}.pt'
.
format
(
data
,
model
,
e
))
#torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
#torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data)
# path = './memory/{}/'.format(data)
# if not os.path.exists(path):
# if not os.path.exists(path):
...
@@ -87,6 +88,7 @@ class MemoryMoniter:
...
@@ -87,6 +88,7 @@ class MemoryMoniter:
def
set_zero
(
self
):
def
set_zero
(
self
):
self
.
memorychange
=
[]
pass
self
.
nid_list
=
[]
#self.memorychange = []
self
.
memory_ssim
=
[]
#self.nid_list =[]
#self.memory_ssim = []
starrygl/sample/memory/shared_mailbox.py
View file @
ef367556
...
@@ -146,7 +146,7 @@ class SharedMailBox():
...
@@ -146,7 +146,7 @@ class SharedMailBox():
source_ts
=
max_ts
source_ts
=
max_ts
source
=
source
[
id
]
source
=
source
[
id
]
index
=
unq_id
index
=
unq_id
#print(self.next_mail_pos[index])
self
.
mailbox_ts
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source_ts
self
.
mailbox_ts
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source_ts
self
.
mailbox
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source
self
.
mailbox
.
accessor
.
data
[
index
,
self
.
next_mail_pos
[
index
]]
=
source
if
self
.
memory_param
[
'mailbox_size'
]
>
1
:
if
self
.
memory_param
[
'mailbox_size'
]
>
1
:
...
@@ -180,9 +180,23 @@ class SharedMailBox():
...
@@ -180,9 +180,23 @@ class SharedMailBox():
if
self
.
deliver_to
==
'neighbors'
:
if
self
.
deliver_to
==
'neighbors'
:
assert
block
is
not
None
and
Reduce_score
is
None
assert
block
is
not
None
and
Reduce_score
is
None
mail
=
torch
.
cat
([
mail
,
mail
[
block
.
edges
()[
0
]
.
long
()]],
dim
=
0
)
# print(block.edges().shape)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
block
.
edges
()[
0
]
.
long
()]],
dim
=
0
)
root
=
torch
.
cat
([
src
,
dst
])
.
reshape
(
-
1
)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos
,
idx
=
torch_scatter
.
scatter_max
(
mail_ts
,
root
,
0
)
mail
=
torch
.
cat
([
mail
,
mail
[
idx
]],
dim
=
0
)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
idx
]],
dim
=
0
)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index
=
torch
.
cat
([
index
,
block
.
dstdata
[
'ID'
][
block
.
edges
()[
1
]
.
long
()]],
dim
=
0
)
index
=
torch
.
cat
([
index
,
block
.
dstdata
[
'ID'
][
block
.
edges
()[
1
]
.
long
()]],
dim
=
0
)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if
Reduce_score
is
not
None
:
if
Reduce_score
is
not
None
:
Reduce_score
=
torch
.
cat
((
Reduce_score
,
Reduce_score
),
-
1
)
.
to
(
self
.
device
)
Reduce_score
=
torch
.
cat
((
Reduce_score
,
Reduce_score
),
-
1
)
.
to
(
self
.
device
)
if
Reduce_score
is
None
:
if
Reduce_score
is
None
:
...
@@ -192,12 +206,19 @@ class SharedMailBox():
...
@@ -192,12 +206,19 @@ class SharedMailBox():
mail
=
mail
[
idx
]
mail
=
mail
[
idx
]
index
=
unq_index
index
=
unq_index
else
:
else
:
unq_index
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
uni
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
perm
=
inv
.
new_empty
(
uni
.
size
(
0
))
.
scatter_
(
0
,
inv
,
perm
)
index
=
index
[
perm
]
mail
=
mail
[
perm
]
mail_ts
=
mail_ts
[
perm
]
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
#print(inv.shape,Reduce_score.shape)
max_score
,
idx
=
torch_scatter
.
scatter_max
(
Reduce_score
,
inv
,
0
)
#max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts
=
mail_ts
[
idx
]
#mail_ts = mail_ts[idx]
mail
=
mail
[
idx
]
#mail = mail[idx]
index
=
unq_index
#index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return
index
,
mail
,
mail_ts
return
index
,
mail
,
mail_ts
def
get_update_memory
(
self
,
index
,
memory
,
memory_ts
,
embedding
):
def
get_update_memory
(
self
,
index
,
memory
,
memory_ts
,
embedding
):
...
@@ -205,7 +226,8 @@ class SharedMailBox():
...
@@ -205,7 +226,8 @@ class SharedMailBox():
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
memory_ts
,
inv
,
0
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
memory_ts
,
inv
,
0
)
ts
=
max_ts
ts
=
max_ts
index
=
unq_index
index
=
unq_index
memory
=
memory
[
idx
]
memory
=
memory
[
idx
]
#print('memory {} {}\n'.format(index.shape,memory.shape,ts.shape))
return
index
,
memory
,
ts
return
index
,
memory
,
ts
def
pack
(
self
,
memory
=
None
,
memory_ts
=
None
,
mail
=
None
,
mail_ts
=
None
,
index
=
None
,
mode
=
None
):
def
pack
(
self
,
memory
=
None
,
memory_ts
=
None
,
mail
=
None
,
mail_ts
=
None
,
index
=
None
,
mode
=
None
):
...
@@ -250,6 +272,7 @@ class SharedMailBox():
...
@@ -250,6 +272,7 @@ class SharedMailBox():
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_mail
,
gather_mail_ts
,
Reduce_Op
=
reduce_Op
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_mail
,
gather_mail_ts
,
Reduce_Op
=
reduce_Op
)
else
:
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_mail
(
self
,
reduce_Op
=
None
,):
def
handle_last_mail
(
self
,
reduce_Op
=
None
,):
if
self
.
last_mail_sync
is
not
None
:
if
self
.
last_mail_sync
is
not
None
:
...
@@ -260,6 +283,7 @@ class SharedMailBox():
...
@@ -260,6 +283,7 @@ class SharedMailBox():
if
isinstance
(
gather_memory
,
list
):
if
isinstance
(
gather_memory
,
list
):
gather_memory
=
torch
.
cat
(
gather_memory
,
dim
=
0
)
gather_memory
=
torch
.
cat
(
gather_memory
,
dim
=
0
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_async
(
self
,
reduce_Op
=
None
):
def
handle_last_async
(
self
,
reduce_Op
=
None
):
self
.
handle_last_memory
(
reduce_Op
)
self
.
handle_last_memory
(
reduce_Op
)
...
@@ -303,6 +327,7 @@ class SharedMailBox():
...
@@ -303,6 +327,7 @@ class SharedMailBox():
return
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_mail_job
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_mail_job
self
.
next_wait_mail_job
=
None
self
.
next_wait_mail_job
=
None
#print(index,gather_id_list)
handle0
=
torch
.
distributed
.
all_to_all_single
(
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
...
@@ -409,6 +434,7 @@ class SharedMailBox():
...
@@ -409,6 +434,7 @@ class SharedMailBox():
self
.
update_p2p_mail
()
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
self
.
handle_last_async
()
ctx
=
DistributedContext
.
get_default_context
()
ctx
=
DistributedContext
.
get_default_context
()
...
@@ -483,7 +509,7 @@ class SharedMailBox():
...
@@ -483,7 +509,7 @@ class SharedMailBox():
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
shared_memory
=
shared_memory
[
idx
]
shared_memory
=
shared_memory
[
idx
]
shared_memory
=
shared_memory
[
idx
]
#
shared_memory = shared_memory[idx]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail
=
shared_mail
[
idx
]
shared_mail
=
shared_mail
[
idx
]
...
@@ -495,11 +521,13 @@ class SharedMailBox():
...
@@ -495,11 +521,13 @@ class SharedMailBox():
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_mail_indx
],
shared_mail
,
shared_mail_ts
)
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_mail_indx
],
shared_mail
,
shared_mail_ts
)
else
:
else
:
self
.
next_wait_gather_memory_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
self
.
next_wait_gather_memory_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
not
wait_submit
:
if
not
wait_submit
:
self
.
update_shared
()
self
.
update_shared
()
self
.
update_p2p_mail
()
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
self
.
sychronize_shared
()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
...
...
starrygl/sample/sample_core/neighbor_sampler.py
View file @
ef367556
...
@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler):
...
@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler):
"""
"""
if
self
.
no_neg
:
if
self
.
no_neg
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]
//
3
*
2
],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
],
is_unique
=
False
)
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]
//
3
*
2
],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
],
is_unique
=
False
)
else
:
else
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
]
,
is_unique
=
False
)
out
,
metadata
=
self
.
sample_from_nodes
(
seed
,
seed_ts
,
is_unique
=
False
)
src_pos_index
=
torch
.
arange
(
0
,
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
src_pos_index
=
torch
.
arange
(
0
,
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
dst_pos_index
=
torch
.
arange
(
num_pos
,
2
*
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
dst_pos_index
=
torch
.
arange
(
num_pos
,
2
*
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment