Commit 3814b57f by zhlj

fix bugs for memory and speed up

parent 8adc9ed9
......@@ -30,7 +30,7 @@ train:
- epoch: 100
batch_size: 1000
# reorder: 16
lr: 0.0008
lr: 0.0004
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
......@@ -30,7 +30,7 @@ train:
- epoch: 50
batch_size: 3000
# reorder: 16
lr: 0.0008
lr: 0.0004
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
#!/bin/bash
# 定义数组变量
addr="192.168.1.107"
addr="192.168.1.105"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="8"
partitions="4"
node_per="4"
nnodes="2"
nnodes="1"
node_rank="0"
probability_params=("0.1" "0.05" "0.01" "0")
sample_type_params=("recent" "boundery_recent_decay")
probability_params=("0" "0.1" "0.05" "0.01")
sample_type_params=("boundery_recent_decay" "recent" )
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local")
memory_type=("historical" "all_update" "local")
memory_type=( "historical" "local" "all_update")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
data_param=( "TaoBao" "StackOverflow" "GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
......
......@@ -63,11 +63,11 @@ parser.add_argument('--partition', default='part', type=str, metavar='W',
help='name of model')
parser.add_argument('--topk', default='0', type=str, metavar='W',
help='name of model')
parser.add_argument('--probability', default=1, type=float, metavar='W',
parser.add_argument('--probability', default=True, type=float, metavar='W',
help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model')
parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W',
parser.add_argument('--local_neg_sample', default=True, type=bool, metavar='W',
help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model')
......@@ -197,7 +197,9 @@ def main():
torch.set_num_threads(10)
device_id = torch.cuda.current_device()
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
torch.autograd.set_detect_anomaly(True)
if(args.dataname=='GDELT'):
train_param['epoch'] = 10
#torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用
if torch.cuda.is_available():
print("Total GPU memory: ", torch.cuda.get_device_properties(0).total_memory/1024**3)
......@@ -209,7 +211,7 @@ def main():
print("CUDA is not available.")
full_dst = full_sampler_graph['edge_index'][1,torch.arange(0,full_sampler_graph['edge_index'].shape[1],2)]
sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)#,local_eids=graph.eids)
sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)
eval_sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)
Path("../saved_models/").mkdir(parents=True, exist_ok=True)
Path("../saved_checkpoints/").mkdir(parents=True, exist_ok=True)
......@@ -495,19 +497,19 @@ def main():
for roots,mfgs,metadata in trainloader:
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1
local_access.append(trainloader.local_node)
remote_access.append(trainloader.remote_node)
local_edge_access.append(trainloader.local_edge)
remote_edge_access.append(trainloader.remote_edge)
local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
if 'ID' in mfgs[0][0].edata:
local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
sum_local_edge_comm +=local_edge_comm[b_cnt-1]
sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
sum_local_comm +=local_comm[b_cnt-1]
sum_remote_comm +=remote_comm[b_cnt-1]
#local_access.append(trainloader.local_node)
#remote_access.append(trainloader.remote_node)
#local_edge_access.append(trainloader.local_edge)
#remote_edge_access.append(trainloader.remote_edge)
#local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
#remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
#if 'ID' in mfgs[0][0].edata:
# local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
# remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
# sum_local_edge_comm +=local_edge_comm[b_cnt-1]
# sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
#sum_local_comm +=local_comm[b_cnt-1]
#sum_remote_comm +=remote_comm[b_cnt-1]
if mailbox is not None:
if(graph.efeat.device.type != 'cpu'):
......@@ -526,13 +528,9 @@ def main():
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
if memory_param['historical_fix'] == True:
loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.update_memory,model.module.memory_updater.prev_memory)
else:
#loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.last_updated_memory,model.module.memory_updater.pre_mem)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
total_loss += float(loss.item())
#mailbox.handle_last_async()
#trainloader.async_feature()
#torch.cuda.synchronize()
......@@ -546,6 +544,8 @@ def main():
#torch.cuda.synchronize()
mailbox.update_shared()
mailbox.update_p2p()
#torch.cuda.empty_cache()
"""
if mailbox is not None:
#src = metadata['src_pos_index']
......@@ -606,8 +606,8 @@ def main():
print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge))
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
if(e==0):
torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
#if(e==0):
# torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
ap = 0
auc = 0
tt.ssim_remote=0
......
......@@ -22,7 +22,6 @@ class Filter(nn.Module):
# Treat filter as parameter so that it is saved and loaded together with the model
self.count = torch.zeros((self.n_nodes),1).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
self.incretment_sqr = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
def get_count(self, node_idxs):
......@@ -31,19 +30,7 @@ class Filter(nn.Module):
def get_incretment(self, node_idxs):
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
return self.incretment[node_idxs,:]#/torch.clamp(self.count[node_idxs,:],1)
def get_incretment_remote(self, idx):
remote_tensor = DistributedTensor(self.incretment)
remote_count = DistributedTensor(self.count)
idx,pos = idx.sort()
xpos = torch.empty_like(pos)
xpos[pos] = torch.arange(pos.shape[0],device = pos.device,dtype=pos.dtype)
return remote_tensor.all_to_all_get(idx)[xpos]/torch.clamp(remote_count.all_to_all_get(idx)[xpos],1)
def get_incretment_sqr(self, node_idxs):
return self.incretment_sqr[node_idxs, :]
return self.incretment[node_idxs,:]/torch.clamp(self.count[node_idxs,:],1)
def detach_filter(self):
self.incretment.detach_()
......@@ -51,11 +38,7 @@ class Filter(nn.Module):
def update(self, node_idxs, incret):
self.count[node_idxs, :] = self.count[node_idxs, :] + 1
self.incretment[node_idxs, :] = self.incretment[node_idxs, :] + incret
self.incretment_sqr[node_idxs, :] = self.incretment_sqr[node_idxs, :] + incret * incret
def clear(self):
self.count.zero_()
self.incretment.zero_()
def compute_prediction(self, node_idxs):
mu = self.incretment[node_idxs, :]/self.count[node_idxs, :]
sigma = self.incretment_sqr[node_idxs, :]/self.count[node_idxs, :]
return mu, sigma
\ No newline at end of file
......@@ -101,7 +101,7 @@ class HistoricalCache:
return mask
def read_synchronize(self):
torch.cuda.synchronize(get_stream_set(self.layer))
get_stream_set(self.layer).synchronize()
# def get_cold_data(self,index,to_device):
# with get_stream_set(self.layer):
......
......@@ -291,19 +291,19 @@ class TransfomerAttentionLayer(torch.nn.Module):
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone()
V_remote = V.clone()
V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
#V_local = V.clone()
#V_remote = V.clone()
#V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
#V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V
b.edata['v0'] = V_local
b.edata['v1'] = V_remote
b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
#b.edata['v0'] = V_local
#b.edata['v1'] = V_remote
#b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
#b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
tt.ssim_cnt += b.num_dst_nodes()
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
#tt.ssim_cnt += b.num_dst_nodes()
#print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
......@@ -421,7 +421,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func
if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]),
self.gamma = torch.nn.Parameter(torch.tensor([0.5]),
requires_grad=True)
self.filter = Filter(n_nodes=mailbox.shared_nodes_index.shape[0],
memory_dimension=self.dim_hid,
......@@ -430,7 +430,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.gamma = 1
def forward(self, mfg, param = None):
for b in mfg:
mail_input = b.srcdata['mem_input']
#print(b.srcdata['ID'].shape[0])
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
......@@ -440,14 +440,15 @@ class AsyncMemeoryUpdater(torch.nn.Module):
upd0 = torch.zeros_like(updated_memory0)
if self.mode == 'historical':
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = self.filter.get_incretment(shared_ind)
transition_dense*=2
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /=torch.clamp(transition_dense.max() ,1)
transition_dense /= transition_dense.max()
transition_dense = 2*transition_dense - 1
upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense
upd0[mask] = transition_dense#b.srcdata['his_mem'][mask] + transition_dense
#print(self.gamma)
#print('tran {} {} {}\n'.format(transition_dense.max().item(),upd0[mask].max().item(),b.srcdata['his_mem'][mask].max().item()))
else:
upd0[mask] = b.srcdata['his_mem'][mask]
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
......@@ -457,16 +458,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
with torch.no_grad():
if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask]
change.detach()
if not (change.max().item() == 0):
change -= change.min()
change /=change.max()
change = 2*change - 1
self.filter.update(shared_ind,change)
#print(transition_dense)
#print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
#print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.filter.update(shared_ind,change)
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
......@@ -504,4 +496,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
b.srcdata['h'] = updated_memory
def empty_cache(self):
pass
if self.mode == 'historical':
print('clear\n')
self.filter.clear()
......@@ -240,8 +240,8 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
#print(CountComm.origin_local,CountComm.origin_remote)
#for p in range(dist.get_world_size()):
# print((DistIndex(dist_nid).part == p).sum().item())
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
#CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
#CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
#print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid))
......@@ -250,7 +250,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
"""
if unique:
block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True)
#print(block_node_list.shape,unq_id.shape)
first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id)
first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool)
first_mask[first_index] = True
......@@ -423,9 +423,9 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param)
if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device,reversed)
out,dist_nid,dist_eid = to_block(graph,data,out,out_device)
else:
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,reversed)
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,identity=(sampler.policy=='identity'))
t_e = time.time()
#print(t_e-t_s)
return out,dist_nid,dist_eid
......
......@@ -257,19 +257,19 @@ class DistributedDataLoader:
pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0]
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum()
self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum()
self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum()
self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum()
#self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum()
#self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum()
#self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum()
#self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum()
#self.remote_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part != dist.get_rank()).sum()
#self.local_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part == dist.get_rank()).sum()
#torch.cuda.synchronize(stream)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
#start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True)
#start.record()
stream.synchronize()
end.record()
end.synchronize()
#end.record()
#end.synchronize()
#print(start.elapsed_time(end))
self.result_queue.popleft()
#start = torch.cuda.Event(enable_timing=True)
......
......@@ -13,10 +13,10 @@ class MemoryMoniter:
else:
return torch.sum((x-y)**2,dim=1)
def add(self,nid,pre_memory,now_memory):
self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
self.nid_list.append(nid)
pass
#self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid)
def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment