Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
d7bc324c
Commit
d7bc324c
authored
Jan 20, 2025
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
increament mtgnn
parent
abb7e9e8
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
652 additions
and
19 deletions
+652
-19
examples/convergence.pdf
+0
-0
examples/draw_convergence.py
+425
-0
examples/train_boundery.py
+1
-1
linear_attention.md
+31
-0
starrygl/module/layers.py
+192
-15
starrygl/module/modules.py
+3
-3
No files found.
examples/convergence.pdf
0 → 100644
View file @
d7bc324c
File added
examples/draw_convergence.py
0 → 100644
View file @
d7bc324c
labels
=
[
'4'
,
'8'
,
'12'
,
'16'
]
methods
=
[
'TGL'
,
'MSPipe'
,
'DistTGL'
,
'Ours'
]
table_label
=
[
'TGL-1'
,
'TGL-4'
,
'MSPipe-4'
,
'DistTGL-4'
,
'Ours-4'
,
'MSPipe-8'
,
'DistTGL-8'
,
'Ours-8'
,
'MSPipe-12'
,
'DistTGL-12'
,
'Ours-12'
,
'MSPipe-16'
,
'DistTGL-16'
,
'Ours-16'
]
table_label_no
=
[
'TGL-1'
,
'TGL-4'
,
'MSPipe-4'
,
'Ours-4'
,
'MSPipe-8'
,
'Ours-8'
,
'MSPipe-12'
,
'Ours-12'
,
'MSPipe-16'
,
'Ours-16'
]
dataset_label
=
[
'WIKI'
,
'LASTFM'
,
'WikiTalk'
,
'StackOverflow'
,
'GDELT'
]
table_ap_tgn
=
[
[
0.9827
,
0.8023
,
0.9611
,
0.9574
,
0.9837
],
#tgl-1
[
0.9808
,
0.7820
,
0.9632
,
0.9547
,
0.9770
],
#tgl-4
[
0.9709
,
0.8277
,
0.9552
,
0.9548
,
0.9832
],
#mspipe
[
0.9682
,
0.7683
,
0.9355
,
0.9273
,
0.9860
],
#
[
0.9770
,
0.8998
,
0.9765
,
0.9760
,
0.9866
],
[
0.9703
,
0.7846
,
0.9361
,
0.9216
,
0.9867
],
[
0.9648
,
0.7419
,
0.9301
,
0.9328
,
0.9860
],
[
0.9768
,
0.9329
,
0.9800
,
0.9792
,
0.9897
],
[
0.9698
,
0.7707
,
0.9301
,
0.9111
,
0.9868
],
[
0.9624
,
0.7713
,
0.9234
,
0.9074
,
0.9871
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.9699
,
0.7868
,
0.9257
,
0.9090
,
0.9870
],
[
0.9646
,
0.6999
,
0.8997
,
0.9031
,
0.9862
],
[
0.9659
,
0.9498
,
0.9816
,
0.9830
,
0.9916
]
]
table_ap_jodie
=
[
[
0.9080
,
0.6202
,
0.9578
,
0.9620
,
0.9831
],
#tgl-1
[
0.8586
,
0.6116
,
0.9463
,
0.9494
,
0.9825
],
#tgl-4
[
0.785868
,
0.4998
,
0.9123
,
0.8856
,
0.9831
],
[
0.9004
,
0.6182
,
0.9541
,
0.9612
,
0.9901
],
[
0.7735
,
0.5566
,
0.8252
,
0.8055
,
0.9812
],
[
0.8769
,
0.6131
,
0.9495
,
0.9580
,
0.9914
],
[
0.7747
,
0.5511
,
0.7983
,
0.7806
,
0.9811
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.7691
,
0.5427
,
0.7858
,
0.7684
,
0.9810
],
[
0.8548
,
0.6044
,
0.9418
,
0.9540
,
0.9908
]
]
table_ap_apan
=
[
[
0.9415
,
0.5659
,
0.9043
,
0.8752
,
0.9612
],
[
0.8834
,
0.5711
,
0.9136
,
0.8886
,
0.9547
],
[
0.9124
,
0.5975
,
0.9147
,
0.8615
,
0.9361
],
[
0.9190
,
0.6741
,
0.9298
,
0.9400
,
0.9839
],
[
0.6752
,
0.5509
,
0.7030
,
0.6824
,
0.7252
],
[
0.8982
,
0.6268
,
0.9398
,
0.9402
,
0.9876
],
[
0.6011
,
0.5281
,
0.6402
,
0.6340
,
0.6012
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.5449
,
0.5156
,
0.6004
,
0.5898
,
0.6163
],
[
0.8658
,
0.6682
,
0.9419
,
0.9522
,
0.9880
]
]
table_train_tgn
=
[
[
1.7701
,
14.7205
,
88.1896
,
938.8964
,
2001.318524
],
[
0.6651
,
6.185821226
,
35.9636
,
328.1245
,
854.80837
],
[
0.8258
,
3.7705
,
15.3471
,
141.1648
,
465.2264
],
[
0.4001
,
3.5691
,
16.5933
,
167.2491
,
445.4328
],
[
0.6327
,
4.4322
,
14.7757
,
132.7098208
,
537.1398258
],
[
0.7085
,
2.2837
,
8.6870
,
76.9830
,
250.0945
],
[
0.2992
,
2.0908
,
9.0872
,
97.8960
,
238.3315
],
[
0.2772
,
2.3960
,
7.1694
,
72.0319
,
251.7710
],
[
0.598028193
,
1.708853126
,
6.600371795
,
59.73907948
,
172.2328092
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
table_train_jodie
=
[
[
0.7056
,
5.9280
,
16.0147
,
134.6145
,
376.2186
],
[
0.2494
,
2.0697
,
6.4145
,
52.0684
,
147.3938
],
[
0.5604
,
2.0900
,
4.5566
,
33.3442
,
84.38458259
],
[
0.3246
,
2.5305
,
5.7808
,
48.8147
,
148.8899
],
[
0.518362013
,
1.148036031
,
2.871488156
,
19.42606455
,
48.72649877
],
[
0.1708
,
1.3234
,
3.1610
,
26.3006
,
78.1362
],
[
0.495687011
,
0.929189181
,
2.223804102
,
15.14540188
,
36.91594124
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
table_train_apan
=
[
[
1.5406
,
13.5400
,
118.8137
,
1032.4779
,
1609.9425
],
[
0.6653
,
5.6160
,
41.4232
,
368.6106
,
1045.6425
],
[
0.7319
,
2.9211
,
25.9092
,
238.7242
,
643.3855
],
[
0.4428
,
3.3608
,
11.3518
,
96.5700
,
244.8932
],
[
0.6765
,
1.9470
,
14.0065
,
123.8830
,
331.8195
],
[
0.2521
,
2.0037
,
8.5989
,
74.9677
,
139.2487
],
[
0.870823467
,
2.671342006
,
14.96090715
,
99.97286532
,
233.0780075
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
table_eval_tgn
=
[
[
0.7499
,
4.8555
,
48.0304
,
505.1028
,
824.7955338
],
[
0.2636
,
2.09782438
,
15.9930
,
162.6785
,
412.4661068
],
[
0.5704
,
1.2416
,
6.4172
,
68.7693
,
117.9146
],
[
0.2692
,
2.0977
,
13.2586
,
163.4167
,
152.3379
],
[
0.1166
,
0.8080
,
3.0650
,
28.00140551
,
115.4568481
],
[
0.5747
,
0.8989
,
4.0630
,
36.5432
,
61.79147074
],
[
0.2714
,
2.0767
,
11.6519
,
143.2905
,
145.8947
],
[
0.0627
,
0.6212
,
2.5015
,
23.1395
,
80.9373
],
[
0.555030217
,
0.779903646
,
3.302122216
,
27.71362416
,
42.97800901
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
table_eval_jodie
=
[
[
0.2091
,
1.6959
,
5.9168
,
48.6228
,
129.0989
],
[
0.0719
,
0.6373
,
2.1611
,
18.0129
,
50.1390
],
[
0.5078
,
0.4018
,
2.7975
,
15.9332
,
28.17862575
],
[
0.0633
,
0.5023
,
1.1726
,
9.4948
,
42.8286
],
[
0.502160359
,
0.669311229
,
2.076292474
,
9.030235369
,
16.12763978
],
[
0.0377
,
0.3370
,
1.2482
,
10.2044
,
34.6001
],
[
0.525156487
,
0.617744214
,
1.698658919
,
6.896312807
,
12.20505439
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
table_eval_apan
=
[
[
0.5907
,
4.8906
,
37.4472
,
332.8560
,
516.5578
],
[
0.2405
,
1.9909
,
11.9007
,
105.6109
,
298.5957
],
[
0.6249
,
1.4006
,
8.5734
,
66.5308
,
161.1591
],
[
0.0819
,
0.6588
,
2.2410
,
19.3600
,
100.0944
],
[
0.7001
,
1.1295
,
4.9139
,
38.4971
,
68.67845647
],
[
0.0709
,
0.8897
,
4.9156
,
41.6653
,
114.7303
],
[
0.860863774
,
1.248189147
,
4.57928021
,
30.29166182
,
69.67092261
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
,
0.0000
],
]
def
get_data
(
model
,
data
,
method
,
part
,
mode
):
table
=
globals
()[
'table_{}_{}'
.
format
(
mode
,
model
.
lower
())]
index_name
=
method
+
'-{}'
.
format
(
part
)
if
model
==
'TGN'
:
index
=
table_label
.
index
(
index_name
)
if
index_name
in
table_label
else
-
1
if
index
==
-
1
:
return
None
else
:
#print(index,dataset_label.index(data),method)
return
table
[
index
][
dataset_label
.
index
(
data
)]
else
:
index
=
table_label_no
.
index
(
index_name
)
if
index_name
in
table_label_no
else
-
1
if
index
==
-
1
:
return
None
else
:
return
table
[
index
][
dataset_label
.
index
(
data
)]
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
# 读取文件内容
import
os
probability_values
=
[
0.1
]
#][0.1,0.05,0.01,0]
data_values
=
[
'WIKI'
,
'LASTFM'
,
'WikiTalk'
,
'StackOverflow'
,
'GDELT'
]
# 存储从文件中读取的数据
seed
=
[
'13357'
,
'12347'
,
'53473'
,
'54763'
,
'12347'
,
'63377'
,
'53473'
,
'54763'
]
partition
=
'ours_shared'
import
math
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
def
read_ours
(
file
,
data
=
None
):
test_ap_list
=
[]
time_list
=
[]
prefix
=
'val ap:'
cntv
=
0
cntt
=
0
max_val_ap
=
0
test_ap
=
0
final_test_ap
=
0
final_average_time
=
0
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
#if line.find('Epoch 50:')!=-1 and (data not in ['WIKI','LASTFM']):
# break
if
line
.
find
(
prefix
)
!=-
1
:
pos
=
line
.
find
(
prefix
)
+
len
(
prefix
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
val_ap
=
float
(
line
[
pos
:
posr
])
pos
=
line
.
find
(
"test ap "
)
+
len
(
"test ap "
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
_test_ap
=
float
(
line
[
pos
:
posr
])
if
cntv
==
0
:
test_ap_list
.
append
(
_test_ap
)
cntv
=
(
cntv
+
1
)
%
4
if
(
val_ap
>
max_val_ap
):
max_val_ap
=
val_ap
final_test_ap
=
_test_ap
if
line
.
find
(
'prep time'
)
!=-
1
:
pos
=
line
.
find
(
'prep time'
)
+
len
(
'prep time:'
)
posr
=
line
.
find
(
's'
,
pos
)
_iter_time
=
float
(
line
[
pos
:
posr
])
#print(_iter_time,cntt)
if
len
(
time_list
)
>
0
:
time_list
.
append
(
_iter_time
+
time_list
[
-
1
])
#print(time_list[-1])
else
:
time_list
.
append
(
_iter_time
)
#print(cntt)
#elif line.find('avg_time ')!=-1:
# pl = line.find('avg_time ') + len('avg_time ')
# pr = line.find(' test time')
# train_time.append(float(line[pl:pr]))
# test_time.append(float(line[pr+len(' test time '):]))
#print(line)
# ap_list.append(test_ap)
#total_communication.append(average(_total_communication))
#shared_synchronize.append(average(_shared_synchronize))
#elif line.find('remote node number tensor([')!=-1:
# pl = line.find('remote node number tensor([')+len('remote node number tensor([')
# pr = line.find('])',pl)
# _total_communication.append(int(line[pl:pr]))
#if(p==0):
#print(file)
#print(line)
#elif line.find('shared comm tensor([')!=-1:
# pl = line.find('shared comm tensor([')+len('shared comm tensor([')
# pr = line.find('])',pl)
# _shared_synchronize.append(int(line[pl:pr]))
return
test_ap_list
,
time_list
def
check_ours
(
file
):
#print(file)
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
if
line
.
find
(
'avg_time '
)
!=-
1
:
return
True
return
False
def
read_dist_tgl
(
file
):
test_ap_list
=
[]
time_list
=
[]
final_test_ap
=
0
final_average_time
=
0
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
#if line.find('Epoch 50:')!=-1:
# break
if
line
.
find
(
'test AP'
)
!=-
1
:
posl
=
line
.
find
(
'test AP'
)
+
len
(
'test AP:'
)
posr
=
line
.
find
(
' '
,
posl
)
ap
=
float
(
line
[
posl
:
posr
])
#:0.784809 '
test_ap_list
.
append
(
ap
)
if
line
.
find
(
'train loss'
)
!=-
1
:
posl
=
line
.
find
(
'train time:'
)
+
len
(
'train time:'
)
#posr = line.find('s',posl)
time
=
float
(
line
[
posl
:])
#print(time)
if
len
(
time_list
)
>
0
:
time_list
.
append
(
time
+
time_list
[
-
1
])
else
:
time_list
.
append
(
time
)
return
test_ap_list
[:
len
(
time_list
)],
time_list
def
check_dist_tgl
(
file
):
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
if
line
.
find
(
'avg train time '
)
!=-
1
:
return
True
return
False
def
read_tgl
(
file
):
test_ap_list
=
[]
time_list
=
[]
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
#if line.find('Epoch 50:')!=-1:
# break
if
line
.
find
(
'test ap'
)
!=-
1
:
posl
=
line
.
find
(
'test ap'
)
+
len
(
'test ap:'
)
posr
=
line
.
find
(
' '
,
posl
)
ap
=
float
(
line
[
posl
:
posr
])
#:0.784809 '
test_ap_list
.
append
(
ap
)
#print(line)
if
line
.
find
(
'total time'
)
!=-
1
:
#print(line)
posl
=
line
.
find
(
'total time'
)
+
len
(
'total time:'
)
posr
=
line
.
find
(
's'
,
posl
)
time
=
float
(
line
[
posl
:
posr
])
if
len
(
time_list
)
>
0
:
time_list
.
append
(
time
+
time_list
[
-
1
])
#print(len(time_list))
else
:
time_list
.
append
(
time
)
return
test_ap_list
[:
len
(
time_list
)],
time_list
def
check_tgl
(
file
):
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
if
line
.
find
(
'avg time'
)
!=-
1
:
return
True
return
False
def
check_mspipe
(
file
):
#print(file)
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
if
line
.
find
(
'Avg epoch time'
)
!=-
1
:
return
True
return
False
#Best val AP:
#Avg epoch time
def
read_mspipe
(
file
):
test_ap_list
=
[]
time_list
=
[]
if
os
.
path
.
exists
(
file
):
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
#if line.find('Epoch 50:')!=-1:
# break
if
line
.
find
(
'Test ap'
)
!=-
1
:
posl
=
line
.
find
(
'Test ap'
)
+
len
(
'Test ap:'
)
posr
=
line
.
find
(
' '
,
posl
)
ap
=
float
(
line
[
posl
:
posr
])
#:0.784809 '
test_ap_list
.
append
(
ap
)
if
line
.
find
(
'Train time '
)
!=-
1
:
posl
=
line
.
find
(
'Train time '
)
+
len
(
'Train time '
)
posr
=
line
.
find
(
's'
,
posl
)
time
=
float
(
line
[
posl
:
posr
])
if
len
(
time_list
)
>
0
:
time_list
.
append
(
time
+
time_list
[
-
1
])
else
:
time_list
.
append
(
time
)
return
test_ap_list
,
time_list
def
draw_table
(
dict
,
model
,
part
):
if
model
==
'TGN'
and
part
==
'4'
:
method_list
=
[
'TGL'
,
'DistTGL'
,
'MSPipe'
,
'Ours'
]
elif
model
!=
'TGN'
and
part
==
'4'
:
method_list
=
[
'TGL'
,
'MSPipe'
,
'Ours'
]
elif
model
==
'TGN'
:
method_list
=
[
'DistTGL'
,
'MSPipe'
,
'Ours'
]
else
:
method_list
=
[
'MSPipe'
,
'Ours'
]
for
id
,
method
in
enumerate
(
method_list
):
char_list
=
[
'&'
]
if
(
id
!=
0
):
char_list
+=
'&'
char_list
+=
method
for
d
in
[
'WIKI'
,
'LASTFM'
,
'WikiTalk'
,
'StackOverflow'
,
'GDELT'
]:
char_list
+=
[
'&'
,
str
(
f
"{dict[(model,part,d,method)]:.2f}"
)]
speed_up
=
dict
[(
model
,
part
,
d
,
method
)]
/
dict
[(
model
,
part
,
d
,
'Ours'
)]
char_list
+=
[
'&'
,
str
(
f
"{speed_up:.2f}"
),
'x'
]
print
(
'{}'
.
format
(
''
.
join
(
char_list
)))
print
(
'
\n
'
)
data
=
[
'WIKI'
,
'REDDIT'
,
'LASTFM'
,
'WikiTalk'
]
parts
=
[
'8'
]
models
=
[
'TGN_large'
]
seed
=
[
'12347'
,
'12357'
,
'13357'
,
'53473'
,
'63377'
,
'63457'
,
''
]
def
find_ours_file
(
d
,
model
,
neighbor
):
path
=
'result_neighbor/{}/{}/'
.
format
(
d
,
model
)
file_name
=
path
+
'1-ours-0.01-local--recent-{}.out'
.
format
(
neighbor
)
print
(
neighbor
,
file_name
)
if
check_ours
(
file_name
):
return
file_name
return
None
dict
=
{}
color_map
=
{
'TGL'
:
'blue'
,
'DistTGL'
:
'green'
,
'MSPipe'
:
'red'
,
'Ours'
:
'Orange'
}
for
model
in
models
:
plt
.
clf
()
fig
,
axes
=
plt
.
subplots
(
1
,
len
(
data
),
figsize
=
(
12
,
2
),
sharey
=
False
)
for
axx
,
data_item
in
enumerate
(
data
):
ax
=
axes
[
axx
]
d
=
data_item
for
neighbor
in
[
'2'
,
'4'
,
'8'
,
'16'
]:
print
(
neighbor
)
ours_file
=
find_ours_file
(
d
,
model
,
neighbor
)
ours_file
=
read_ours
(
ours_file
,
d
)
#print(ours_file)
ap_list
=
torch
.
tensor
(
ours_file
[
0
])
line
,
=
ax
.
plot
(
range
(
1
,
len
(
ours_file
[
0
])
+
1
),
ours_file
[
0
],
label
=
neighbor
)
ax
.
set_xlabel
(
'Training Epoch'
,
fontsize
=
12
)
ax
.
set_ylabel
(
'Test AP'
,
fontsize
=
12
)
ax
.
set_title
(
'{}'
.
format
(
d
))
if
d
==
'LASTFM'
:
ax
.
set_ylim
([
0.6
,
0.85
])
elif
d
==
'WikiTalk'
or
d
==
'StackOverflow'
:
if
model
==
'TGN_large'
:
ax
.
set_xlim
([
0
,
20
])
ax
.
set_ylim
([
0.95
,
1
])
elif
model
==
'APAN'
:
ax
.
set_xlim
([
0
,
20
])
ax
.
set_ylim
([
0.6
,
1
])
else
:
ax
.
set_xlim
([
0
,
20
])
ax
.
set_ylim
([
0.7
,
1
])
else
:
if
model
==
'TGN_large'
:
if
d
==
'WIKI'
:
ax
.
set_ylim
([
0.97
,
0.99
])
elif
d
==
'REDDIT'
:
ax
.
set_ylim
([
0.98
,
0.99
])
elif
model
==
'APAN'
:
pass
#ax.set_ylim([0.4,1])
else
:
ax
.
set_ylim
([
0.9
,
1
])
ax
.
grid
(
True
,
linestyle
=
'--'
,
color
=
'gray'
,
linewidth
=
0.5
)
if
axx
==
0
:
fig
.
legend
(
fontsize
=
11
,
loc
=
'upper center'
,
ncol
=
4
)
#print('{} {} {} {} {} {} {}\n'.format(model,d,part,tgl_time,disttgl_time,mspipe_time,ours_time))
#print
#handles = {line.get_label(): line for line in lines.values()}.values()
plt
.
tight_layout
(
rect
=
[
0
,
0
,
1
,
0.9
])
plt
.
savefig
(
'convergence.pdf'
.
format
(
model
))
examples/train_boundery.py
View file @
d7bc324c
...
@@ -353,7 +353,7 @@ def main():
...
@@ -353,7 +353,7 @@ def main():
print
(
'dim_node {} dim_edge {}
\n
'
.
format
(
gnn_dim_node
,
gnn_dim_edge
))
print
(
'dim_node {} dim_edge {}
\n
'
.
format
(
gnn_dim_node
,
gnn_dim_edge
))
avg_time
=
0
avg_time
=
0
if
use_cuda
:
if
use_cuda
:
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
)
.
cuda
()
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
,
num_node
=
graph
.
num_nodes
,
num_edge
=
graph
.
num_edges
)
.
cuda
()
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
else
:
else
:
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
)
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
)
...
...
linear_attention.md
0 → 100644
View file @
d7bc324c
$h_u(t) =
\s
um_{(v,
\t
au)
\i
n N_u(t)}
\p
hi(q_u)^T
\p
hi(k_v)x_v $
in special,
$h_u(t) =
\s
um_{(v,
\t
au)
\i
n N_u(t)}
\t
ext{softmax}(q_u)^T
\t
ext{edgesoftmax}(k_v)x_v $
$w_u(t^-) =
\s
um_{(v,
\t
au)
\i
n N_u(t^-)}
\p
hi(k_v,t^-)x_v$
$
\p
hi(k_v,t^-) = e^{k_v-edgemax(k_v)} / edge sum (exp(k_v-edge max(k_v)))$
$w_u(t) =
\s
um_{(v,
\t
au)
\i
n N_u(t)}
\p
hi(k_v,t)x_v$
In special,
$newmax = max(lastmax(k_v),max_{(v,
\t
au)
\i
n N_u(t)/N_u(t^-)} k_v)$
$newsum = lastsum(exp(k_v-lastmax(k_v)))
*
exp(lastmax(k_v)- newmax) +
\s
um_{(v,
\t
au)
\i
n N_u(t)/N_u(t^-)} e^{k_v-newmax}$
$h_u(t^-) =
\s
um_{(v,
\t
au)
\i
n N_u(t^-)}
\p
hi(q_u) ^T
\p
hi(k_v)x_v$
$h_u(t) =
\p
hi(q_u)^T
\{
\s
um_{(v,
\t
au)
\i
n N_u(t^-)}
\p
hi(k_v,t^-)
*lastsum(exp(k_v-lastmax(k_v)))exp(lastmax(k_v)-edgemax(k_v))/newsum *
x_v +
\s
um_{(v,
\t
au)
\i
n N_u(t)/N_u(t^-)} e^{k_v-newmax}/newsum
*
x_v
\}
$
$x_v = s_v +
\p
hi(wt+b)$
$
[
cos(w(t+\Delta t)+b),sin(wt+b+\Delta t)
]
=
[
cos(wt+b)cos(wt+b) -sin(w\Delta t)sin(w\Delta t) ,sin(wt+b)cos(w\Delta t)+ cos(wt+b)sin(w\Delta t)
]
$
starrygl/module/layers.py
View file @
d7bc324c
...
@@ -2,7 +2,7 @@ from os.path import abspath, join, dirname
...
@@ -2,7 +2,7 @@ from os.path import abspath, join, dirname
import
os
import
os
import
sys
import
sys
from
os.path
import
abspath
,
join
,
dirname
from
os.path
import
abspath
,
join
,
dirname
import
torch_scatter
from
starrygl.distributed.utils
import
DistIndex
from
starrygl.distributed.utils
import
DistIndex
sys
.
path
.
insert
(
0
,
join
(
abspath
(
dirname
(
__file__
))))
sys
.
path
.
insert
(
0
,
join
(
abspath
(
dirname
(
__file__
))))
import
torch
import
torch
...
@@ -10,7 +10,73 @@ import dgl
...
@@ -10,7 +10,73 @@ import dgl
import
math
import
math
import
numpy
as
np
import
numpy
as
np
from
starrygl.sample.count_static
import
time_count
as
tt
from
starrygl.sample.count_static
import
time_count
as
tt
class
AggregatorCache
():
def
__init__
(
self
,
node_num
,
edge_num
,
dim
,
num_head
,
kernel
=
'softmax'
):
#self.aggregator=torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.last_time = torch.zeros(node_num,dtype=torch.float,device=torch.device('cuda'))
#self.attention_list = torch.zeros(edge_num,dim,num_head,device=torch.device('cuda'))
#self.attention_max = torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.attention_sum = torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.last_kv = torch.zeros(node_num,dim,device=torch.device('cuda'))
self
.
last_attention_max
=
torch
.
zeros
(
node_num
,
dim
,
device
=
torch
.
device
(
'cuda'
))
self
.
last_attention_sum
=
torch
.
zeros
(
node_num
,
dim
,
device
=
torch
.
device
(
'cuda'
))
self
.
last_kv
=
torch
.
zeros
(
node_num
,
dim
,
dim
,
device
=
torch
.
device
(
'cuda'
))
self
.
kernel
=
kernel
def
get_historical_aggregator
(
self
,
b
,
_q
,
_k
,
_v
,
time_projection_matrix
):
with
torch
.
no_grad
():
srcid
=
b
.
srcdata
[
'ID'
][:
b
.
num_dst_nodes
()][
b
.
edges
()[
1
]]
#last_aggregation = self.aggregator[srcid]
last_attention_max
=
self
.
last_attention_max
[
srcid
]
last_attention_sum
=
self
.
last_attention_sum
[
srcid
]
last_kv
=
self
.
last_kv
[
srcid
]
if
self
.
kernel
==
'softmax'
:
increament_max
=
dgl
.
ops
.
copy_e_max
(
b
,
_k
)
new_attention_max
=
torch
.
max
(
last_attention_max
,
increament_max
)
new_attention_sum
=
last_attention_sum
*
torch
.
exp
(
last_attention_max
-
new_attention_max
)
+
dgl
.
ops
.
copy_e_sum
(
b
,
torch
.
exp
(
_k
-
new_attention_max
[
b
.
edges
()[
1
]]))
increament_kv
=
dgl
.
ops
.
copy_e_sum
(
b
,
torch
.
einsum
(
'nj,nk->njk'
,(
torch
.
exp
(
_k
-
new_attention_max
[
b
.
edges
()[
1
]])
/
new_attention_sum
[
b
.
edges
()[
1
]]),
_v
)
)
new_kv
=
last_kv
*
last_attention_sum
*
torch
.
exp
(
last_attention_max
-
new_attention_max
)
/
new_attention_sum
+
increament_kv
h_u
=
torch
.
einsum
(
'ni,nik->nk'
,
_q
,
new_kv
)
b
.
srcdata
[
'h'
]
=
h_u
with
torch
.
no_grad
():
unq_index
,
inv
=
torch
.
unique
(
b
.
srcdata
[
'ID'
][:
b
.
num_dst_nodes
()],
return_inverse
=
True
)
_
,
idx
=
torch_scatter
.
scatter_max
(
b
.
srcdata
[
'ts'
][:
b
.
num_dst_nodes
()],
inv
,
0
)
self
.
last_kv
=
new_kv
self
.
last_attention_max
=
new_attention_max
self
.
last_attention_sum
=
new_attention_sum
#calculate the max value for both new insert attention and older attention
#next_h = b.srcdata['h'][srcid]
#att_v_max_new = dgl.ops.copy_e_max(b,new_attention)
# delta_Q = (w_q(next_h-last_h))
# print(delta_Q,last_attention_max)
# historical_attention_max = delta_Q*(last_attention_max)
# print(att_v_max_new)
# max_QK = torch.max((att_v_max_new,historical_attention_max))
# historical_attention_sum = (delta_Q+1)*last_attention_sum/torch.exp(max_QK-historical_attention_max)
# last_aggregation = (delta_Q+1)*last_aggregation/torch.exp(max_QK-historical_attention_max)
# att_v_new = torch.exp(torch.exp(dgl.ops.e_sub_v(b,new_attention,max_QK)))
# att_v_sum = dgl.ops.copy_e_sum(b,att_v_new)+historical_attention_sum
# att_new = dgl.ops.e_div_v(b,att_v_new,torch.clamp_min(att_v_sum ,1))
# att = self.att_dropout(att_new)
# V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) + last_aggregation
# b.edata['v'] = V
# b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
# b.srcdata['h'] += last_aggregation
# with torch.no_grad():
# unq_index,inv = torch.unique(b.srcdata['ID'][:b.num_dst_nodes()],return_inverse = True)
# _,idx = torch_scatter.scatter_max(b.srcdata['ts'][:b.num_dst_nodes()],inv,0)
# self.aggregator[idx] = b.srcdata['h'][idx]#b.srcdata['h']
# self.attention_max = max_QK[idx]#torch.zeros(node_num,dim)
# self.attention_sum = att_v_sum[idx]#torch.zeros(node_num,dim)
# self.last_h = next_h #torch.zeros(node_num,dim)
#att_v_max = torch.max(att_v_max_new,last_attention_max)
#att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,new_attention,att_v_max))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max)+last_attention ,1))
class
TimeEncode
(
torch
.
nn
.
Module
):
class
TimeEncode
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
alpha
=
None
,
beta
=
None
,
parameter_requires_grad
:
bool
=
True
):
def
__init__
(
self
,
dim
,
alpha
=
None
,
beta
=
None
,
parameter_requires_grad
:
bool
=
True
):
...
@@ -174,11 +240,111 @@ class MixerMLP(torch.nn.Module):
...
@@ -174,11 +240,111 @@ class MixerMLP(torch.nn.Module):
self
.
block_padding
(
b
)
self
.
block_padding
(
b
)
#return x
#return x
# class TransformerAttentionLayer0(torch.nn.Module):
class
FastAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim_node_feat
,
dim_edge_feat
,
dim_time
,
num_head
,
dropout
,
att_dropout
,
dim_out
,
combined
=
False
,
num_node
=
None
,
num_edge
=
None
):
super
(
FastAttention
,
self
)
.
__init__
()
self
.
num_head
=
num_head
self
.
dim_node_feat
=
dim_node_feat
self
.
dim_edge_feat
=
dim_edge_feat
self
.
dim_time
=
dim_time
self
.
dim_out
=
dim_out
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
att_dropout
=
torch
.
nn
.
Dropout
(
att_dropout
)
self
.
att_act
=
torch
.
nn
.
LeakyReLU
(
0.2
)
self
.
combined
=
combined
if
dim_time
>
0
:
self
.
time_enc
=
TimeEncode
(
dim_time
)
if
combined
:
if
dim_node_feat
>
0
:
self
.
w_q_n
=
torch
.
nn
.
Linear
(
dim_node_feat
,
dim_out
)
self
.
w_k_n
=
torch
.
nn
.
Linear
(
dim_node_feat
,
dim_out
)
self
.
w_v_n
=
torch
.
nn
.
Linear
(
dim_node_feat
,
dim_out
)
if
dim_edge_feat
>
0
:
self
.
w_k_e
=
torch
.
nn
.
Linear
(
dim_edge_feat
,
dim_out
)
self
.
w_v_e
=
torch
.
nn
.
Linear
(
dim_edge_feat
,
dim_out
)
if
dim_time
>
0
:
self
.
w_q_t
=
torch
.
nn
.
Linear
(
dim_time
,
dim_out
)
self
.
w_k_t
=
torch
.
nn
.
Linear
(
dim_time
,
dim_out
)
self
.
w_v_t
=
torch
.
nn
.
Linear
(
dim_time
,
dim_out
)
else
:
if
dim_node_feat
+
dim_time
>
0
:
self
.
w_q
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_time
,
dim_out
)
self
.
w_k
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_edge_feat
+
dim_time
,
dim_out
)
self
.
w_v
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_edge_feat
+
dim_time
,
dim_out
)
self
.
w_out
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_out
,
dim_out
)
self
.
layer_norm
=
torch
.
nn
.
LayerNorm
(
dim_out
)
self
.
cache
=
AggregatorCache
(
num_node
,
num_edge
,
dim_out
,
num_head
)
def
forward
(
self
,
b
):
assert
(
self
.
dim_time
+
self
.
dim_node_feat
+
self
.
dim_edge_feat
>
0
)
self
.
device
=
b
.
device
if
b
.
num_edges
()
==
0
:
return
torch
.
zeros
((
b
.
num_dst_nodes
(),
self
.
dim_out
),
device
=
self
.
device
)
if
self
.
dim_time
>
0
:
time_feat
=
self
.
time_enc
(
b
.
edata
[
'dt'
])
zero_time_feat
=
self
.
time_enc
(
torch
.
zeros
(
b
.
num_dst_nodes
(),
dtype
=
torch
.
float32
,
device
=
self
.
device
))
if
self
.
dim_time
==
0
and
self
.
dim_node_feat
==
0
:
Q
=
torch
.
ones
((
b
.
num_edges
(),
self
.
dim_out
),
device
=
self
.
device
)
K
=
self
.
w_k
(
b
.
edata
[
'f'
])
V
=
self
.
w_v
(
b
.
edata
[
'f'
])
elif
self
.
dim_time
==
0
and
self
.
dim_edge_feat
==
0
:
Q
=
self
.
w_q
(
b
.
srcdata
[
'h'
][:
b
.
num_dst_nodes
()])[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]])
V
=
self
.
w_v
(
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]])
elif
self
.
dim_time
==
0
:
Q
=
self
.
w_q
(
b
.
srcdata
[
'h'
][:
b
.
num_dst_nodes
()])[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
b
.
edata
[
'f'
]],
dim
=
1
))
V
=
self
.
w_v
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
b
.
edata
[
'f'
]],
dim
=
1
))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edat['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edat['f']], dim=1))
elif
self
.
dim_node_feat
==
0
and
self
.
dim_edge_feat
==
0
:
Q
=
self
.
w_q
(
zero_time_feat
)[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
time_feat
)
V
=
self
.
w_v
(
time_feat
)
elif
self
.
dim_node_feat
==
0
:
Q
=
self
.
w_q
(
zero_time_feat
)[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
torch
.
cat
([
b
.
edata
[
'f'
],
time_feat
],
dim
=
1
))
V
=
self
.
w_v
(
torch
.
cat
([
b
.
edata
[
'f'
],
time_feat
],
dim
=
1
))
elif
self
.
dim_edge_feat
==
0
:
Q
=
self
.
w_q
(
torch
.
cat
([
b
.
srcdata
[
'h'
][:
b
.
num_dst_nodes
()],
zero_time_feat
],
dim
=
1
))[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
time_feat
],
dim
=
1
))
V
=
self
.
w_v
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
time_feat
],
dim
=
1
))
else
:
Q
=
self
.
w_q
(
torch
.
cat
([
b
.
srcdata
[
'h'
][:
b
.
num_dst_nodes
()],
zero_time_feat
],
dim
=
1
))[
b
.
edges
()[
1
]]
K
=
self
.
w_k
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
b
.
edata
[
'f'
],
time_feat
],
dim
=
1
))
V
=
self
.
w_v
(
torch
.
cat
([
b
.
srcdata
[
'h'
][
b
.
edges
()[
0
]],
b
.
edata
[
'f'
],
time_feat
],
dim
=
1
))
Q
=
torch
.
reshape
(
Q
,
(
Q
.
shape
[
0
],
self
.
num_head
,
-
1
))
K
=
torch
.
reshape
(
K
,
(
K
.
shape
[
0
],
self
.
num_head
,
-
1
))
V
=
torch
.
reshape
(
V
,
(
V
.
shape
[
0
],
self
.
num_head
,
-
1
))
#att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
#att = self.att_dropout(att)
#if self.no_projection:
Q
=
Q
.
softmax
(
dim
=-
1
)
att
=
torch
.
att_act
(
torch
.
sum
(
Q
*
dgl
.
ops
.
edge_softmax
(
b
,
K
,
dim
=
0
),
dim
=
2
))
#k_cumsum = K.sum(dim = -2)
#D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
#context = torch.einsum('...nd,...ne->...de', k, v)
#out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
V
=
torch
.
reshape
(
V
*
att
[:,
:,
None
],
(
V
.
shape
[
0
],
-
1
))
if
self
.
cache
is
None
:
b
.
edata
[
'v'
]
=
V
b
.
update_all
(
dgl
.
function
.
copy_e
(
'v'
,
'm'
),
dgl
.
function
.
sum
(
'm'
,
'h'
))
else
:
self
.
cache
.
get_historical_aggregator
(
b
,
att
,
self
.
w_q
,
V
)
if
self
.
dim_node_feat
!=
0
:
rst
=
torch
.
cat
([
b
.
dstdata
[
'h'
],
b
.
srcdata
[
'h'
][:
b
.
num_dst_nodes
()]],
dim
=
1
)
else
:
rst
=
b
.
dstdata
[
'h'
]
rst
=
self
.
w_out
(
rst
)
rst
=
torch
.
nn
.
functional
.
relu
(
self
.
dropout
(
rst
))
return
self
.
layer_norm
(
rst
)
class
TransfomerAttentionLayer
(
torch
.
nn
.
Module
):
class
TransfomerAttentionLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim_node_feat
,
dim_edge_feat
,
dim_time
,
num_head
,
dropout
,
att_dropout
,
dim_out
,
combined
=
False
):
def
__init__
(
self
,
dim_node_feat
,
dim_edge_feat
,
dim_time
,
num_head
,
dropout
,
att_dropout
,
dim_out
,
combined
=
False
,
num_node
=
None
,
num_edge
=
None
):
super
(
TransfomerAttentionLayer
,
self
)
.
__init__
()
super
(
TransfomerAttentionLayer
,
self
)
.
__init__
()
self
.
num_head
=
num_head
self
.
num_head
=
num_head
self
.
dim_node_feat
=
dim_node_feat
self
.
dim_node_feat
=
dim_node_feat
...
@@ -210,7 +376,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
...
@@ -210,7 +376,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
self
.
w_v
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_edge_feat
+
dim_time
,
dim_out
)
self
.
w_v
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_edge_feat
+
dim_time
,
dim_out
)
self
.
w_out
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_out
,
dim_out
)
self
.
w_out
=
torch
.
nn
.
Linear
(
dim_node_feat
+
dim_out
,
dim_out
)
self
.
layer_norm
=
torch
.
nn
.
LayerNorm
(
dim_out
)
self
.
layer_norm
=
torch
.
nn
.
LayerNorm
(
dim_out
)
self
.
cache
=
None
#AggregatorCache(num_node,num_edge,dim_out,num_head)
def
forward
(
self
,
b
):
def
forward
(
self
,
b
):
assert
(
self
.
dim_time
+
self
.
dim_node_feat
+
self
.
dim_edge_feat
>
0
)
assert
(
self
.
dim_time
+
self
.
dim_node_feat
+
self
.
dim_edge_feat
>
0
)
self
.
device
=
b
.
device
self
.
device
=
b
.
device
...
@@ -287,15 +453,21 @@ class TransfomerAttentionLayer(torch.nn.Module):
...
@@ -287,15 +453,21 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att_v_max = dgl.ops.copy_e_max(b,att_sum)
#att_v_max = dgl.ops.copy_e_max(b,att_sum)
#att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,att_sum,att_v_max))
#att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,att_sum,att_v_max))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att
=
dgl
.
ops
.
edge_softmax
(
b
,
self
.
att_act
(
torch
.
sum
(
Q
*
K
,
dim
=
2
)))
self
.
linear
=
True
att
=
self
.
att_dropout
(
att
)
if
self
.
linear
is
False
:
att
=
dgl
.
ops
.
edge_softmax
(
b
,
self
.
att_act
(
torch
.
sum
(
Q
*
K
,
dim
=
2
)))
att
=
self
.
att_dropout
(
att
)
else
:
Q
=
Q
.
softmax
(
dim
=-
1
)
K
=
dgl
.
ops
.
edge_softmax
(
b
,
K
)
att
=
self
.
att_act
(
torch
.
sum
(
Q
*
K
,
dim
=
2
))
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
#tt.weight_count_local+=to
rch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
attention_value
=
att
[:,:]
#
attention_value = att[:,:]
attention_delta_t
=
b
.
edata
[
'dt'
]
#
attention_delta_t = b.edata['dt']
attention_number
=
b
.
edata
[
'weight'
]
#
attention_number = b.edata['weight']
tt
.
insert_attention
(
attention_delta_t
,
attention_value
,
attention_number
)
#
tt.insert_attention(attention_delta_t,attention_value,attention_number)
V
=
torch
.
reshape
(
V
*
att
[:,
:,
None
],
(
V
.
shape
[
0
],
-
1
))
#V_local = V.clone()
#V_local = V.clone()
#V_remote = V.clone()
#V_remote = V.clone()
#V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
#V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
...
@@ -312,9 +484,14 @@ class TransfomerAttentionLayer(torch.nn.Module):
...
@@ -312,9 +484,14 @@ class TransfomerAttentionLayer(torch.nn.Module):
# b.edata['v'] = V*weight
# b.edata['v'] = V*weight
#else:
#else:
# weight = b.edata['weight'].reshape(-1,1)
# weight = b.edata['weight'].reshape(-1,1)
b
.
edata
[
'v'
]
=
V
#print(torch.sum(torch.sum(((V-V*weight)**2))))
#print(torch.sum(torch.sum(((V-V*weight)**2))))
b
.
update_all
(
dgl
.
function
.
copy_e
(
'v'
,
'm'
),
dgl
.
function
.
sum
(
'm'
,
'h'
))
if
self
.
cache
is
None
:
V
=
torch
.
reshape
(
V
*
att
[:,
:,
None
],
(
V
.
shape
[
0
],
-
1
))
b
.
edata
[
'v'
]
=
V
b
.
update_all
(
dgl
.
function
.
copy_e
(
'v'
,
'm'
),
dgl
.
function
.
sum
(
'm'
,
'h'
))
else
:
self
.
cache
.
get_historical_aggregator
(
b
,
Q
,
K
,
V
)
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
#tt.ssim_cnt += b.num_dst_nodes()
#tt.ssim_cnt += b.num_dst_nodes()
...
...
starrygl/module/modules.py
View file @
d7bc324c
...
@@ -69,7 +69,7 @@ class NegFixLayer(torch.autograd.Function):
...
@@ -69,7 +69,7 @@ class NegFixLayer(torch.autograd.Function):
class
GeneralModel
(
torch
.
nn
.
Module
):
class
GeneralModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim_node
,
dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
num_nodes
=
None
,
mailbox
=
None
,
combined
=
False
,
train_ratio
=
None
):
def
__init__
(
self
,
dim_node
,
dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
num_nodes
=
None
,
mailbox
=
None
,
combined
=
False
,
train_ratio
=
None
,
num_node
=
0
,
num_edge
=
0
):
super
(
GeneralModel
,
self
)
.
__init__
()
super
(
GeneralModel
,
self
)
.
__init__
()
self
.
dim_node
=
dim_node
self
.
dim_node
=
dim_node
self
.
dim_node_input
=
dim_node
self
.
dim_node_input
=
dim_node
...
@@ -110,10 +110,10 @@ class GeneralModel(torch.nn.Module):
...
@@ -110,10 +110,10 @@ class GeneralModel(torch.nn.Module):
self
.
layers
=
torch
.
nn
.
ModuleDict
()
self
.
layers
=
torch
.
nn
.
ModuleDict
()
if
gnn_param
[
'arch'
]
==
'transformer_attention'
:
if
gnn_param
[
'arch'
]
==
'transformer_attention'
:
for
h
in
range
(
sample_param
[
'history'
]):
for
h
in
range
(
sample_param
[
'history'
]):
self
.
layers
[
'l0h'
+
str
(
h
)]
=
TransfomerAttentionLayer
(
self
.
dim_node_input
,
dim_edge
,
gnn_param
[
'dim_time'
],
gnn_param
[
'att_head'
],
train_param
[
'dropout'
],
train_param
[
'att_dropout'
],
gnn_param
[
'dim_out'
],
combined
=
combined
)
self
.
layers
[
'l0h'
+
str
(
h
)]
=
TransfomerAttentionLayer
(
self
.
dim_node_input
,
dim_edge
,
gnn_param
[
'dim_time'
],
gnn_param
[
'att_head'
],
train_param
[
'dropout'
],
train_param
[
'att_dropout'
],
gnn_param
[
'dim_out'
],
combined
=
combined
,
num_node
=
num_node
,
num_edge
=
num_edge
)
for
l
in
range
(
1
,
gnn_param
[
'layer'
]):
for
l
in
range
(
1
,
gnn_param
[
'layer'
]):
for
h
in
range
(
sample_param
[
'history'
]):
for
h
in
range
(
sample_param
[
'history'
]):
self
.
layers
[
'l'
+
str
(
l
)
+
'h'
+
str
(
h
)]
=
TransfomerAttentionLayer
(
gnn_param
[
'dim_out'
],
dim_edge
,
gnn_param
[
'dim_time'
],
gnn_param
[
'att_head'
],
train_param
[
'dropout'
],
train_param
[
'att_dropout'
],
gnn_param
[
'dim_out'
],
combined
=
False
)
self
.
layers
[
'l'
+
str
(
l
)
+
'h'
+
str
(
h
)]
=
TransfomerAttentionLayer
(
gnn_param
[
'dim_out'
],
dim_edge
,
gnn_param
[
'dim_time'
],
gnn_param
[
'att_head'
],
train_param
[
'dropout'
],
train_param
[
'att_dropout'
],
gnn_param
[
'dim_out'
],
combined
=
False
,
num_node
=
num_node
,
num_edge
=
num_edge
)
elif
gnn_param
[
'arch'
]
==
'identity'
:
elif
gnn_param
[
'arch'
]
==
'identity'
:
self
.
gnn_param
[
'layer'
]
=
1
self
.
gnn_param
[
'layer'
]
=
1
for
h
in
range
(
sample_param
[
'history'
]):
for
h
in
range
(
sample_param
[
'history'
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment