Commit e35ddb0f by zlj

fix some bugs

parent 711af292
......@@ -166,6 +166,7 @@ cython_debug/
#.idea/
*.pt
*.out
/*.out
/a.out
......
......@@ -244,6 +244,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
double s_start_time = omp_get_wtime();
if ((policy == "recent") || (end_index <= fanout)&& policy.substr(0,8) != "boundery" ){
int cnt = 0;
for(int cid = end_index-1;cid>=0;cid--){
cnt++;
if(cnt>fanout)break;
}
int start_index = max(0, end_index-fanout);
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), end_index-start_index, i);
tgb_i[tid].sample_nodes.insert(tgb_i[tid].sample_nodes.end(), tnb.neighbors[node].begin()+start_index, tnb.neighbors[node].begin()+end_index);
......
#!/bin/bash
# 定义数组变量
addr="192.168.1.107"
partition_params=("ours" "metis" "ldg" "random")
partitions="16"
nnodes="4"
#("ours" "metis" "ldg" "random")
partitions="4"
nnodes="1"
node_rank="0"
probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0")
sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
shared_memory_ssim=("1" "2" "3" "10" "100")
data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
sample_type_params=("recent")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "local" "historical")
memory_type=("local" "all_update")
shared_memory_ssim=("0" "0.1" "0.2" "0.3" "0.4" )
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
data_param=("REDDIT" "WikiTalk")
# 创建输出目录
mkdir -p all
# 遍历数组并执行命令
for data in "${data_param[@]}"; do
mkdir all/"$data"
torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition ours --memory_type local --sample_type recent --topk 0 > all/"$data"/1.out &
mkdir all/"$data"/comm
#torchrun --nnodes "$nnodes" --node_rank 0 --nproc-per-node 1 --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition ours --memory_type local --sample_type recent --topk 0 > all/"$data"/1.out &
wait
for partition in "${partition_params[@]}"; do
for sample in "${sample_type_params[@]}"; do
......@@ -26,15 +33,20 @@ for data in "${data_param[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample".out &
wait
fi
fi
......@@ -45,15 +57,20 @@ for data in "${data_param[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$partitions"-ours_shared-0.01"$mem"-"$ssim"-"$sample"-"$pro".out &
wait
fi
done
elif [ "$mem" = "all_reduce" ]; then
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out&
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait
if [ "$partition" = "ours" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr 192.168.1.105 --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$partitions" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode TGN_large --partition "$partition" --topk 0.01 --sample_type "$sample" --probability "$pro" --memory_type "$mem" > all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
wait
fi
fi
......
......@@ -211,7 +211,7 @@ def main():
num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
policy = args.sample_type#'boundery_recent_decay'
policy_train = args.sample_type#'boundery_recent_decay'
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
......@@ -219,7 +219,7 @@ def main():
else:
mailbox = None
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=1,policy = policy_train, graph_name = "train",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
eval_sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=eval_sample_graph, workers=10,policy = 'recent', graph_name = "eval",local_part=dist.get_rank(),edge_part=DistIndex(graph.eids_mapper).part,node_part=DistIndex(graph.nids_mapper).part,probability=args.probability)
train_data = torch.masked_select(graph.edge_index,train_mask.to(graph.edge_index.device)).reshape(2,-1)
......@@ -246,7 +246,9 @@ def main():
print('ts {} {} {} {}'.format(train_data.ts,eval_train_data.ts,test_data.ts,val_data.ts))
neg_samples = args.neg_samples
mask = DistIndex(graph.nids_mapper[graph.edge_index[1,:]].to('cpu')).part == dist.get_rank()
if args.local_neg_sample:
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else:
train_neg_sampler = LocalNegativeSampling('triplet',amount = neg_samples,dst_node_list = full_dst.unique())
......@@ -266,7 +268,7 @@ def main():
is_pipeline=True,
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=1 if args.topk == 0 else 0
probability=1 if float(args.topk) == 0 else 0
)
eval_trainloader = DistributedDataLoader(graph,eval_train_data,sampler = eval_sampler,
......@@ -440,8 +442,34 @@ def main():
model.module.memory_updater.last_updated_ts = None
t0 = time.time()
t_s = tt.start()
sum_local_comm = 0
sum_remote_comm = 0
sum_local_edge_comm = 0
sum_remote_edge_comm = 0
local_access = []
remote_access = []
local_comm = []
remote_comm = []
local_edge_access = []
remote_edge_access = []
local_edge_comm = []
remote_edge_comm = []
b_cnt = 0
for roots,mfgs,metadata in trainloader:
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1
local_access.append(trainloader.local_node)
remote_access.append(trainloader.remote_node)
local_edge_access.append(trainloader.local_edge)
remote_edge_access.append(trainloader.remote_edge)
local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
sum_local_comm +=local_comm[b_cnt-1]
sum_remote_comm +=remote_comm[b_cnt-1]
sum_local_edge_comm +=local_edge_comm[b_cnt-1]
sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
tt.pre_input += tt.elapsed(t_s)
t_prep_s = time.time()
t1 = time.time()
......@@ -547,7 +575,10 @@ def main():
torch.distributed.all_reduce(tot_comm_count,group=ctx.gloo_group)
torch.distributed.all_reduce(tot_shared_count,group=ctx.gloo_group)
print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge))
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
if(e==0):
torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
tt.print()
ap = 0
auc = 0
......@@ -586,7 +617,7 @@ def main():
torch.save(model.module.state_dict(), get_checkpoint_path(e))
torch.save(val_list,'all/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time)
if not early_stop:
dist.barrier()
......
......@@ -80,7 +80,7 @@ class HistoricalCache:
self.local_ts.zero_()
self.loss_count.zero_()
def ssim(self,x,y,type = 'F'):
def ssim(self,x,y,type = 'cos'):
if type == 'cos':
return 1-torch.nn.functional.cosine_similarity(x,y)
......@@ -90,6 +90,7 @@ class HistoricalCache:
if self.time_threshold is not None:
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
else:
#print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index]))
#print(new_data,self.local_historical_data[index])
#print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold))
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold)
......@@ -136,12 +137,13 @@ class HistoricalCache:
shared_data = shared_data[:,:-1]
#print(shared_index)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#max_ts,idx = torch_scatter.scatter_mean(shared_ts,inv,0)
shared_ts = torch_scatter.scatter_mean(shared_ts,inv,0)
shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
#shared_data = shared_data[idx]
#shared_ts = shared_ts[idx]
max_ts,idx = torch_scatter.scatter_max(shared_ts,inv,0)
#shared_ts = torch_scatter.scatter_mean(shared_ts,inv,0)
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_data = shared_data[idx]
shared_ts = shared_ts[idx]
shared_index = unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None:
# change = shared_data - self.local_historical_data[shared_index]
# if not (change.max().item() == 0):
......
......@@ -292,7 +292,6 @@ class TransfomerAttentionLayer(torch.nn.Module):
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
torch.set_printoptions(threshold=100000)
#print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
......@@ -320,10 +320,10 @@ class AsyncMemeoryUpdater(torch.nn.Module):
nxt_fetch_func()
def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
submit_to_queue = False
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
......@@ -331,10 +331,10 @@ class AsyncMemeoryUpdater(torch.nn.Module):
nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
self.mailbox.sychronize_shared()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
submit_to_queue = False
self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='historical',set_remote=False,submit = submit_to_queue)
def local_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func):
if nxt_fetch_func is not None:
......
......@@ -30,6 +30,11 @@ src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
class CountComm:
origin_remote = 0
origin_local = 0
def __init__(self):
pass
total_sample_core_time = 0
total_fetch_prepare_time = 0
total_comm_time = 0
......@@ -213,7 +218,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = dst.to(graph.nids_mapper.device)
#print(src_node)
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
......@@ -232,6 +237,11 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
#print(CountComm.origin_local,CountComm.origin_remote)
#for p in range(dist.get_world_size()):
# print((DistIndex(dist_nid).part == p).sum().item())
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
#print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid))
......
class count_comm:
origin_remote = 0
origin_local = 0
def __init__(self):
pass
\ No newline at end of file
......@@ -99,7 +99,7 @@ class DistributedDataLoader:
local_embedding = False,
cache_mask = None,
use_local_feature = True,
probability = 0,
probability = 1,
**kwargs
):
self.use_local_feature = use_local_feature
......@@ -149,6 +149,7 @@ class DistributedDataLoader:
self.remote_root = 0
self.local_root = 0
self.probability = probability
print('pro {}\n'.format(self.probability))
def __iter__(self):
if self.chunk_size is None:
......@@ -215,19 +216,20 @@ class DistributedDataLoader:
# self.current_pos += self.batch_size
#else:
if self.batch_pos_l[self.recv_idxs] == -1:
if self.batch_pos_l[self.submitted] == -1:
next_data = self.input_dataset[torch.tensor([],device=self.device,dtype= torch.long)]
else:
next_data = self.input_dataset[self.batch_pos_l[self.recv_idxs]:self.batch_pos_r[self.recv_idxs] +1]
next_data = self.input_dataset[self.batch_pos_l[self.submitted]:self.batch_pos_r[self.submitted] +1]
if self.mode=='train' and self.probability < 1:
mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank()))
if self.probability > 0:
mask[~mask] = (torch.rand((~mask)) < self.probability)
next_data = next_data[mask.to(next_data.device)]
self.submitted = self.submitted + 1
return next_data
def submit(self):
global executor
if(self.recv_idxs+1<=self.expected_idx):
if(self.submitted < self.expected_idx):
data = self._next_data()
with torch.cuda.stream(stream):
fut = executor.submit(
......@@ -306,9 +308,9 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
if(self.mailbox is not None and self.mailbox.historical_cache is not None):
id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(id).is_shared
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared
# a = DistIndex(id[mask]).loc
# b = self.mailbox.historical_cache.local_historical_data[a]
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id[mask]).loc]
......@@ -331,9 +333,9 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid)
if(self.mailbox is not None and self.mailbox.historical_cache is not None):
id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(id).is_shared
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]]
self.recv_idxs += 1
else:
......
......@@ -300,6 +300,7 @@ class SharedMailBox():
shared_memory_ts = memory_ts[mask]
shared_mail = mail[mask]
shared_mail_ts = mail_ts[mask]
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
if mode == 'historical':
#print(shared_memory_ind)
update_index = self.historical_cache.historical_check(shared_memory_ind,shared_memory,shared_memory_ts)
......@@ -350,6 +351,8 @@ class SharedMailBox():
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index = unq_index
self.set_memory_local(self.shared_nodes_index[shared_index],shared_memory,shared_memory_ts)
self.historical_cache.local_historical_data[shared_index] = shared_memory
self.historical_cache.local_ts[shared_index] = shared_memory_ts
#self.set_mailbox_local(self.shared_nodes_index[shared_index],shared_mail,shared_mail_ts)
else:
start = torch.cuda.Event(enable_timing=True)
......
......@@ -48,6 +48,7 @@ def load_feat(d, node_num = 0, edge_num = 0, rand_de=0, rand_dn=0):
def load_graph(d):
df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/{}/edges.csv'.format(d))
return df
def build_shared_index(ids,shared_ids,to_shared_idexs):
......@@ -130,6 +131,7 @@ def load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev = Tru
df = load_graph(data)
src = torch.from_numpy(np.array(df.src.values)).long()
dst = torch.from_numpy(np.array(df.dst.values)).long()
print('tot edge {} circ edge {} same edge {}\n'.format(src.shape[0],torch.stack((src,dst)).unique(dim = 1).shape[1],(src==dst).sum().item()))
ts = torch.from_numpy(np.array(df.time.values)).long()
train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
......@@ -260,10 +262,10 @@ def get_eval_batch(eid,nid_mapper,eid_mapper,batch_size):
eid_for_test = mask[eid.to('cpu')].to('cuda')
local_belong_batch = belong_batch[eid_for_test]
pos = torch.arange(eid_for_test.sum().item(),device=torch.device('cuda'),dtype=torch.long)
posl,_ = torch_scatter.scatter_min(pos,local_belong_batch)
posr,_ = torch_scatter.scatter_max(pos,local_belong_batch)
posl,_ = torch_scatter.scatter_min(pos,local_belong_batch,dim_size=max_batch+1)
posr,_ = torch_scatter.scatter_max(pos,local_belong_batch,dim_size=max_batch+1)
global_edge[i] = posr-posl+1
print(global_edge)
print('average: {}\n'.format((global_edge.max(dim = 0)[0]/global_edge.min(dim=0)[0]).float().mean()))
print('max average: {}\n'.format((global_edge.max(dim = 0)[0]).float().mean()))
print('min average: {}\n'.format((global_edge.min(dim = 0)[0]).float().mean()))
......@@ -284,24 +286,35 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
return load_from_shared_node_partition(data,None,None,sample_add_rev=sampler_graph_add_rev,device=device,feature_device=feature_device)
else:
if partition == 'ours':
fnode_i = '../../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../../SPEED/partition/divided_nodes_seed_t2/{}/reorder.txt'.format(data)
edge_i = '../../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_t2/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'metis':
fnode_i = '../../../SPEED/partition/divided_nodes_metis_balance/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../../SPEED/partition/divided_nodes_metis_balance/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../../SPEED/partition/divided_nodes_seed_ldg/{}/reorder.txt'.format(data)
edge_i = '../../../SPEED/partition/divided_nodes_metis_balance/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_i = '../../SPEED/partition/divided_nodes_metis/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_metis/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_metis/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_metis/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'ldg':
fnode_i = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_ldg/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_ldg/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'random':
df = load_graph(data)
src = torch.from_numpy(np.array(df.src.values)).long()
dst = torch.from_numpy(np.array(df.dst.values)).long()
num_node = max(src.max().item(),dst.max().item())+1
part = torch.randint(0,dist.get_world_size(),size = [num_node])
#part = torch.arange(num_node)%dist.get_world_size()
#edge_part = torch.randint(0,dist.get_world_size(),size = [src.shape[0]])
node_i = (part == dist.get_rank()).nonzero().reshape(-1)
dist.broadcast(part,src=0,group=ctx.gloo_group)
#dist.broadcast(edge_part,src=0,group=ctx.gloo_group)
print(part)
#edge_i = (edge_part == dist.get_rank()).nonzero().reshape(-1)#
edge_i = (part[src] == dist.get_rank()).nonzero().reshape(-1)
print(node_i,edge_i)
shared_node_list = []
shared_node = torch.tensor(shared_node_list).reshape(-1).to(torch.long)
return load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev=sampler_graph_add_rev,edge_i=edge_i,reid=None,device=device,feature_device=feature_device)
......
......@@ -139,7 +139,7 @@ class NeighborSampler(BaseSampler):
else:
assert tnb is not None
self.tnb = tnb
print(local_part,node_part)
self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy, local_part,edge_part.to(torch.int),node_part.to(torch.int),probability)
......@@ -257,7 +257,7 @@ class NeighborSampler(BaseSampler):
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole,
is_unique:bool = True,
is_unique:bool = False,
nid_mapper = None,
eid_mapper = None,
out_device = None
......@@ -349,6 +349,7 @@ class NeighborSampler(BaseSampler):
# dst_pos_index = first_block_id[dst_pos_index].contiguous()
# dst_neg_index = first_block_id[dst_neg_index].contiguous()
# src_neg_index = first_block_id[src_neg_index].contiguous()
metadata['seed'] = seed
metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment