Commit d7bc324c by zlj

increament mtgnn

parent abb7e9e8
labels = ['4', '8', '12', '16']
methods = ['TGL', 'MSPipe','DistTGL','Ours']
table_label = ['TGL-1','TGL-4','MSPipe-4','DistTGL-4','Ours-4','MSPipe-8','DistTGL-8','Ours-8','MSPipe-12','DistTGL-12','Ours-12','MSPipe-16','DistTGL-16','Ours-16']
table_label_no = ['TGL-1','TGL-4','MSPipe-4','Ours-4','MSPipe-8','Ours-8','MSPipe-12','Ours-12','MSPipe-16','Ours-16']
dataset_label = ['WIKI','LASTFM','WikiTalk','StackOverflow','GDELT']
table_ap_tgn = [
[0.9827, 0.8023, 0.9611, 0.9574, 0.9837],#tgl-1
[0.9808, 0.7820, 0.9632, 0.9547, 0.9770],#tgl-4
[0.9709, 0.8277, 0.9552, 0.9548, 0.9832],#mspipe
[0.9682, 0.7683, 0.9355, 0.9273, 0.9860],#
[0.9770, 0.8998, 0.9765, 0.9760, 0.9866],
[0.9703, 0.7846, 0.9361, 0.9216, 0.9867],
[0.9648, 0.7419, 0.9301, 0.9328, 0.9860],
[0.9768, 0.9329, 0.9800, 0.9792, 0.9897],
[0.9698, 0.7707, 0.9301, 0.9111, 0.9868],
[0.9624, 0.7713, 0.9234, 0.9074, 0.9871],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.9699, 0.7868, 0.9257, 0.9090, 0.9870],
[0.9646, 0.6999, 0.8997, 0.9031, 0.9862],
[0.9659, 0.9498, 0.9816, 0.9830, 0.9916]
]
table_ap_jodie = [
[0.9080, 0.6202, 0.9578, 0.9620, 0.9831],#tgl-1
[0.8586, 0.6116, 0.9463 ,0.9494 , 0.9825],#tgl-4
[0.785868, 0.4998, 0.9123, 0.8856, 0.9831],
[0.9004, 0.6182, 0.9541, 0.9612, 0.9901],
[0.7735, 0.5566, 0.8252, 0.8055, 0.9812],
[0.8769, 0.6131, 0.9495, 0.9580, 0.9914],
[0.7747, 0.5511, 0.7983, 0.7806, 0.9811],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.7691, 0.5427, 0.7858, 0.7684, 0.9810],
[0.8548, 0.6044, 0.9418, 0.9540, 0.9908]
]
table_ap_apan = [
[0.9415, 0.5659, 0.9043, 0.8752, 0.9612],
[0.8834, 0.5711, 0.9136, 0.8886, 0.9547],
[0.9124, 0.5975, 0.9147, 0.8615, 0.9361],
[0.9190, 0.6741, 0.9298, 0.9400, 0.9839],
[0.6752, 0.5509, 0.7030, 0.6824, 0.7252],
[0.8982, 0.6268, 0.9398, 0.9402, 0.9876],
[0.6011, 0.5281, 0.6402, 0.6340, 0.6012],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5449, 0.5156, 0.6004, 0.5898, 0.6163],
[0.8658, 0.6682, 0.9419, 0.9522, 0.9880]
]
table_train_tgn =[
[1.7701 ,14.7205 ,88.1896 ,938.8964 , 2001.318524],
[0.6651 ,6.185821226 ,35.9636 , 328.1245 , 854.80837],
[0.8258 ,3.7705 ,15.3471 , 141.1648 ,465.2264 ],
[0.4001 ,3.5691 ,16.5933 , 167.2491 , 445.4328 ],
[0.6327 ,4.4322 ,14.7757 , 132.7098208, 537.1398258],
[0.7085 ,2.2837 ,8.6870 ,76.9830 ,250.0945 ],
[0.2992 ,2.0908 ,9.0872 ,97.8960 ,238.3315 ],
[0.2772 ,2.3960 ,7.1694 ,72.0319 ,251.7710 ],
[0.598028193 ,1.708853126 ,6.600371795 ,59.73907948 ,172.2328092],
[0.0000 ,0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 ,0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 ,0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 ,0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 ,0.0000 ,0.0000 ,0.0000 ,0.0000 ],
]
table_train_jodie =[
[0.7056 , 5.9280, 16.0147 , 134.6145 , 376.2186 ],
[0.2494 , 2.0697, 6.4145 ,52.0684 ,147.3938 ],
[0.5604 , 2.0900, 4.5566 ,33.3442 ,84.38458259],
[0.3246 , 2.5305, 5.7808 ,48.8147 ,148.8899 ],
[0.518362013, 1.148036031 ,2.871488156 ,19.42606455, 48.72649877],
[0.1708 , 1.3234 ,3.1610 ,26.3006 , 78.1362 ],
[0.495687011 ,0.929189181, 2.223804102 ,15.14540188, 36.91594124],
[0.0000, 0.0000 , 0.0000 ,0.0000 , 0.0000 ],
[0.0000, 0.0000 , 0.0000 ,0.0000 , 0.0000 ],
[0.0000, 0.0000 , 0.0000 ,0.0000 , 0.0000 ],
]
table_train_apan =[
[1.5406, 13.5400, 118.8137, 1032.4779 , 1609.9425 ],
[0.6653, 5.6160 , 41.4232 , 368.6106 , 1045.6425 ],
[0.7319, 2.9211 , 25.9092 , 238.7242 , 643.3855 ],
[0.4428, 3.3608 , 11.3518 , 96.5700 , 244.8932 ],
[0.6765, 1.9470 , 14.0065 , 123.8830 , 331.8195 ],
[0.2521, 2.0037 , 8.5989 ,74.9677 , 139.2487 ],
[0.870823467, 2.671342006, 14.96090715, 99.97286532, 233.0780075],
[0.0000 ,0.0000 ,0.0000 , 0.0000 , 0.0000 ],
[0.0000 ,0.0000 ,0.0000 , 0.0000 , 0.0000 ],
[0.0000 ,0.0000 ,0.0000 , 0.0000 , 0.0000 ],
]
table_eval_tgn = [
[0.7499 ,4.8555 ,48.0304 ,505.1028 ,824.7955338],
[0.2636 ,2.09782438 ,15.9930 ,162.6785 ,412.4661068],
[0.5704 ,1.2416 ,6.4172 ,68.7693 ,117.9146 ],
[0.2692 ,2.0977 ,13.2586 ,163.4167 ,152.3379 ],
[0.1166 ,0.8080 ,3.0650 ,28.00140551, 115.4568481],
[0.5747 ,0.8989 ,4.0630 ,36.5432 ,61.79147074],
[0.2714 ,2.0767 ,11.6519 ,143.2905 ,145.8947 ],
[0.0627 ,0.6212 ,2.5015 ,23.1395 ,80.9373 ],
[0.555030217 ,0.779903646 ,3.302122216 ,27.71362416 ,42.97800901],
[0.0000 , 0.0000 ,0.0000 ,0.0000 , 0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 , 0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 , 0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 , 0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 , 0.0000 ],
]
table_eval_jodie =[
[0.2091 ,1.6959 ,5.9168 ,48.6228 ,129.0989 ],
[0.0719 ,0.6373 ,2.1611 ,18.0129 ,50.1390 ],
[0.5078 ,0.4018 ,2.7975 ,15.9332 ,28.17862575],
[0.0633 ,0.5023 ,1.1726 ,9.4948 ,42.8286 ],
[0.502160359 ,0.669311229 ,2.076292474 ,9.030235369, 16.12763978],
[0.0377 ,0.3370 ,1.2482 , 10.2044 ,34.6001 ],
[0.525156487 ,0.617744214, 1.698658919 ,6.896312807 ,12.20505439],
[0.0000 ,0.0000 , 0.0000, 0.0000 , 0.0000 ],
[0.0000 ,0.0000 , 0.0000, 0.0000 , 0.0000 ],
[0.0000 ,0.0000 , 0.0000, 0.0000 , 0.0000 ],
]
table_eval_apan =[
[0.5907 , 4.8906 ,37.4472 ,332.8560, 516.5578 ],
[0.2405 , 1.9909 ,11.9007 ,105.6109, 298.5957 ],
[0.6249 , 1.4006 ,8.5734 ,66.5308 , 161.1591 ],
[0.0819 , 0.6588 ,2.2410 ,19.3600 , 100.0944 ],
[0.7001 , 1.1295 ,4.9139 ,38.4971 , 68.67845647],
[0.0709 , 0.8897 ,4.9156 ,41.6653 , 114.7303 ],
[0.860863774, 1.248189147 ,4.57928021 ,30.29166182, 69.67092261],
[0.0000 , 0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 ,0.0000 ],
[0.0000 , 0.0000 ,0.0000 ,0.0000 ,0.0000 ],
]
def get_data(model,data,method,part,mode):
table = globals()['table_{}_{}'.format(mode,model.lower())]
index_name = method+'-{}'.format(part)
if model =='TGN':
index = table_label.index(index_name) if index_name in table_label else -1
if index == -1:
return None
else:
#print(index,dataset_label.index(data),method)
return table[index][dataset_label.index(data)]
else:
index = table_label_no.index(index_name) if index_name in table_label_no else -1
if index == -1:
return None
else:
return table[index][dataset_label.index(data)]
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
import os
probability_values = [0.1]#][0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','StackOverflow','GDELT'] # 存储从文件中读取的数据
seed = ['13357','12347','53473','54763','12347','63377','53473','54763']
partition = 'ours_shared'
import math
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
def read_ours(file,data=None):
test_ap_list = []
time_list = []
prefix = 'val ap:'
cntv = 0
cntt = 0
max_val_ap=0
test_ap=0
final_test_ap = 0
final_average_time = 0
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
#if line.find('Epoch 50:')!=-1 and (data not in ['WIKI','LASTFM']):
# break
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if cntv == 0:
test_ap_list.append(_test_ap)
cntv = (cntv +1)%4
if(val_ap>max_val_ap):
max_val_ap = val_ap
final_test_ap = _test_ap
if line.find('prep time')!=-1:
pos = line.find('prep time')+len('prep time:')
posr = line.find('s',pos)
_iter_time = float(line[pos:posr])
#print(_iter_time,cntt)
if len(time_list) > 0:
time_list.append(_iter_time+time_list[-1])
#print(time_list[-1])
else:
time_list.append(_iter_time)
#print(cntt)
#elif line.find('avg_time ')!=-1:
# pl = line.find('avg_time ') + len('avg_time ')
# pr = line.find(' test time')
# train_time.append(float(line[pl:pr]))
# test_time.append(float(line[pr+len(' test time '):]))
#print(line)
# ap_list.append(test_ap)
#total_communication.append(average(_total_communication))
#shared_synchronize.append(average(_shared_synchronize))
#elif line.find('remote node number tensor([')!=-1:
# pl = line.find('remote node number tensor([')+len('remote node number tensor([')
# pr = line.find('])',pl)
# _total_communication.append(int(line[pl:pr]))
#if(p==0):
#print(file)
#print(line)
#elif line.find('shared comm tensor([')!=-1:
# pl = line.find('shared comm tensor([')+len('shared comm tensor([')
# pr = line.find('])',pl)
# _shared_synchronize.append(int(line[pl:pr]))
return test_ap_list,time_list
def check_ours(file):
#print(file)
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
if line.find('avg_time ')!=-1:
return True
return False
def read_dist_tgl(file):
test_ap_list = []
time_list = []
final_test_ap = 0
final_average_time = 0
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
#if line.find('Epoch 50:')!=-1:
# break
if line.find('test AP')!=-1:
posl = line.find('test AP')+len('test AP:')
posr = line.find(' ',posl)
ap = float(line[posl:posr])#:0.784809 '
test_ap_list.append(ap)
if line.find('train loss')!=-1:
posl = line.find('train time:')+len('train time:')
#posr = line.find('s',posl)
time = float(line[posl:])
#print(time)
if len(time_list) > 0:
time_list.append(time+time_list[-1])
else:
time_list.append(time)
return test_ap_list[:len(time_list)],time_list
def check_dist_tgl(file):
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
if line.find('avg train time ')!=-1:
return True
return False
def read_tgl(file):
test_ap_list = []
time_list = []
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
#if line.find('Epoch 50:')!=-1:
# break
if line.find('test ap')!=-1:
posl = line.find('test ap')+len('test ap:')
posr = line.find(' ',posl)
ap = float(line[posl:posr])#:0.784809 '
test_ap_list.append(ap)
#print(line)
if line.find('total time')!=-1:
#print(line)
posl = line.find('total time')+len('total time:')
posr = line.find('s',posl)
time = float(line[posl:posr])
if len(time_list) > 0:
time_list.append(time+time_list[-1])
#print(len(time_list))
else:
time_list.append(time)
return test_ap_list[:len(time_list)],time_list
def check_tgl(file):
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
if line.find('avg time')!=-1:
return True
return False
def check_mspipe(file):
#print(file)
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
if line.find('Avg epoch time')!=-1:
return True
return False
#Best val AP:
#Avg epoch time
def read_mspipe(file):
test_ap_list = []
time_list = []
if os.path.exists(file):
with open(file, 'r') as file:
for line in file:
#if line.find('Epoch 50:')!=-1:
# break
if line.find('Test ap')!=-1:
posl = line.find('Test ap')+len('Test ap:')
posr = line.find(' ',posl)
ap = float(line[posl:posr])#:0.784809 '
test_ap_list.append(ap)
if line.find('Train time ')!=-1:
posl = line.find('Train time ')+len('Train time ')
posr = line.find('s',posl)
time = float(line[posl:posr])
if len(time_list) > 0:
time_list.append(time+time_list[-1])
else:
time_list.append(time)
return test_ap_list,time_list
def draw_table(dict,model,part):
if model == 'TGN' and part == '4':
method_list = ['TGL','DistTGL','MSPipe','Ours']
elif model != 'TGN' and part == '4':
method_list = ['TGL','MSPipe','Ours']
elif model == 'TGN':
method_list = ['DistTGL','MSPipe','Ours']
else:
method_list = ['MSPipe','Ours']
for id,method in enumerate(method_list):
char_list = ['&']
if(id != 0):
char_list+='&'
char_list+=method
for d in ['WIKI','LASTFM','WikiTalk','StackOverflow','GDELT']:
char_list += ['&',str(f"{dict[(model,part,d,method)]:.2f}")]
speed_up = dict[(model,part,d,method)]/dict[(model,part,d,'Ours')]
char_list += ['&',str(f"{speed_up:.2f}"),'x']
print('{}'.format(''.join(char_list)))
print('\n')
data = ['WIKI','REDDIT','LASTFM','WikiTalk']
parts = ['8']
models=['TGN_large']
seed = ['12347','12357','13357','53473','63377','63457','']
def find_ours_file(d,model,neighbor):
path = 'result_neighbor/{}/{}/'.format(d,model)
file_name = path+'1-ours-0.01-local--recent-{}.out'.format(neighbor)
print(neighbor,file_name)
if check_ours(file_name):
return file_name
return None
dict = {}
color_map = {
'TGL': 'blue',
'DistTGL': 'green',
'MSPipe': 'red',
'Ours': 'Orange'
}
for model in models:
plt.clf()
fig, axes = plt.subplots(1, len(data), figsize=(12, 2), sharey=False)
for axx, data_item in enumerate(data):
ax = axes[axx]
d=data_item
for neighbor in ['2','4','8','16']:
print(neighbor)
ours_file = find_ours_file(d,model,neighbor)
ours_file = read_ours(ours_file,d)
#print(ours_file)
ap_list = torch.tensor(ours_file[0])
line, = ax.plot(range(1,len(ours_file[0])+1),ours_file[0],label=neighbor)
ax.set_xlabel('Training Epoch', fontsize=12)
ax.set_ylabel('Test AP', fontsize=12)
ax.set_title('{}'.format(d))
if d == 'LASTFM':
ax.set_ylim([0.6,0.85])
elif d=='WikiTalk' or d == 'StackOverflow':
if model == 'TGN_large':
ax.set_xlim([0,20])
ax.set_ylim([0.95,1])
elif model == 'APAN':
ax.set_xlim([0,20])
ax.set_ylim([0.6,1])
else:
ax.set_xlim([0,20])
ax.set_ylim([0.7,1])
else:
if model == 'TGN_large':
if d == 'WIKI':
ax.set_ylim([0.97,0.99])
elif d == 'REDDIT':
ax.set_ylim([0.98,0.99])
elif model == 'APAN':
pass
#ax.set_ylim([0.4,1])
else:
ax.set_ylim([0.9,1])
ax.grid(True, linestyle='--', color='gray', linewidth=0.5)
if axx == 0:
fig.legend(fontsize=11,loc='upper center',ncol=4)
#print('{} {} {} {} {} {} {}\n'.format(model,d,part,tgl_time,disttgl_time,mspipe_time,ours_time))
#print
#handles = {line.get_label(): line for line in lines.values()}.values()
plt.tight_layout(rect=[0,0,1,0.9])
plt.savefig('convergence.pdf'.format(model))
...@@ -353,7 +353,7 @@ def main(): ...@@ -353,7 +353,7 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge)) print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,num_node=graph.num_nodes, num_edge=graph.num_edges).cuda()
device = torch.device('cuda') device = torch.device('cuda')
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox)
......
$h_u(t) = \sum_{(v,\tau)\in N_u(t)} \phi(q_u)^T\phi(k_v)x_v $
in special,
$h_u(t) = \sum_{(v,\tau)\in N_u(t)} \text{softmax}(q_u)^T \text{edgesoftmax}(k_v)x_v $
$w_u(t^-) = \sum_{(v,\tau)\in N_u(t^-)}\phi(k_v,t^-)x_v$
$\phi(k_v,t^-) = e^{k_v-edgemax(k_v)} / edge sum (exp(k_v-edge max(k_v)))$
$w_u(t) = \sum_{(v,\tau)\in N_u(t)}\phi(k_v,t)x_v$
In special,
$newmax = max(lastmax(k_v),max_{(v,\tau)\in N_u(t)/N_u(t^-)} k_v)$
$newsum = lastsum(exp(k_v-lastmax(k_v)))*exp(lastmax(k_v)- newmax) + \sum_{(v,\tau)\in N_u(t)/N_u(t^-)} e^{k_v-newmax}$
$h_u(t^-) = \sum_{(v,\tau)\in N_u(t^-)} \phi(q_u) ^T\phi(k_v)x_v$
$h_u(t) = \phi(q_u)^T \{ \sum_{(v,\tau)\in N_u(t^-)} \phi(k_v,t^-)*lastsum(exp(k_v-lastmax(k_v)))exp(lastmax(k_v)-edgemax(k_v))/newsum *x_v + \sum_{(v,\tau)\in N_u(t)/N_u(t^-)} e^{k_v-newmax}/newsum *x_v\}$
$x_v = s_v + \phi(wt+b)$
$[cos(w(t+\Delta t)+b),sin(wt+b+\Delta t)] = [cos(wt+b)cos(wt+b) -sin(w\Delta t)sin(w\Delta t) ,sin(wt+b)cos(w\Delta t)+ cos(wt+b)sin(w\Delta t)]$
...@@ -2,7 +2,7 @@ from os.path import abspath, join, dirname ...@@ -2,7 +2,7 @@ from os.path import abspath, join, dirname
import os import os
import sys import sys
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
import torch_scatter
from starrygl.distributed.utils import DistIndex from starrygl.distributed.utils import DistIndex
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
import torch import torch
...@@ -10,7 +10,73 @@ import dgl ...@@ -10,7 +10,73 @@ import dgl
import math import math
import numpy as np import numpy as np
from starrygl.sample.count_static import time_count as tt from starrygl.sample.count_static import time_count as tt
class AggregatorCache():
def __init__(self,node_num,edge_num,dim,num_head,kernel='softmax'):
#self.aggregator=torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.last_time = torch.zeros(node_num,dtype=torch.float,device=torch.device('cuda'))
#self.attention_list = torch.zeros(edge_num,dim,num_head,device=torch.device('cuda'))
#self.attention_max = torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.attention_sum = torch.zeros(node_num,dim,device=torch.device('cuda'))
#self.last_kv = torch.zeros(node_num,dim,device=torch.device('cuda'))
self.last_attention_max = torch.zeros(node_num,dim,device = torch.device('cuda'))
self.last_attention_sum = torch.zeros(node_num,dim,device = torch.device('cuda'))
self.last_kv = torch.zeros(node_num,dim,dim,device = torch.device('cuda'))
self.kernel = kernel
def get_historical_aggregator(self,b,_q,_k,_v,time_projection_matrix):
with torch.no_grad():
srcid = b.srcdata['ID'][:b.num_dst_nodes()][b.edges()[1]]
#last_aggregation = self.aggregator[srcid]
last_attention_max = self.last_attention_max[srcid]
last_attention_sum = self.last_attention_sum[srcid]
last_kv = self.last_kv[srcid]
if self.kernel == 'softmax':
increament_max = dgl.ops.copy_e_max(b,_k)
new_attention_max = torch.max(last_attention_max,increament_max)
new_attention_sum = last_attention_sum*torch.exp(last_attention_max - new_attention_max) + dgl.ops.copy_e_sum(b,torch.exp(_k - new_attention_max[b.edges()[1]]))
increament_kv = dgl.ops.copy_e_sum(b,
torch.einsum('nj,nk->njk',(torch.exp(_k - new_attention_max[b.edges()[1]])/new_attention_sum[b.edges()[1]]),_v)
)
new_kv = last_kv*last_attention_sum * torch.exp(last_attention_max - new_attention_max)/new_attention_sum + increament_kv
h_u = torch.einsum('ni,nik->nk',_q,new_kv)
b.srcdata['h'] = h_u
with torch.no_grad():
unq_index,inv = torch.unique(b.srcdata['ID'][:b.num_dst_nodes()],return_inverse = True)
_,idx = torch_scatter.scatter_max(b.srcdata['ts'][:b.num_dst_nodes()],inv,0)
self.last_kv = new_kv
self.last_attention_max = new_attention_max
self.last_attention_sum = new_attention_sum
#calculate the max value for both new insert attention and older attention
#next_h = b.srcdata['h'][srcid]
#att_v_max_new = dgl.ops.copy_e_max(b,new_attention)
# delta_Q = (w_q(next_h-last_h))
# print(delta_Q,last_attention_max)
# historical_attention_max = delta_Q*(last_attention_max)
# print(att_v_max_new)
# max_QK = torch.max((att_v_max_new,historical_attention_max))
# historical_attention_sum = (delta_Q+1)*last_attention_sum/torch.exp(max_QK-historical_attention_max)
# last_aggregation = (delta_Q+1)*last_aggregation/torch.exp(max_QK-historical_attention_max)
# att_v_new = torch.exp(torch.exp(dgl.ops.e_sub_v(b,new_attention,max_QK)))
# att_v_sum = dgl.ops.copy_e_sum(b,att_v_new)+historical_attention_sum
# att_new = dgl.ops.e_div_v(b,att_v_new,torch.clamp_min(att_v_sum ,1))
# att = self.att_dropout(att_new)
# V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) + last_aggregation
# b.edata['v'] = V
# b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
# b.srcdata['h'] += last_aggregation
# with torch.no_grad():
# unq_index,inv = torch.unique(b.srcdata['ID'][:b.num_dst_nodes()],return_inverse = True)
# _,idx = torch_scatter.scatter_max(b.srcdata['ts'][:b.num_dst_nodes()],inv,0)
# self.aggregator[idx] = b.srcdata['h'][idx]#b.srcdata['h']
# self.attention_max = max_QK[idx]#torch.zeros(node_num,dim)
# self.attention_sum = att_v_sum[idx]#torch.zeros(node_num,dim)
# self.last_h = next_h #torch.zeros(node_num,dim)
#att_v_max = torch.max(att_v_max_new,last_attention_max)
#att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,new_attention,att_v_max))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max)+last_attention ,1))
class TimeEncode(torch.nn.Module): class TimeEncode(torch.nn.Module):
def __init__(self, dim, alpha = None, beta= None, parameter_requires_grad: bool = True): def __init__(self, dim, alpha = None, beta= None, parameter_requires_grad: bool = True):
...@@ -174,11 +240,111 @@ class MixerMLP(torch.nn.Module): ...@@ -174,11 +240,111 @@ class MixerMLP(torch.nn.Module):
self.block_padding(b) self.block_padding(b)
#return x #return x
# class TransformerAttentionLayer0(torch.nn.Module):
class FastAttention(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False,num_node=None,num_edge = None):
super(FastAttention, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
self.cache = AggregatorCache(num_node,num_edge,dim_out,num_head)
def forward(self,b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']],dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']],dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edat['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edat['f']], dim=1))
elif self.dim_node_feat == 0 and self.dim_edge_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(time_feat)
V = self.w_v(time_feat)
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()],zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat],dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat],dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()],zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'],time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'],time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
#att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
#att = self.att_dropout(att)
#if self.no_projection:
Q = Q.softmax(dim=-1)
att = torch.att_act(torch.sum(Q*dgl.ops.edge_softmax(b,K,dim=0),dim=2))
#k_cumsum = K.sum(dim = -2)
#D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
#context = torch.einsum('...nd,...ne->...de', k, v)
#out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
if self.cache is None:
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
else:
self.cache.get_historical_aggregator(b,att,self.w_q,V)
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class TransfomerAttentionLayer(torch.nn.Module): class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False): def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False,num_node=None,num_edge = None):
super(TransfomerAttentionLayer, self).__init__() super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head self.num_head = num_head
self.dim_node_feat = dim_node_feat self.dim_node_feat = dim_node_feat
...@@ -210,7 +376,7 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -210,7 +376,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out) self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out) self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out) self.layer_norm = torch.nn.LayerNorm(dim_out)
self.cache = None#AggregatorCache(num_node,num_edge,dim_out,num_head)
def forward(self, b): def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0) assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device self.device = b.device
...@@ -287,15 +453,21 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -287,15 +453,21 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att_v_max = dgl.ops.copy_e_max(b,att_sum) #att_v_max = dgl.ops.copy_e_max(b,att_sum)
#att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,att_sum,att_v_max)) #att_e_sub_max = torch.exp(dgl.ops.e_sub_v(b,att_sum,att_v_max))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1)) #att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2))) self.linear = True
att = self.att_dropout(att) if self.linear is False:
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
else:
Q = Q.softmax(dim=-1)
K = dgl.ops.edge_softmax(b,K)
att = self.att_act(torch.sum(Q*K,dim=2))
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2) #tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2) #tt.weight_count_local+=to rch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
attention_value = att[:,:] #attention_value = att[:,:]
attention_delta_t = b.edata['dt'] #attention_delta_t = b.edata['dt']
attention_number = b.edata['weight'] #attention_number = b.edata['weight']
tt.insert_attention(attention_delta_t,attention_value,attention_number) #tt.insert_attention(attention_delta_t,attention_value,attention_number)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
#V_local = V.clone() #V_local = V.clone()
#V_remote = V.clone() #V_remote = V.clone()
#V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0 #V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
...@@ -312,9 +484,14 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -312,9 +484,14 @@ class TransfomerAttentionLayer(torch.nn.Module):
# b.edata['v'] = V*weight # b.edata['v'] = V*weight
#else: #else:
# weight = b.edata['weight'].reshape(-1,1) # weight = b.edata['weight'].reshape(-1,1)
b.edata['v'] = V
#print(torch.sum(torch.sum(((V-V*weight)**2)))) #print(torch.sum(torch.sum(((V-V*weight)**2))))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h')) if self.cache is None:
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
else:
self.cache.get_historical_aggregator(b,Q,K,V)
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0'])) #tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1'])) #tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
#tt.ssim_cnt += b.num_dst_nodes() #tt.ssim_cnt += b.num_dst_nodes()
......
...@@ -69,7 +69,7 @@ class NegFixLayer(torch.autograd.Function): ...@@ -69,7 +69,7 @@ class NegFixLayer(torch.autograd.Function):
class GeneralModel(torch.nn.Module): class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None): def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None,num_node = 0, num_edge = 0):
super(GeneralModel, self).__init__() super(GeneralModel, self).__init__()
self.dim_node = dim_node self.dim_node = dim_node
self.dim_node_input = dim_node self.dim_node_input = dim_node
...@@ -110,10 +110,10 @@ class GeneralModel(torch.nn.Module): ...@@ -110,10 +110,10 @@ class GeneralModel(torch.nn.Module):
self.layers = torch.nn.ModuleDict() self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention': if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']): for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined) self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined, num_node=num_node, num_edge=num_edge)
for l in range(1, gnn_param['layer']): for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']): for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False) self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False, num_node=num_node, num_edge=num_edge)
elif gnn_param['arch'] == 'identity': elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1 self.gnn_param['layer'] = 1
for h in range(sample_param['history']): for h in range(sample_param['history']):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment