Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
4fa85d33
Commit
4fa85d33
authored
Nov 14, 2024
by
xxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
synchronization
parent
ab4e56d0
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
53 deletions
+95
-53
examples/result.py
+69
-33
examples/test_all.sh
+7
-7
examples/train_boundery.py
+19
-13
No files found.
examples/result.py
View file @
4fa85d33
...
@@ -2,53 +2,89 @@ import matplotlib.pyplot as plt
...
@@ -2,53 +2,89 @@ import matplotlib.pyplot as plt
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
# 读取文件内容
# 读取文件内容
import
os
probability_values
=
[
0.1
]
#[0.1,0.05,0.01,0]
probability_values
=
[
0.1
]
#
]
[0.1,0.05,0.01,0]
data_values
=
[
'W
ikiTalk
'
]
# 存储从文件中读取的数据
data_values
=
[
'W
IKI'
,
'LASTFM'
,
'WikiTalk'
,
'StackOverflow'
,
'GDELT
'
]
# 存储从文件中读取的数据
seed
=
[
'1
2357'
]
#
,'12347','63377','53473','54763']
seed
=
[
'1
3357'
,
'12347'
,
'53473'
,
'54763'
,
'12347'
,
'63377'
,
'53473'
,
'54763'
]
partition
=
'ours_shared'
partition
=
'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions
=
4
partitions
=
12
topk
=
0.01
topk
=
0.01
mem
=
'historical-0.3'
#'historical'
mem
=
'historical-0.3'
#'historical'
model0
=
'
APA
N'
model0
=
'
TG
N'
def
average
(
l
):
def
average
(
l
):
return
sum
(
l
)
/
len
(
l
)
return
sum
(
l
)
/
len
(
l
)
for
data
in
data_values
:
for
p
in
probability_values
:
ap_list
=
[]
for
data
in
data_values
:
comm_list
=
[]
ap_list
=
[]
for
sd
in
seed
:
comm_list
=
[]
for
p
in
probability_values
:
test_time
=
[]
train_time
=
[]
total_communication
=
[]
shared_synchronize
=
[]
for
sd
in
seed
:
if
data
==
'WIKI'
or
data
==
'LASTFM'
:
if
data
==
'WIKI'
or
data
==
'LASTFM'
:
model
=
model0
model
=
model0
else
:
else
:
model
=
model0
+
'_large'
model
=
model0
+
'_large'
if
p
==
1
:
if
p
==
1
:
file
=
'all_{}/{}/{}/{}-{}-{}-{}-recent.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
)
file
=
'
../examples/
all_{}/{}/{}/{}-{}-{}-{}-recent.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
)
else
:
else
:
file
=
'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
,
p
)
file
=
'../examples/all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
,
p
)
#print(file)
prefix
=
"val ap:"
prefix
=
"val ap:"
max_val_ap
=
0
max_val_ap
=
0
test_ap
=
0
test_ap
=
0
#if
with
open
(
file
,
'r'
)
as
file
:
#print(file)
for
line
in
file
:
if
os
.
path
.
exists
(
file
):
if
line
.
find
(
'Epoch 50:'
)
!=-
1
:
with
open
(
file
,
'r'
)
as
file
:
break
_total_communication
=
[]
if
line
.
find
(
prefix
)
!=-
1
:
_shared_synchronize
=
[]
pos
=
line
.
find
(
prefix
)
+
len
(
prefix
)
for
line
in
file
:
posr
=
line
.
find
(
' '
,
pos
)
#if line.find('Epoch 50:')!=-1:
#print(line[pos:posr])
# break
val_ap
=
float
(
line
[
pos
:
posr
])
if
line
.
find
(
prefix
)
!=-
1
:
pos
=
line
.
find
(
prefix
)
+
len
(
prefix
)
pos
=
line
.
find
(
"test ap "
)
+
len
(
"test ap "
)
posr
=
line
.
find
(
' '
,
pos
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
#print(line[pos:posr])
val_ap
=
float
(
line
[
pos
:
posr
])
_test_ap
=
float
(
line
[
pos
:
posr
])
if
(
val_ap
>
max_val_ap
):
pos
=
line
.
find
(
"test ap "
)
+
len
(
"test ap "
)
max_val_ap
=
val_ap
posr
=
line
.
find
(
' '
,
pos
)
test_ap
=
_test_ap
#print(line[pos:posr])
ap_list
.
append
(
test_ap
)
_test_ap
=
float
(
line
[
pos
:
posr
])
print
(
'data {} model {} ap: {}'
.
format
(
data
,
model
,
ap_list
))
if
(
val_ap
>
max_val_ap
):
max_val_ap
=
val_ap
test_ap
=
_test_ap
elif
line
.
find
(
'avg_time '
)
!=-
1
:
pl
=
line
.
find
(
'avg_time '
)
+
len
(
'avg_time '
)
pr
=
line
.
find
(
' test time'
)
train_time
.
append
(
float
(
line
[
pl
:
pr
]))
test_time
.
append
(
float
(
line
[
pr
+
len
(
' test time '
):]))
#print(line)
ap_list
.
append
(
test_ap
)
total_communication
.
append
(
average
(
_total_communication
))
shared_synchronize
.
append
(
average
(
_shared_synchronize
))
elif
line
.
find
(
'remote node number tensor(['
)
!=-
1
:
pl
=
line
.
find
(
'remote node number tensor(['
)
+
len
(
'remote node number tensor(['
)
pr
=
line
.
find
(
'])'
,
pl
)
_total_communication
.
append
(
int
(
line
[
pl
:
pr
]))
#if(p==0):
#print(file)
#print(line)
elif
line
.
find
(
'shared comm tensor(['
)
!=-
1
:
pl
=
line
.
find
(
'shared comm tensor(['
)
+
len
(
'shared comm tensor(['
)
pr
=
line
.
find
(
'])'
,
pl
)
_shared_synchronize
.
append
(
int
(
line
[
pl
:
pr
]))
else
:
print
(
file
)
if
len
(
ap_list
)
>
0
:
#print('prob {} data {} model {} remote volume : {} synchronize volume : {}'.format(p,data,model,average(total_communication),average(shared_synchronize)))
print
(
'prob {} data {} model {} ap: {} train_time: {} eval time: {} remote volume : {} synchronize volume : {}'
.
format
(
p
,
data
,
model
,
average
(
ap_list
),
average
(
train_time
),
average
(
test_time
),
average
(
total_communication
),
average
(
shared_synchronize
)))
examples/test_all.sh
View file @
4fa85d33
...
@@ -2,14 +2,14 @@
...
@@ -2,14 +2,14 @@
#跑了4卡的TaoBao
#跑了4卡的TaoBao
# 定义数组变量
# 定义数组变量
seed
=
$1
seed
=
$1
addr
=
"192.168.1.10
6
"
addr
=
"192.168.1.10
7
"
partition_params
=(
"ours"
)
partition_params
=(
"ours"
)
#"metis" "ldg" "random")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions
=
"
4
"
partitions
=
"
12
"
node_per
=
"4"
node_per
=
"4"
nnodes
=
"
1
"
nnodes
=
"
3
"
node_rank
=
"
0
"
node_rank
=
"
1
"
probability_params
=(
"0.1"
)
probability_params
=(
"0.1"
)
sample_type_params
=(
"boundery_recent_decay"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
...
@@ -163,11 +163,11 @@ for data in "${data_param[@]}"; do
...
@@ -163,11 +163,11 @@ for data in "${data_param[@]}"; do
done
done
done
done
done
done
data_param
=(
"StackOverflow"
"GDELT"
)
for
data
in
"
${
data_param
[@]
}
"
;
do
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"TGN"
model
=
"TGN
_large
"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"TGN
_large
"
model
=
"TGN"
#continue
#continue
fi
fi
#model="APAN"
#model="APAN"
...
...
examples/train_boundery.py
View file @
4fa85d33
...
@@ -110,7 +110,7 @@ if not 'WORLD_SIZE' in os.environ:
...
@@ -110,7 +110,7 @@ if not 'WORLD_SIZE' in os.environ:
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
args
.
world_size
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
args
.
world_size
)
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
args
.
local_rank
)
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
args
.
local_rank
)
if
not
'MASTER_ADDR'
in
os
.
environ
:
if
not
'MASTER_ADDR'
in
os
.
environ
:
os
.
environ
[
"MASTER_ADDR"
]
=
'192.168.
2
.107'
os
.
environ
[
"MASTER_ADDR"
]
=
'192.168.
1
.107'
if
not
'MASTER_PORT'
in
os
.
environ
:
if
not
'MASTER_PORT'
in
os
.
environ
:
os
.
environ
[
"MASTER_PORT"
]
=
'9337'
os
.
environ
[
"MASTER_PORT"
]
=
'9337'
...
@@ -379,7 +379,10 @@ def main():
...
@@ -379,7 +379,10 @@ def main():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
total_loss
=
0
total_loss
=
0
signal
=
torch
.
tensor
([
0
],
dtype
=
int
,
device
=
device
)
signal
=
torch
.
tensor
([
0
],
dtype
=
int
,
device
=
device
)
batch_cnt
=
0
for
roots
,
mfgs
,
metadata
in
loader
:
for
roots
,
mfgs
,
metadata
in
loader
:
print
(
batch_cnt
)
batch_cnt
=
batch_cnt
+
1
"""
"""
if ctx.memory_group == 0:
if ctx.memory_group == 0:
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=neg_samples)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=neg_samples)
...
@@ -450,15 +453,15 @@ def main():
...
@@ -450,15 +453,15 @@ def main():
"""
"""
ap
=
torch
.
empty
([
1
])
ap
=
torch
.
empty
([
1
])
auc_mrr
=
torch
.
empty
([
1
])
auc_mrr
=
torch
.
empty
([
1
])
if
(
ctx
.
memory_group
==
0
):
#
if(ctx.memory_group==0):
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
ap
[
0
]
=
torch
.
tensor
(
aps
)
.
mean
()
ap
[
0
]
=
torch
.
tensor
(
aps
)
.
mean
()
auc_mrr
[
0
]
=
torch
.
tensor
(
aucs_mrrs
)
.
mean
()
#float(aucs_mrrs.clone().mean())
auc_mrr
[
0
]
=
torch
.
tensor
(
aucs_mrrs
)
.
mean
()
#float(aucs_mrrs.clone().mean())
print
(
'mode: {} {} {}
'
.
format
(
mode
,
ap
,
auc_mrr
))
print
(
'mode: {} {} {}
\n
'
.
format
(
mode
,
ap
,
auc_mrr
))
dist
.
all_reduce
(
ap
,
group
=
ctx
.
gloo_group
)
dist
.
all_reduce
(
ap
,
group
=
ctx
.
gloo_group
)
ap
/=
ctx
.
memory_group_size
ap
/=
ctx
.
memory_group_size
dist
.
all_reduce
(
auc_mrr
,
group
=
ctx
.
gloo_group
)
dist
.
all_reduce
(
auc_mrr
,
group
=
ctx
.
gloo_group
)
auc_mrr
/=
ctx
.
memory_group_size
auc_mrr
/=
ctx
.
memory_group_size
dist
.
broadcast
(
ap
,
0
,
group
=
ctx
.
gloo_group
)
dist
.
broadcast
(
ap
,
0
,
group
=
ctx
.
gloo_group
)
dist
.
broadcast
(
auc_mrr
,
0
,
group
=
ctx
.
gloo_group
)
dist
.
broadcast
(
auc_mrr
,
0
,
group
=
ctx
.
gloo_group
)
return
ap
.
item
(),
auc_mrr
.
item
()
return
ap
.
item
(),
auc_mrr
.
item
()
...
@@ -641,10 +644,13 @@ def main():
...
@@ -641,10 +644,13 @@ def main():
tt
.
weight_count_remote
=
0
tt
.
weight_count_remote
=
0
tt
.
ssim_cnt
=
0
tt
.
ssim_cnt
=
0
ap
,
auc
=
eval
(
'val'
)
ap
,
auc
=
eval
(
'val'
)
torch
.
cuda
.
synchronize
()
print
(
'finish val'
)
#torch.cuda.synchronize()
print
(
'start'
)
t_test
=
time
.
time
()
t_test
=
time
.
time
()
test_ap
,
test_auc
=
eval
(
'test'
)
print
(
'test'
)
torch
.
cuda
.
synchronize
()
test_ap
,
test_auc
=
0
,
0
#eval('test')
#torch.cuda.synchronize()
t_test
=
time
.
time
()
-
t_test
t_test
=
time
.
time
()
-
t_test
total_test_time
+=
t_test
total_test_time
+=
t_test
test_ap_list
.
append
((
test_ap
,
test_auc
))
test_ap_list
.
append
((
test_ap
,
test_auc
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment