Commit 86a4611d by zlj

add APAN model

parent 24a069d6
......@@ -10,6 +10,7 @@ __pycache__/
# Distribution / packaging
.Python
examples/all*
build/
develop-eggs/
dist/
......
sampling:
- layer: 1
neighbor:
- 10
- 20
strategy: 'recent'
prop_time: False
history: 1
......
......@@ -263,7 +263,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
int cal_cnt = 0;
for(int cid = end_index-1;cid>=0;cid--){
cal_cnt++;
if(cal_cnt > 2*fanout)break;
if(cal_cnt > fanout)break;
int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
......

43.7 KB | W: | H:

41.9 KB | W: | H:

examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
examples/all/WIKI_boundary_Convergence_rate.png
  • 2-up
  • Swipe
  • Onion skin

14.2 KB | W: | H:

14.2 KB | W: | H:

examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
examples/all/boundary_AP_WIKI.png
  • 2-up
  • Swipe
  • Onion skin

16.9 KB | W: | H:

15.8 KB | W: | H:

examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
examples/all/boundary_comm_WIKI.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -2,12 +2,12 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
data_values = ['WikiTalk'] # 存储从文件中读取的数据
ssim_values = [0, 0.5, 1.0, 1.5, 2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=8
partitions=4
topk=0.01
mem='historical'
for data in data_values:
......@@ -28,7 +28,7 @@ for data in data_values:
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
ap_list.append(ap)
comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
# 绘制柱状图
bar_width = 0.4
......@@ -40,12 +40,11 @@ for data in data_values:
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.ylim([0.9,1])
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}.png'.format(data))
plt.savefig('ssim_{}_{}.png'.format(data,partitions))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
......@@ -54,7 +53,7 @@ for data in data_values:
plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('comm_{}.png'.format(data))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions))
plt.clf()
if partition == 'ours_shared':
......@@ -69,11 +68,12 @@ for data in data_values:
val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
print(val_ap)
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_ssim_Convergence_rate.png'.format(data))
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions))
plt.clf()
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -2,7 +2,7 @@
# 定义数组变量
addr="192.168.1.107"
partition_params=("ours")
partition_params=("dis_tgl" "ours" "metis" "random")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="4"
......@@ -10,15 +10,17 @@ node_per="4"
nnodes="1"
node_rank="0"
probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0")
#sample_type_params=("recent") #"boundery_recent_decay" "boundery_recent_uniform")
sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
#sample_type_params=("recent")
#sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
sample_type_params=("recent")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("all_update")
memory_type=("all_update" "local" "historical")
#"historical" "all_update") #"local" "historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0" "0.1" "0.2" "0.3" "0.4" )
shared_memory_ssim=("0" "0.5" "1.0" "1.5")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("DGraphFin" "WikiTalk")
data_param=("LASTFM")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
# 创建输出目录
......@@ -54,7 +56,7 @@ for data in "${data_param[@]}"; do
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ]; then
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
......@@ -64,21 +66,22 @@ for data in "${data_param[@]}"; do
for pro in "${probability_params[@]}"; do
for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
wait
fi
done
continue
# for ssim in "${shared_memory_ssim[@]}"; do
# if [ "$partition" = "ours" ]; then
# torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
# wait
# fi
# done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
if [ "$partition" = "ours"]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ]; then
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
......
......@@ -79,6 +79,13 @@ parser.add_argument('--memory_type', default='all_update', type=str, metavar='W'
help='name of model')
#boundery_recent_uniform boundery_recent_decay
args = parser.parse_args()
if args.memory_type == 'all_local' or args.topk != '0':
train_cross_probability = 0
else:
train_cross_probability = 1
if args.memory_type == 'all_local':
args.sample_type = 'boundery_recent_uniform'
args.probability = 0
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
......@@ -271,7 +278,7 @@ def main():
is_pipeline=True,
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=1 if float(args.topk) == 0 else 0
probability=train_cross_probability
)
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
......
......@@ -124,7 +124,7 @@ class HistoricalCache:
def synchronize_shared_update(self,filter=None):
if self.last_shared_update_wait is None:
return
return None
handle0,handle1,shared_index,shared_data = self.last_shared_update_wait
self.last_shared_update_wait = None
handle0.wait()
......@@ -132,16 +132,21 @@ class HistoricalCache:
shared_data = torch.cat(shared_data,dim = 0)
shared_index = torch.cat(shared_index)
if(shared_data.shape[0] == 0):
return
shared_ts = shared_data[:,-1]
shared_data = shared_data[:,:-1]
return None
len = self.local_historical_data.shape[1]
mail_ts = shared_data[:,-1]
mail_data = shared_data[:,len+1:-1]
shared_ts = shared_data[:,len]
shared_mem = shared_data[:,:len]
#print(shared_index)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_ts,inv,0)
#shared_ts = torch_scatter.scatter_mean(shared_ts,inv,0)
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_data = shared_data[idx]
shared_mem = shared_mem[idx]
shared_ts = shared_ts[idx]
mail_data = mail_data[idx]
mail_ts = mail_ts[idx]
shared_index = unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None:
......@@ -151,10 +156,10 @@ class HistoricalCache:
# change /=change.max()
# change = 2*change - 1
# filter.update(shared_index,change)
self.local_historical_data[shared_index] = shared_data
self.local_historical_data[shared_index] = shared_mem
self.local_ts[shared_index] = shared_ts
self.last_shared_update_wait = None
return shared_index,shared_data,shared_ts
return shared_index,shared_mem,shared_ts,mail_data,mail_ts
......
......@@ -68,25 +68,26 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
if memory_param['async'] == False:
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
else:
updater = torch.nn.GRUCell
if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
else:
self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
#if memory_param['async'] == False:
# self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
#else:
updater = torch.nn.GRUCell
# if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
# else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'rnn':
if memory_param['async'] == False:
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
else:
updater = torch.nn.RNNCell
if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
else:
self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
#if memory_param['async'] == False:
# self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
#else:
updater = torch.nn.RNNCell
#if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
# else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
self.memory_updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
......
......@@ -152,7 +152,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False,
remote_src = None, remote_dst = None,Reduce_score=None):
deliver_to='self',block = None,Reduce_score=None,):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
......@@ -164,26 +164,21 @@ class SharedMailBox():
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
if remote_src is None:
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
else:
src_mail = torch.cat([emb_src if use_src_emb else mem_src, remote_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, remote_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats[:src_mail.shape[0]]], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats[src_mail.shape[0]:]], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = ts.to(self.device).to(self.mailbox_ts.dtype)
#.reshape(-1, src_mail.shape[1])
if deliver_to == 'neighbor':
assert block is None and Reduce_score is not None
mail = torch.cat([mail, mail[block.edges()[1].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[1].long()]], dim=0)
index = torch.cat([index,dist_indx_mapper[block.edges()[0].long()]],dim=0)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None:
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
......@@ -250,10 +245,12 @@ class SharedMailBox():
def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update()
if out is not None:
shared_index,shared_data,shared_ts = out
#print(shared_ts,self.node_memory_ts.accessor.data[self.shared_nodes_index[shared_index]])
self.node_memory.accessor.data[self.shared_nodes_index[shared_index]] = shared_data
self.node_memory_ts.accessor.data[self.shared_nodes_index[shared_index]] = shared_ts
shared_index,shared_data,shared_ts,mail,mail_ts = out
index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data
self.node_memory_ts.accessor.data[index] = shared_ts
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self):
ctx = DistributedContext.get_default_context()
if self.last_job is not None:
......@@ -300,7 +297,6 @@ class SharedMailBox():
shared_memory_ts = memory_ts[mask]
shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask]
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
if mode == 'historical':
#print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
......@@ -308,12 +304,14 @@ class SharedMailBox():
shared_memory_ind = shared_memory_ind[update_index]
shared_memory = shared_memory[update_index]
shared_memory_ts = shared_memory_ts[update_index]
shared_mail = shared_mail[update_index]
shared_mail_ts = shared_mail_ts[update_index]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
......@@ -337,23 +335,23 @@ class SharedMailBox():
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
shared_memory,shared_memory_ts = self.unpack(mem)
#shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
#shared_mail_ts = shared_mail_ts[idx]
#shared_mail = shared_mail[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
#self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
else:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
......
......@@ -300,6 +300,11 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
fnode_share = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_ldg/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'dis_tgl':
fnode_i = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_dis/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_dis/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'random':
df = load_graph(data)
src = torch.from_numpy(np.array(df.src.values)).long()
......@@ -348,7 +353,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
node_i = torch.tensor(node_list).reshape(-1).to(torch.long)
edge_i = torch.tensor(eid_list).reshape(-1).to(torch.long)
#reid = torch.arange(len(reid))#torch.tensor(reid).reshape(-1)
shared_node = torch.tensor(shared_node_list).reshape(-1).to(torch.long)
if partition == 'dis_tgl':
shared_node = torch.tensor([],dtype=torch.long)
else:
shared_node = torch.tensor(shared_node_list).reshape(-1).to(torch.long)
#print(reid)
return load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev=sampler_graph_add_rev,edge_i=edge_i,reid=None,device=device,feature_device=feature_device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment