Commit 86a4611d by zlj

add APAN model

parent 24a069d6
...@@ -10,6 +10,7 @@ __pycache__/ ...@@ -10,6 +10,7 @@ __pycache__/
# Distribution / packaging # Distribution / packaging
.Python .Python
examples/all*
build/ build/
develop-eggs/ develop-eggs/
dist/ dist/
......
sampling: sampling:
- layer: 1 - layer: 1
neighbor: neighbor:
- 10 - 20
strategy: 'recent' strategy: 'recent'
prop_time: False prop_time: False
history: 1 history: 1
......
...@@ -263,7 +263,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -263,7 +263,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
int cal_cnt = 0; int cal_cnt = 0;
for(int cid = end_index-1;cid>=0;cid--){ for(int cid = end_index-1;cid>=0;cid--){
cal_cnt++; cal_cnt++;
if(cal_cnt > 2*fanout)break; if(cal_cnt > fanout)break;
int eid = tnb.eid[node][cid]; int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){ if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0); double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
......

43.7 KB | W: | H:

41.9 KB | W: | H:

examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
  • 2-up
  • Swipe
  • Onion skin

14.2 KB | W: | H:

14.2 KB | W: | H:

examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
  • 2-up
  • Swipe
  • Onion skin

16.9 KB | W: | H:

15.8 KB | W: | H:

examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -2,12 +2,12 @@ import matplotlib.pyplot as plt ...@@ -2,12 +2,12 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值 ssim_values = [0, 0.5, 1.0, 1.5, 2] # 假设这是你的 ssim 参数值
data_values = ['WikiTalk'] # 存储从文件中读取的数据 data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=8 partitions=4
topk=0.01 topk=0.01
mem='historical' mem='historical'
for data in data_values: for data in data_values:
...@@ -28,7 +28,7 @@ for data in data_values: ...@@ -28,7 +28,7 @@ for data in data_values:
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3]) comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
ap_list.append(ap) ap_list.append(ap)
comm_list.append(comm) comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
# 绘制柱状图 # 绘制柱状图
bar_width = 0.4 bar_width = 0.4
...@@ -40,12 +40,11 @@ for data in data_values: ...@@ -40,12 +40,11 @@ for data in data_values:
# 绘制柱状图 # 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width) plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图 # 绘制柱状图
plt.ylim([0.9,1])
plt.xticks([b for b in bars], ssim_values) plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP') plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}.png'.format(data)) plt.savefig('ssim_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width) plt.bar([b for b in bars], comm_list, width=bar_width)
...@@ -54,7 +53,7 @@ for data in data_values: ...@@ -54,7 +53,7 @@ for data in data_values:
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume') plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('comm_{}.png'.format(data)) plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
if partition == 'ours_shared': if partition == 'ours_shared':
...@@ -69,11 +68,12 @@ for data in data_values: ...@@ -69,11 +68,12 @@ for data in data_values:
val_ap = torch.tensor(torch.load(file)) val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0]) epoch = torch.arange(val_ap.shape[0])
#绘制曲线图 #绘制曲线图
print(val_ap)
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim)) plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('Val AP') plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True) # plt.grid(True)
plt.legend() plt.legend()
plt.savefig('{}_ssim_Convergence_rate.png'.format(data)) plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions))
plt.clf() plt.clf()
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# 定义数组变量 # 定义数组变量
addr="192.168.1.107" addr="192.168.1.107"
partition_params=("ours") partition_params=("dis_tgl" "ours" "metis" "random")
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="4" partitions="4"
...@@ -10,15 +10,17 @@ node_per="4" ...@@ -10,15 +10,17 @@ node_per="4"
nnodes="1" nnodes="1"
node_rank="0" node_rank="0"
probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0") probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0")
#sample_type_params=("recent") #"boundery_recent_decay" "boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#sample_type_params=("recent") sample_type_params=("recent")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("all_update") memory_type=("all_update" "local" "historical")
#"historical" "all_update") #"local" "historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0" "0.1" "0.2" "0.3" "0.4" ) shared_memory_ssim=("0" "0.5" "1.0" "1.5")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("DGraphFin" "WikiTalk") data_param=("LASTFM")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
# 创建输出目录 # 创建输出目录
...@@ -54,7 +56,7 @@ for data in "${data_param[@]}"; do ...@@ -54,7 +56,7 @@ for data in "${data_param[@]}"; do
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait wait
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait wait
fi fi
...@@ -64,21 +66,22 @@ for data in "${data_param[@]}"; do ...@@ -64,21 +66,22 @@ for data in "${data_param[@]}"; do
for pro in "${probability_params[@]}"; do for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do continue
if [ "$partition" = "ours" ]; then # for ssim in "${shared_memory_ssim[@]}"; do
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out & # if [ "$partition" = "ours" ]; then
wait # torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
fi # wait
done # fi
# done
elif [ "$mem" = "all_reduce" ]; then elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out& torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait wait
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait wait
fi fi
......
...@@ -79,6 +79,13 @@ parser.add_argument('--memory_type', default='all_update', type=str, metavar='W' ...@@ -79,6 +79,13 @@ parser.add_argument('--memory_type', default='all_update', type=str, metavar='W'
help='name of model') help='name of model')
#boundery_recent_uniform boundery_recent_decay #boundery_recent_uniform boundery_recent_decay
args = parser.parse_args() args = parser.parse_args()
if args.memory_type == 'all_local' or args.topk != '0':
train_cross_probability = 0
else:
train_cross_probability = 1
if args.memory_type == 'all_local':
args.sample_type = 'boundery_recent_uniform'
args.probability = 0
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
import time import time
...@@ -271,7 +278,7 @@ def main(): ...@@ -271,7 +278,7 @@ def main():
is_pipeline=True, is_pipeline=True,
use_local_feature = False, use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)), device = torch.device('cuda:{}'.format(local_rank)),
probability=1 if float(args.topk) == 0 else 0 probability=train_cross_probability
) )
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler, eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
......
...@@ -124,7 +124,7 @@ class HistoricalCache: ...@@ -124,7 +124,7 @@ class HistoricalCache:
def synchronize_shared_update(self,filter=None): def synchronize_shared_update(self,filter=None):
if self.last_shared_update_wait is None: if self.last_shared_update_wait is None:
return return None
handle0,handle1,shared_index,shared_data = self.last_shared_update_wait handle0,handle1,shared_index,shared_data = self.last_shared_update_wait
self.last_shared_update_wait = None self.last_shared_update_wait = None
handle0.wait() handle0.wait()
...@@ -132,16 +132,21 @@ class HistoricalCache: ...@@ -132,16 +132,21 @@ class HistoricalCache:
shared_data = torch.cat(shared_data,dim = 0) shared_data = torch.cat(shared_data,dim = 0)
shared_index = torch.cat(shared_index) shared_index = torch.cat(shared_index)
if(shared_data.shape[0] == 0): if(shared_data.shape[0] == 0):
return return None
shared_ts = shared_data[:,-1] len = self.local_historical_data.shape[1]
shared_data = shared_data[:,:-1] mail_ts = shared_data[:,-1]
mail_data = shared_data[:,len+1:-1]
shared_ts = shared_data[:,len]
shared_mem = shared_data[:,:len]
#print(shared_index) #print(shared_index)
unq_index,inv = torch.unique(shared_index,return_inverse = True) unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(shared_ts,inv,0)
#shared_ts = torch_scatter.scatter_mean(shared_ts,inv,0) #shared_ts = torch_scatter.scatter_mean(shared_ts,inv,0)
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0) #shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_data = shared_data[idx] shared_mem = shared_mem[idx]
shared_ts = shared_ts[idx] shared_ts = shared_ts[idx]
mail_data = mail_data[idx]
mail_ts = mail_ts[idx]
shared_index = unq_index shared_index = unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts)) #print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None: # if filter is not None:
...@@ -151,10 +156,10 @@ class HistoricalCache: ...@@ -151,10 +156,10 @@ class HistoricalCache:
# change /=change.max() # change /=change.max()
# change = 2*change - 1 # change = 2*change - 1
# filter.update(shared_index,change) # filter.update(shared_index,change)
self.local_historical_data[shared_index] = shared_data self.local_historical_data[shared_index] = shared_mem
self.local_ts[shared_index] = shared_ts self.local_ts[shared_index] = shared_ts
self.last_shared_update_wait = None self.last_shared_update_wait = None
return shared_index,shared_data,shared_ts return shared_index,shared_mem,shared_ts,mail_data,mail_ts
......
...@@ -136,52 +136,52 @@ class RNNMemeoryUpdater(torch.nn.Module): ...@@ -136,52 +136,52 @@ class RNNMemeoryUpdater(torch.nn.Module):
def empty_cache(self): def empty_cache(self):
pass pass
class TransformerMemoryUpdater(torch.nn.Module): # class TransformerMemoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param): # def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param):
super(TransformerMemoryUpdater, self).__init__() # super(TransformerMemoryUpdater, self).__init__()
self.memory_param = memory_param # self.memory_param = memory_param
self.dim_time = dim_time # self.dim_time = dim_time
self.att_h = memory_param['attention_head'] # self.att_h = memory_param['attention_head']
if dim_time > 0: # if dim_time > 0:
self.time_enc = TimeEncode(dim_time) # self.time_enc = TimeEncode(dim_time)
self.w_q = torch.nn.Linear(dim_out, dim_out) # self.w_q = torch.nn.Linear(dim_out, dim_out)
self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out) # self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out) # self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out)
self.att_act = torch.nn.LeakyReLU(0.2) # self.att_act = torch.nn.LeakyReLU(0.2)
self.layer_norm = torch.nn.LayerNorm(dim_out) # self.layer_norm = torch.nn.LayerNorm(dim_out)
self.mlp = torch.nn.Linear(dim_out, dim_out) # self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout']) # self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout']) # self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
self.last_updated_memory = None # self.last_updated_memory = None
self.last_updated_ts = None # self.last_updated_ts = None
self.last_updated_nid = None # self.last_updated_nid = None
def forward(self, mfg, param = None): # def forward(self, mfg, param = None):
for b in mfg: # for b in mfg:
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1)) # Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1)) # mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
if self.dim_time > 0: # if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1)) # time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
mails = torch.cat([mails, time_feat], dim=2) # mails = torch.cat([mails, time_feat], dim=2)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1)) # K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1)) # V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3)) # att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
att = torch.nn.functional.softmax(att, dim=1) # att = torch.nn.functional.softmax(att, dim=1)
att = self.att_dropout(att) # att = self.att_dropout(att)
rst = (att[:,:,:,None]*V).sum(dim=1) # rst = (att[:,:,:,None]*V).sum(dim=1)
rst = rst.reshape((rst.shape[0], -1)) # rst = rst.reshape((rst.shape[0], -1))
rst += b.srcdata['mem'] # rst += b.srcdata['mem']
rst = self.layer_norm(rst) # rst = self.layer_norm(rst)
rst = self.mlp(rst) # rst = self.mlp(rst)
rst = self.dropout(rst) # rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst) # rst = torch.nn.functional.relu(rst)
b.srcdata['h'] = rst # b.srcdata['h'] = rst
self.last_updated_memory = rst.detach().clone() # self.last_updated_memory = rst.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() # self.last_updated_nid = b.srcdata['ID'].detach().clone()
self.last_updated_ts = b.srcdata['ts'].detach().clone() # self.last_updated_ts = b.srcdata['ts'].detach().clone()
def empty_cache(self): # def empty_cache(self):
pass # pass
...@@ -301,7 +301,7 @@ class HistoricalMemeoryUpdater(torch.nn.Module): ...@@ -301,7 +301,7 @@ class HistoricalMemeoryUpdater(torch.nn.Module):
# #
## new_memory[~local_mask] = self.time_for_historical(torch.cat((history_mem,history_mem_ts),dim = 1)) ## new_memory[~local_mask] = self.time_for_historical(torch.cat((history_mem,history_mem_ts),dim = 1))
self.lasted_memory = memory.detach().clone() self.lasted_memory = memory.detach().clone()
if self.memory_param['combine_node_feature']: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat > 0: if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += memory b.srcdata['h'] += memory
...@@ -309,10 +309,49 @@ class HistoricalMemeoryUpdater(torch.nn.Module): ...@@ -309,10 +309,49 @@ class HistoricalMemeoryUpdater(torch.nn.Module):
b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h']) b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h'])
else: else:
b.srcdata['h'] = memory b.srcdata['h'] = memory
else:
b.srcdata['h'] = memory
tt.mem_update += tt.elapsed(t_s) tt.mem_update += tt.elapsed(t_s)
class TransformerMemoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param):
super(TransformerMemoryUpdater, self).__init__()
self.memory_param = memory_param
self.dim_time = dim_time
self.att_h = memory_param['attention_head']
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
self.w_q = torch.nn.Linear(dim_out, dim_out)
self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out)
self.att_act = torch.nn.LeakyReLU(0.2)
self.layer_norm = torch.nn.LayerNorm(dim_out)
self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
mails = torch.cat([mails, time_feat], dim=2)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
att = torch.nn.functional.softmax(att, dim=1)
att = self.att_dropout(att)
rst = (att[:,:,:,None]*V).sum(dim=1)
rst = rst.reshape((rst.shape[0], -1))
rst += b.srcdata['mem']
rst = self.layer_norm(rst)
rst = self.mlp(rst)
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
return rst
class AsyncMemeoryUpdater(torch.nn.Module): class AsyncMemeoryUpdater(torch.nn.Module):
def all_update_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func): def all_update_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True) self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
...@@ -339,13 +378,26 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -339,13 +378,26 @@ class AsyncMemeoryUpdater(torch.nn.Module):
def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func): def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None): def transformer_updater(self,b):
return self.ceil_updater(b)
def rnn_updater(self,b):
if self.dim_time > 0:
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem'])
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None):
super(AsyncMemeoryUpdater, self).__init__() super(AsyncMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat self.dim_node_feat = dim_node_feat
self.memory_param = memory_param self.memory_param = memory_param
self.dim_time = dim_time self.dim_time = dim_time
self.updater = updater(dim_in + dim_time, dim_hid) if memory_param['memory_update'] == 'transformer':
self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param)
self.updater = self.transformer_updater
else:
self.updater = updater(dim_in + dim_time, dim_hid)
self.updater = self.rnn_updater
self.last_updated_memory = None self.last_updated_memory = None
self.last_updated_ts = None self.last_updated_ts = None
self.last_updated_nid = None self.last_updated_nid = None
...@@ -365,20 +417,15 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -365,20 +417,15 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.update_hunk = self.all_reduce_func self.update_hunk = self.all_reduce_func
elif self.mode == 'historical': elif self.mode == 'historical':
self.update_hunk = self.historical_func self.update_hunk = self.historical_func
elif self.mode == 'local': elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func self.update_hunk = self.local_func
def forward(self, mfg, param = None): def forward(self, mfg, param = None):
for b in mfg: for b in mfg:
if self.dim_time > 0: updated_memory = self.updater(b)
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() self.last_updated_nid = b.srcdata['ID'].detach().clone()
with torch.no_grad(): with torch.no_grad():
if param is not None: if param is not None:
_,src,dst,ts,edge_feats,nxt_fetch_func = param _,src,dst,ts,edge_feats,nxt_fetch_func = param
...@@ -401,14 +448,13 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -401,14 +448,13 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max') self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func) self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func)
if self.memory_param['combine_node_feature']: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat > 0: if self.dim_node_feat == self.dim_hid:
if self.dim_node_feat == self.dim_hid: b.srcdata['h'] += memory
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else: else:
b.srcdata['h'] = updated_memory b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = memory
def empty_cache(self): def empty_cache(self):
pass pass
...@@ -68,25 +68,26 @@ class GeneralModel(torch.nn.Module): ...@@ -68,25 +68,26 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param self.train_param = train_param
if memory_param['type'] == 'node': if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru': if memory_param['memory_update'] == 'gru':
if memory_param['async'] == False: #if memory_param['async'] == False:
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) # self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
else: #else:
updater = torch.nn.GRUCell updater = torch.nn.GRUCell
if memory_param['historical_fix'] == False: # if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode']) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
else: # else:
self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'rnn': elif memory_param['memory_update'] == 'rnn':
if memory_param['async'] == False: #if memory_param['async'] == False:
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) # self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
else: #else:
updater = torch.nn.RNNCell updater = torch.nn.RNNCell
if memory_param['historical_fix'] == False: #if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode']) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
else: # else:
self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param) self.memory_updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
else: else:
raise NotImplementedError raise NotImplementedError
self.dim_node_input = memory_param['dim_out'] self.dim_node_input = memory_param['dim_out']
......
...@@ -152,7 +152,7 @@ class SharedMailBox(): ...@@ -152,7 +152,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False, memory,embedding=None,use_src_emb=False,use_dst_emb=False,
remote_src = None, remote_dst = None,Reduce_score=None): deliver_to='self',block = None,Reduce_score=None,):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -164,26 +164,21 @@ class SharedMailBox(): ...@@ -164,26 +164,21 @@ class SharedMailBox():
if embedding is not None: if embedding is not None:
emb_src = embedding[src] emb_src = embedding[src]
emb_dst = embedding[dst] emb_dst = embedding[dst]
if remote_src is None: src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1) dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1) if edge_feats is not None:
if edge_feats is not None: src_mail = torch.cat([src_mail, edge_feats], dim=1)
src_mail = torch.cat([src_mail, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1) mail = torch.cat([src_mail, dst_mail], dim=0)
mail = torch.cat([src_mail, dst_mail], dim=0) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts) #print(mail_ts)
if Reduce_score is not None: if deliver_to == 'neighbor':
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device) assert block is None and Reduce_score is not None
else: mail = torch.cat([mail, mail[block.edges()[1].long()]], dim=0)
src_mail = torch.cat([emb_src if use_src_emb else mem_src, remote_dst], dim=1) mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[1].long()]], dim=0)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, remote_src], dim=1) index = torch.cat([index,dist_indx_mapper[block.edges()[0].long()]],dim=0)
if edge_feats is not None: if Reduce_score is not None:
src_mail = torch.cat([src_mail, edge_feats[:src_mail.shape[0]]], dim=1) Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
dst_mail = torch.cat([dst_mail, edge_feats[src_mail.shape[0]:]], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = ts.to(self.device).to(self.mailbox_ts.dtype)
#.reshape(-1, src_mail.shape[1])
if Reduce_score is None: if Reduce_score is None:
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
...@@ -250,10 +245,12 @@ class SharedMailBox(): ...@@ -250,10 +245,12 @@ class SharedMailBox():
def sychronize_shared(self): def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update() out=self.historical_cache.synchronize_shared_update()
if out is not None: if out is not None:
shared_index,shared_data,shared_ts = out shared_index,shared_data,shared_ts,mail,mail_ts = out
#print(shared_ts,self.node_memory_ts.accessor.data[self.shared_nodes_index[shared_index]]) index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[self.shared_nodes_index[shared_index]] = shared_data self.node_memory.accessor.data[index] = shared_data
self.node_memory_ts.accessor.data[self.shared_nodes_index[shared_index]] = shared_ts self.node_memory_ts.accessor.data[index] = shared_ts
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self): def update_shared(self):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
if self.last_job is not None: if self.last_job is not None:
...@@ -300,7 +297,6 @@ class SharedMailBox(): ...@@ -300,7 +297,6 @@ class SharedMailBox():
shared_memory_ts = memory_ts[mask] shared_memory_ts = memory_ts[mask]
shared_mail = mail[mask] shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask] shared_mail_ts = mail_ts[mask]
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
if mode == 'historical': if mode == 'historical':
#print(shared_memory_ind) #print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts) update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
...@@ -308,12 +304,14 @@ class SharedMailBox(): ...@@ -308,12 +304,14 @@ class SharedMailBox():
shared_memory_ind = shared_memory_ind[update_index] shared_memory_ind = shared_memory_ind[update_index]
shared_memory = shared_memory[update_index] shared_memory = shared_memory[update_index]
shared_memory_ts = shared_memory_ts[update_index] shared_memory_ts = shared_memory_ts[update_index]
shared_mail = shared_mail[update_index]
shared_mail_ts = shared_mail_ts[update_index]
#print(shared_memory_ind) #print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode) #mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode) #mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0] self.tot_shared_count += shared_memory_ind.shape[0]
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode) mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int) broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0] broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)] shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
...@@ -337,23 +335,23 @@ class SharedMailBox(): ...@@ -337,23 +335,23 @@ class SharedMailBox():
#id = shared_index.sort() #id = shared_index.sort()
#print(mem[id],shared_index[id]) #print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts, #,shared_memory,shared_memory_ts,
shared_memory,shared_memory_ts = self.unpack(mem) #shared_memory,shared_memory_ts = self.unpack(mem)
#shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem) shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True) unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape) #print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0) #min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory = shared_memory[idx] shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx] shared_memory_ts = shared_memory_ts[idx]
#shared_mail_ts = shared_mail_ts[idx] shared_mail_ts = shared_mail_ts[idx]
#shared_mail = shared_mail[idx] shared_mail = shared_mail[idx]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0) #shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0) #shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts) self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts self.historical_cache.local_ts[shared_index] = shared_memory_ts
#self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts) self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
else: else:
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
......
...@@ -300,6 +300,11 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud ...@@ -300,6 +300,11 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
fnode_share = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top) fnode_share = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_ldg/{}/reorder.txt'.format(data) reorder = '../../SPEED/partition/divided_nodes_ldg/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank) edge_i = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'dis_tgl':
fnode_i = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_dis/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'random': elif partition == 'random':
df = load_graph(data) df = load_graph(data)
src = torch.from_numpy(np.array(df.src.values)).long() src = torch.from_numpy(np.array(df.src.values)).long()
...@@ -348,7 +353,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud ...@@ -348,7 +353,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
node_i = torch.tensor(node_list).reshape(-1).to(torch.long) node_i = torch.tensor(node_list).reshape(-1).to(torch.long)
edge_i = torch.tensor(eid_list).reshape(-1).to(torch.long) edge_i = torch.tensor(eid_list).reshape(-1).to(torch.long)
#reid = torch.arange(len(reid))#torch.tensor(reid).reshape(-1) #reid = torch.arange(len(reid))#torch.tensor(reid).reshape(-1)
shared_node = torch.tensor(shared_node_list).reshape(-1).to(torch.long) if partition == 'dis_tgl':
shared_node = torch.tensor([],dtype=torch.long)
else:
shared_node = torch.tensor(shared_node_list).reshape(-1).to(torch.long)
#print(reid) #print(reid)
return load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev=sampler_graph_add_rev,edge_i=edge_i,reid=None,device=device,feature_device=feature_device) return load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev=sampler_graph_add_rev,edge_i=edge_i,reid=None,device=device,feature_device=feature_device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment