Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
4ad96131
Commit
4ad96131
authored
Nov 19, 2024
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
delte some code useless
parent
4fa85d33
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
120 additions
and
38 deletions
+120
-38
examples-probability-sample/result.py
+99
-0
examples/test_all.sh
+4
-3
examples/train_boundery.py
+3
-1
starrygl/module/modules.py
+2
-2
starrygl/sample/batch_data.py
+3
-3
starrygl/sample/memory/shared_mailbox.py
+9
-29
No files found.
examples-probability-sample/result.py
0 → 100644
View file @
4ad96131
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
# 读取文件内容
import
os
probability_values
=
[
1
,
0.1
,
0.05
,
0.01
,
0
]
#[0.1,0.05,0.01,0]
data_values
=
[
'WIKI'
,
'LASTFM'
,
'WikiTalk'
,
'StackOverflow'
,
'GDELT'
]
# 存储从文件中读取的数据
seed
=
[
'13357'
,
'12347'
,
'53473'
,
'54763'
,
'12347'
,
'63377'
,
'53473'
,
'54763'
]
partition
=
'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions
=
4
topk
=
0.01
mem
=
'all_update'
#'historical'
model0
=
'TGN'
def
average
(
l
):
return
sum
(
l
)
/
len
(
l
)
for
p
in
probability_values
:
for
data
in
data_values
:
ap_list
=
[]
comm_list
=
[]
test_time
=
[]
train_time
=
[]
total_communication
=
[]
shared_synchronize
=
[]
for
sd
in
seed
:
if
data
==
'WIKI'
or
data
==
'LASTFM'
:
model
=
model0
else
:
model
=
model0
+
'_large'
if
p
==
1
:
file
=
'../examples-probability-sample/all_{}/{}/{}/{}-{}-{}-{}-recent.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
)
else
:
file
=
'../examples-probability-sample/all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
,
p
)
#print(file)
prefix
=
"val ap:"
max_val_ap
=
0
test_ap
=
0
#if
#print(file)
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
_total_communication
=
[]
_shared_synchronize
=
[]
for
line
in
file
:
#if line.find('Epoch 50:')!=-1:
# break
if
line
.
find
(
prefix
)
!=-
1
:
pos
=
line
.
find
(
prefix
)
+
len
(
prefix
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
val_ap
=
float
(
line
[
pos
:
posr
])
pos
=
line
.
find
(
"test ap "
)
+
len
(
"test ap "
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
_test_ap
=
float
(
line
[
pos
:
posr
])
if
(
val_ap
>
max_val_ap
):
max_val_ap
=
val_ap
test_ap
=
_test_ap
elif
line
.
find
(
'avg_time '
)
!=-
1
:
pl
=
line
.
find
(
'avg_time '
)
+
len
(
'avg_time '
)
pr
=
line
.
find
(
' test time'
)
train_time
.
append
(
float
(
line
[
pl
:
pr
]))
test_time
.
append
(
float
(
line
[
pr
+
len
(
' test time '
):]))
#print(line)
ap_list
.
append
(
test_ap
)
total_communication
.
append
(
average
(
_total_communication
))
shared_synchronize
.
append
(
average
(
_shared_synchronize
))
"""
local node number tensor([114453298]) remote node number tensor([20457479]) local edge tensor([1592867807]) remote edgetensor([326448632])
local node number tensor([114453298]) remote node number tensor([20457479]) local edge tensor([1592867807]) remote edgetensor([326448632])
comm local node number 0 remote node number 0 local edge 0 remote edge0
comm local node number 0 remote node number 0 local edge 0 remote edge0
memory comm tensor([0]) shared comm tensor([24860])
memory comm tensor([0]) shared comm tensor([24860])
"""
elif
line
.
find
(
'remote node number tensor(['
)
!=-
1
:
pl
=
line
.
find
(
'remote node number tensor(['
)
+
len
(
'remote node number tensor(['
)
pr
=
line
.
find
(
'])'
,
pl
)
_total_communication
.
append
(
int
(
line
[
pl
:
pr
]))
#if(p==0):
#print(file)
#print(line)
elif
line
.
find
(
'shared comm tensor(['
)
!=-
1
:
pl
=
line
.
find
(
'shared comm tensor(['
)
+
len
(
'shared comm tensor(['
)
pr
=
line
.
find
(
'])'
,
pl
)
_shared_synchronize
.
append
(
int
(
line
[
pl
:
pr
]))
#else:
# print(file)
if
len
(
ap_list
)
>
0
:
#print('prob {} data {} model {} remote volume : {} synchronize volume : {}'.format(p,data,model,average(total_communication),average(shared_synchronize)))
print
(
'prob {} data {} model {} ap: {} train_time: {} eval time: {} remote volume : {} synchronize volume : {}'
.
format
(
p
,
data
,
model
,
average
(
ap_list
),
average
(
train_time
),
average
(
test_time
),
average
(
total_communication
),
average
(
shared_synchronize
)))
examples/test_all.sh
View file @
4ad96131
...
@@ -11,15 +11,16 @@ node_per="4"
...
@@ -11,15 +11,16 @@ node_per="4"
nnodes
=
"3"
nnodes
=
"3"
node_rank
=
"1"
node_rank
=
"1"
probability_params
=(
"0.1"
)
probability_params
=(
"0.1"
)
probability_params
=(
"0.1"
)
sample_type_params
=(
"boundery_recent_decay"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type
=(
"
histori
cal"
)
memory_type
=(
"
lo
cal"
)
#"historical")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
)
shared_memory_ssim
=(
"0.3"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param
=(
"WIKI"
"LASTFM"
"WikiTalk"
"StackOverflow"
"GDELT"
)
data_param
=(
"WIKI"
)
#"GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
...
@@ -34,7 +35,7 @@ mkdir -p all_"$seed"
...
@@ -34,7 +35,7 @@ mkdir -p all_"$seed"
for
data
in
"
${
data_param
[@]
}
"
;
do
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"APAN_large"
model
=
"APAN_large"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"
APA
N"
model
=
"
TG
N"
fi
fi
#model="APAN"
#model="APAN"
mkdir all_
"
$seed
"
/
"
$data
"
mkdir all_
"
$seed
"
/
"
$data
"
...
...
examples/train_boundery.py
View file @
4ad96131
...
@@ -236,7 +236,9 @@ def main():
...
@@ -236,7 +236,9 @@ def main():
policy_train
=
policy
policy_train
=
policy
if
memory_param
[
'type'
]
!=
'none'
:
if
memory_param
[
'type'
]
!=
'none'
:
mailbox
=
SharedMailBox
(
graph
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
graph
.
efeat
.
shape
[
1
]
if
graph
.
efeat
is
not
None
else
0
,
mailbox
=
SharedMailBox
(
graph
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
graph
.
efeat
.
shape
[
1
]
if
graph
.
efeat
is
not
None
else
0
,
shared_nodes_index
=
graph
.
shared_nids_list
[
ctx
.
memory_group_rank
],
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
shared_ssim
=
args
.
shared_memory_ssim
)
shared_nodes_index
=
graph
.
shared_nids_list
[
ctx
.
memory_group_rank
],
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
start_historical
=
(
args
.
memory_type
==
'historical'
)
shared_ssim
=
args
.
shared_memory_ssim
)
else
:
else
:
mailbox
=
None
mailbox
=
None
...
...
starrygl/module/modules.py
View file @
4ad96131
...
@@ -76,12 +76,12 @@ class GeneralModel(torch.nn.Module):
...
@@ -76,12 +76,12 @@ class GeneralModel(torch.nn.Module):
self
.
dim_edge
=
dim_edge
self
.
dim_edge
=
dim_edge
self
.
sample_param
=
sample_param
self
.
sample_param
=
sample_param
self
.
memory_param
=
memory_param
self
.
memory_param
=
memory_param
self
.
train_pos_ratio
,
self
.
train_neg_ratio
=
train_ratio
#
self.train_pos_ratio,self.train_neg_ratio = train_ratio
if
not
'dim_out'
in
gnn_param
:
if
not
'dim_out'
in
gnn_param
:
gnn_param
[
'dim_out'
]
=
memory_param
[
'dim_out'
]
gnn_param
[
'dim_out'
]
=
memory_param
[
'dim_out'
]
self
.
gnn_param
=
gnn_param
self
.
gnn_param
=
gnn_param
self
.
train_param
=
train_param
self
.
train_param
=
train_param
self
.
neg_fix_layer
=
NegFixLayer
()
#
self.neg_fix_layer = NegFixLayer()
if
memory_param
[
'type'
]
==
'node'
:
if
memory_param
[
'type'
]
==
'node'
:
if
memory_param
[
'memory_update'
]
==
'gru'
:
if
memory_param
[
'memory_update'
]
==
'gru'
:
#if memory_param['async'] == False:
#if memory_param['async'] == False:
...
...
starrygl/sample/batch_data.py
View file @
4ad96131
...
@@ -290,9 +290,9 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
...
@@ -290,9 +290,9 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
if
sample_out
[
r
]
.
delta_ts
()
.
shape
[
0
]
>
0
:
if
sample_out
[
r
]
.
delta_ts
()
.
shape
[
0
]
>
0
:
b
.
edata
[
'dt'
]
=
sample_out
[
r
]
.
delta_ts
()
.
to
(
device
)
b
.
edata
[
'dt'
]
=
sample_out
[
r
]
.
delta_ts
()
.
to
(
device
)
b
.
srcdata
[
'ts'
]
=
block_node_list
[
1
,
b
.
srcnodes
()]
.
to
(
torch
.
float
)
b
.
srcdata
[
'ts'
]
=
block_node_list
[
1
,
b
.
srcnodes
()]
.
to
(
torch
.
float
)
weight
=
sample_out
[
r
]
.
sample_weight
()
#
weight = sample_out[r].sample_weight()
if
(
weight
.
shape
[
0
]
>
0
):
#
if(weight.shape[0] > 0):
b
.
edata
[
'weight'
]
=
1
/
torch
.
clamp
(
sample_out
[
r
]
.
sample_weight
(),
0.0001
)
.
to
(
b
.
device
)
#
b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
b
.
edata
[
'__ID'
]
=
e_idx
b
.
edata
[
'__ID'
]
=
e_idx
col
=
row
col
=
row
col_len
+=
eid_len
[
r
]
col_len
+=
eid_len
[
r
]
...
...
starrygl/sample/memory/shared_mailbox.py
View file @
4ad96131
...
@@ -105,7 +105,7 @@ class SharedMailBox():
...
@@ -105,7 +105,7 @@ class SharedMailBox():
self
.
is_shared_mask
[
shared_nodes_index
]
=
torch
.
arange
(
self
.
shared_nodes_index
.
shape
[
0
],
dtype
=
torch
.
int
,
self
.
is_shared_mask
[
shared_nodes_index
]
=
torch
.
arange
(
self
.
shared_nodes_index
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
torch
.
device
(
'cuda:{}'
.
format
(
ctx
.
local_rank
)))
device
=
torch
.
device
(
'cuda:{}'
.
format
(
ctx
.
local_rank
)))
if
start_historical
is
not
None
:
if
start_historical
:
self
.
historical_cache
=
historical_cache
.
HistoricalCache
(
self
.
shared_nodes_index
,
0
,
self
.
node_memory
.
shape
[
1
],
self
.
node_memory
.
dtype
,
self
.
node_memory
.
device
,
threshold
=
shared_ssim
)
self
.
historical_cache
=
historical_cache
.
HistoricalCache
(
self
.
shared_nodes_index
,
0
,
self
.
node_memory
.
shape
[
1
],
self
.
node_memory
.
dtype
,
self
.
node_memory
.
device
,
threshold
=
shared_ssim
)
self
.
_mem_pin
=
{}
self
.
_mem_pin
=
{}
self
.
_mail_pin
=
{}
self
.
_mail_pin
=
{}
...
@@ -180,26 +180,12 @@ class SharedMailBox():
...
@@ -180,26 +180,12 @@ class SharedMailBox():
if
self
.
deliver_to
==
'neighbors'
:
if
self
.
deliver_to
==
'neighbors'
:
assert
block
is
not
None
and
Reduce_score
is
None
assert
block
is
not
None
and
Reduce_score
is
None
# print(block.edges().shape)
root
=
torch
.
cat
([
src
,
dst
])
.
reshape
(
-
1
)
root
=
torch
.
cat
([
src
,
dst
])
.
reshape
(
-
1
)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
_
,
idx
=
torch_scatter
.
scatter_max
(
mail_ts
,
root
,
0
)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos
,
idx
=
torch_scatter
.
scatter_max
(
mail_ts
,
root
,
0
)
#print(block.number_of_edges())
mail
=
torch
.
cat
([
mail
,
mail
[
idx
[
block
.
edges
()[
0
]
.
long
()]]],
dim
=
0
)
mail
=
torch
.
cat
([
mail
,
mail
[
idx
[
block
.
edges
()[
0
]
.
long
()]]],
dim
=
0
)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
idx
[
block
.
edges
()[
0
]
.
long
()]]],
dim
=
0
)
mail_ts
=
torch
.
cat
([
mail_ts
,
mail_ts
[
idx
[
block
.
edges
()[
0
]
.
long
()]]],
dim
=
0
)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index
=
torch
.
cat
([
index
,
block
.
dstdata
[
'ID'
][
block
.
edges
()[
1
]
.
long
()]],
dim
=
0
)
index
=
torch
.
cat
([
index
,
block
.
dstdata
[
'ID'
][
block
.
edges
()[
1
]
.
long
()]],
dim
=
0
)
#print(index)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if
Reduce_score
is
not
None
:
if
Reduce_score
is
not
None
:
Reduce_score
=
torch
.
cat
((
Reduce_score
,
Reduce_score
),
-
1
)
.
to
(
self
.
device
)
Reduce_score
=
torch
.
cat
((
Reduce_score
,
Reduce_score
),
-
1
)
.
to
(
self
.
device
)
if
Reduce_score
is
None
:
if
Reduce_score
is
None
:
...
@@ -209,18 +195,12 @@ class SharedMailBox():
...
@@ -209,18 +195,12 @@ class SharedMailBox():
mail
=
mail
[
idx
]
mail
=
mail
[
idx
]
index
=
unq_index
index
=
unq_index
else
:
else
:
uni
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
unq_index
,
inv
=
torch
.
unique
(
index
,
return_inverse
=
True
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
print
(
inv
.
shape
,
Reduce_score
.
shape
)
perm
=
inv
.
new_empty
(
uni
.
size
(
0
))
.
scatter_
(
0
,
inv
,
perm
)
max_score
,
idx
=
torch_scatter
.
scatter_max
(
Reduce_score
,
inv
,
0
)
index
=
index
[
perm
]
mail_ts
=
mail_ts
[
idx
]
mail
=
mail
[
perm
]
mail
=
mail
[
idx
]
mail_ts
=
mail_ts
[
perm
]
index
=
unq_index
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
#max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
#mail_ts = mail_ts[idx]
#mail = mail[idx]
#index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return
index
,
mail
,
mail_ts
return
index
,
mail
,
mail_ts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment