Commit 3814b57f by zhlj

fix bugs for memory and speed up

parent 8adc9ed9
...@@ -30,7 +30,7 @@ train: ...@@ -30,7 +30,7 @@ train:
- epoch: 100 - epoch: 100
batch_size: 1000 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0008 lr: 0.0004
dropout: 0.2 dropout: 0.2
att_dropout: 0.2 att_dropout: 0.2
all_on_gpu: True all_on_gpu: True
...@@ -30,7 +30,7 @@ train: ...@@ -30,7 +30,7 @@ train:
- epoch: 50 - epoch: 50
batch_size: 3000 batch_size: 3000
# reorder: 16 # reorder: 16
lr: 0.0008 lr: 0.0004
dropout: 0.2 dropout: 0.2
att_dropout: 0.2 att_dropout: 0.2
all_on_gpu: True all_on_gpu: True
#!/bin/bash #!/bin/bash
# 定义数组变量 # 定义数组变量
addr="192.168.1.107" addr="192.168.1.105"
partition_params=("ours" ) partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="8" partitions="4"
node_per="4" node_per="4"
nnodes="2" nnodes="1"
node_rank="0" node_rank="0"
probability_params=("0.1" "0.05" "0.01" "0") probability_params=("0" "0.1" "0.05" "0.01")
sample_type_params=("recent" "boundery_recent_decay") sample_type_params=("boundery_recent_decay" "recent" )
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local") memory_type=( "historical" "local" "all_update")
memory_type=("historical" "all_update" "local")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7") shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao") data_param=( "TaoBao" "StackOverflow" "GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
......
...@@ -63,11 +63,11 @@ parser.add_argument('--partition', default='part', type=str, metavar='W', ...@@ -63,11 +63,11 @@ parser.add_argument('--partition', default='part', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--topk', default='0', type=str, metavar='W', parser.add_argument('--topk', default='0', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--probability', default=1, type=float, metavar='W', parser.add_argument('--probability', default=True, type=float, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W', parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W', parser.add_argument('--local_neg_sample', default=True, type=bool, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W', parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model') help='name of model')
...@@ -197,7 +197,9 @@ def main(): ...@@ -197,7 +197,9 @@ def main():
torch.set_num_threads(10) torch.set_num_threads(10)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu')) graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
torch.autograd.set_detect_anomaly(True) if(args.dataname=='GDELT'):
train_param['epoch'] = 10
#torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用 # 确保 CUDA 可用
if torch.cuda.is_available(): if torch.cuda.is_available():
print("Total GPU memory: ", torch.cuda.get_device_properties(0).total_memory/1024**3) print("Total GPU memory: ", torch.cuda.get_device_properties(0).total_memory/1024**3)
...@@ -209,7 +211,7 @@ def main(): ...@@ -209,7 +211,7 @@ def main():
print("CUDA is not available.") print("CUDA is not available.")
full_dst = full_sampler_graph['edge_index'][1,torch.arange(0,full_sampler_graph['edge_index'].shape[1],2)] full_dst = full_sampler_graph['edge_index'][1,torch.arange(0,full_sampler_graph['edge_index'].shape[1],2)]
sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)#,local_eids=graph.eids) sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)
eval_sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper) eval_sample_graph = TemporalNeighborSampleGraph(full_sampler_graph,mode = 'full',dist_eid_mapper=graph.eids_mapper)
Path("../saved_models/").mkdir(parents=True, exist_ok=True) Path("../saved_models/").mkdir(parents=True, exist_ok=True)
Path("../saved_checkpoints/").mkdir(parents=True, exist_ok=True) Path("../saved_checkpoints/").mkdir(parents=True, exist_ok=True)
...@@ -495,19 +497,19 @@ def main(): ...@@ -495,19 +497,19 @@ def main():
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata in trainloader:
#print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max())) #print('rank is {} batch max ts is {} batch min ts is {}'.format(dist.get_rank(),roots.ts.min(),roots.ts.max()))
b_cnt = b_cnt + 1 b_cnt = b_cnt + 1
local_access.append(trainloader.local_node) #local_access.append(trainloader.local_node)
remote_access.append(trainloader.remote_node) #remote_access.append(trainloader.remote_node)
local_edge_access.append(trainloader.local_edge) #local_edge_access.append(trainloader.local_edge)
remote_edge_access.append(trainloader.remote_edge) #remote_edge_access.append(trainloader.remote_edge)
local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item()) #local_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part == dist.get_rank()).sum().item())
remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item()) #remote_comm.append((DistIndex(mfgs[0][0].srcdata['ID']).part != dist.get_rank()).sum().item())
if 'ID' in mfgs[0][0].edata: #if 'ID' in mfgs[0][0].edata:
local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item()) # local_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part == dist.get_rank()).sum().item())
remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item()) # remote_edge_comm.append((DistIndex(mfgs[0][0].edata['ID']).part != dist.get_rank()).sum().item())
sum_local_edge_comm +=local_edge_comm[b_cnt-1] # sum_local_edge_comm +=local_edge_comm[b_cnt-1]
sum_remote_edge_comm +=remote_edge_comm[b_cnt-1] # sum_remote_edge_comm +=remote_edge_comm[b_cnt-1]
sum_local_comm +=local_comm[b_cnt-1] #sum_local_comm +=local_comm[b_cnt-1]
sum_remote_comm +=remote_comm[b_cnt-1] #sum_remote_comm +=remote_comm[b_cnt-1]
if mailbox is not None: if mailbox is not None:
if(graph.efeat.device.type != 'cpu'): if(graph.efeat.device.type != 'cpu'):
...@@ -526,13 +528,9 @@ def main(): ...@@ -526,13 +528,9 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
if memory_param['historical_fix'] == True:
loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.update_memory,model.module.memory_updater.prev_memory)
else:
#loss = creterion(pred_pos, torch.ones_like(pred_pos)) + 0.1*inner_prod(model.module.memory_updater.last_updated_memory,model.module.memory_updater.pre_mem)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) total_loss += float(loss.item())
#mailbox.handle_last_async() #mailbox.handle_last_async()
#trainloader.async_feature() #trainloader.async_feature()
#torch.cuda.synchronize() #torch.cuda.synchronize()
...@@ -546,6 +544,8 @@ def main(): ...@@ -546,6 +544,8 @@ def main():
#torch.cuda.synchronize() #torch.cuda.synchronize()
mailbox.update_shared() mailbox.update_shared()
mailbox.update_p2p() mailbox.update_p2p()
#torch.cuda.empty_cache()
""" """
if mailbox is not None: if mailbox is not None:
#src = metadata['src_pos_index'] #src = metadata['src_pos_index']
...@@ -606,8 +606,8 @@ def main(): ...@@ -606,8 +606,8 @@ def main():
print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge)) print('local node number {} remote node number {} local edge {} remote edge{}\n'.format(local_node,remote_node,local_edge,remote_edge))
print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm)) print(' comm local node number {} remote node number {} local edge {} remote edge{}\n'.format(sum_local_comm,sum_remote_comm,sum_local_edge_comm,sum_remote_edge_comm))
print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count)) print('memory comm {} shared comm {}\n'.format(tot_comm_count,tot_shared_count))
if(e==0): #if(e==0):
torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) # torch.save((local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm),'all/{}/{}/comm/comm_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
ap = 0 ap = 0
auc = 0 auc = 0
tt.ssim_remote=0 tt.ssim_remote=0
......
...@@ -22,7 +22,6 @@ class Filter(nn.Module): ...@@ -22,7 +22,6 @@ class Filter(nn.Module):
# Treat filter as parameter so that it is saved and loaded together with the model # Treat filter as parameter so that it is saved and loaded together with the model
self.count = torch.zeros((self.n_nodes),1).to(self.device) self.count = torch.zeros((self.n_nodes),1).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device) self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
self.incretment_sqr = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device)
def get_count(self, node_idxs): def get_count(self, node_idxs):
...@@ -31,19 +30,7 @@ class Filter(nn.Module): ...@@ -31,19 +30,7 @@ class Filter(nn.Module):
def get_incretment(self, node_idxs): def get_incretment(self, node_idxs):
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape) #print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
return self.incretment[node_idxs,:]#/torch.clamp(self.count[node_idxs,:],1) return self.incretment[node_idxs,:]/torch.clamp(self.count[node_idxs,:],1)
def get_incretment_remote(self, idx):
remote_tensor = DistributedTensor(self.incretment)
remote_count = DistributedTensor(self.count)
idx,pos = idx.sort()
xpos = torch.empty_like(pos)
xpos[pos] = torch.arange(pos.shape[0],device = pos.device,dtype=pos.dtype)
return remote_tensor.all_to_all_get(idx)[xpos]/torch.clamp(remote_count.all_to_all_get(idx)[xpos],1)
def get_incretment_sqr(self, node_idxs):
return self.incretment_sqr[node_idxs, :]
def detach_filter(self): def detach_filter(self):
self.incretment.detach_() self.incretment.detach_()
...@@ -51,11 +38,7 @@ class Filter(nn.Module): ...@@ -51,11 +38,7 @@ class Filter(nn.Module):
def update(self, node_idxs, incret): def update(self, node_idxs, incret):
self.count[node_idxs, :] = self.count[node_idxs, :] + 1 self.count[node_idxs, :] = self.count[node_idxs, :] + 1
self.incretment[node_idxs, :] = self.incretment[node_idxs, :] + incret self.incretment[node_idxs, :] = self.incretment[node_idxs, :] + incret
self.incretment_sqr[node_idxs, :] = self.incretment_sqr[node_idxs, :] + incret * incret
def clear(self): def clear(self):
self.count.zero_()
self.incretment.zero_() self.incretment.zero_()
def compute_prediction(self, node_idxs):
mu = self.incretment[node_idxs, :]/self.count[node_idxs, :]
sigma = self.incretment_sqr[node_idxs, :]/self.count[node_idxs, :]
return mu, sigma
\ No newline at end of file
...@@ -101,7 +101,7 @@ class HistoricalCache: ...@@ -101,7 +101,7 @@ class HistoricalCache:
return mask return mask
def read_synchronize(self): def read_synchronize(self):
torch.cuda.synchronize(get_stream_set(self.layer)) get_stream_set(self.layer).synchronize()
# def get_cold_data(self,index,to_device): # def get_cold_data(self,index,to_device):
# with get_stream_set(self.layer): # with get_stream_set(self.layer):
......
...@@ -291,19 +291,19 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -291,19 +291,19 @@ class TransfomerAttentionLayer(torch.nn.Module):
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2) #tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2) #tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone() #V_local = V.clone()
V_remote = V.clone() #V_remote = V.clone()
V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0 #V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0 #V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V b.edata['v'] = V
b.edata['v0'] = V_local #b.edata['v0'] = V_local
b.edata['v1'] = V_remote #b.edata['v1'] = V_remote
b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0')) #b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1')) #b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h')) b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0'])) #tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1'])) #tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
tt.ssim_cnt += b.num_dst_nodes() #tt.ssim_cnt += b.num_dst_nodes()
#print('dst {}'.format(b.dstdata['h'])) #print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0) #b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h')) #b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
...@@ -421,7 +421,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -421,7 +421,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
elif self.mode == 'local' or self.mode=='all_local': elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func self.update_hunk = self.local_func
if self.mode == 'historical': if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]), self.gamma = torch.nn.Parameter(torch.tensor([0.5]),
requires_grad=True) requires_grad=True)
self.filter = Filter(n_nodes=mailbox.shared_nodes_index.shape[0], self.filter = Filter(n_nodes=mailbox.shared_nodes_index.shape[0],
memory_dimension=self.dim_hid, memory_dimension=self.dim_hid,
...@@ -430,7 +430,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -430,7 +430,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.gamma = 1 self.gamma = 1
def forward(self, mfg, param = None): def forward(self, mfg, param = None):
for b in mfg: for b in mfg:
mail_input = b.srcdata['mem_input'] #print(b.srcdata['ID'].shape[0])
updated_memory0 = self.updater(b) updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask] #incr = updated_memory[mask] - b.srcdata['mem'][mask]
...@@ -441,13 +441,14 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -441,13 +441,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
if self.mode == 'historical': if self.mode == 'historical':
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc] shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = self.filter.get_incretment(shared_ind) transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
transition_dense*=2
if not (transition_dense.max().item() == 0): if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min() transition_dense -= transition_dense.min()
transition_dense /=torch.clamp(transition_dense.max() ,1) transition_dense /= transition_dense.max()
transition_dense = 2*transition_dense - 1 transition_dense = 2*transition_dense - 1
upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense upd0[mask] = transition_dense#b.srcdata['his_mem'][mask] + transition_dense
#print(self.gamma)
#print('tran {} {} {}\n'.format(transition_dense.max().item(),upd0[mask].max().item(),b.srcdata['his_mem'][mask].max().item()))
else: else:
upd0[mask] = b.srcdata['his_mem'][mask] upd0[mask] = b.srcdata['his_mem'][mask]
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask]) #upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
...@@ -457,16 +458,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -457,16 +458,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
if self.mode == 'historical': if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask] change = updated_memory[mask] - b.srcdata['his_mem'][mask]
change.detach()
if not (change.max().item() == 0):
change -= change.min()
change /=change.max()
change = 2*change - 1
self.filter.update(shared_ind,change) self.filter.update(shared_ind,change)
#print(transition_dense)
#print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
#print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() self.last_updated_nid = b.srcdata['ID'].detach().clone()
...@@ -504,4 +496,6 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -504,4 +496,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
b.srcdata['h'] = updated_memory b.srcdata['h'] = updated_memory
def empty_cache(self): def empty_cache(self):
pass if self.mode == 'historical':
print('clear\n')
self.filter.clear()
...@@ -240,8 +240,8 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -240,8 +240,8 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
#print(CountComm.origin_local,CountComm.origin_remote) #print(CountComm.origin_local,CountComm.origin_remote)
#for p in range(dist.get_world_size()): #for p in range(dist.get_world_size()):
# print((DistIndex(dist_nid).part == p).sum().item()) # print((DistIndex(dist_nid).part == p).sum().item())
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item() #CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item() #CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True) dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
#print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid)) #print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid))
...@@ -250,7 +250,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -250,7 +250,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
""" """
if unique: if unique:
block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True) block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True)
#print(block_node_list.shape,unq_id.shape)
first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id) first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id)
first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool) first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool)
first_mask[first_index] = True first_mask[first_index] = True
...@@ -423,9 +423,9 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t ...@@ -423,9 +423,9 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device} param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param) out = sample_fn(sampler,data,neg_sampling,**param)
if reversed is False: if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device,reversed) out,dist_nid,dist_eid = to_block(graph,data,out,out_device)
else: else:
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,reversed) out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,identity=(sampler.policy=='identity'))
t_e = time.time() t_e = time.time()
#print(t_e-t_s) #print(t_e-t_s)
return out,dist_nid,dist_eid return out,dist_nid,dist_eid
......
...@@ -257,19 +257,19 @@ class DistributedDataLoader: ...@@ -257,19 +257,19 @@ class DistributedDataLoader:
pass pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result() batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0] b = batch_data[1][0][0]
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum() #self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum()
self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum() #self.local_node += (DistIndex(dist_nid).part == dist.get_rank()).sum()
self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum() #self.remote_edge += (DistIndex(dist_eid).part != dist.get_rank()).sum()
self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum() #self.local_edge += (DistIndex(dist_eid).part == dist.get_rank()).sum()
#self.remote_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part != dist.get_rank()).sum() #self.remote_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part != dist.get_rank()).sum()
#self.local_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part == dist.get_rank()).sum() #self.local_root += (DistIndex(dist_nid[b.srcdata['__ID'][:self.batch_size*2]]).part == dist.get_rank()).sum()
#torch.cuda.synchronize(stream) #torch.cuda.synchronize(stream)
start = torch.cuda.Event(enable_timing=True) #start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) #end = torch.cuda.Event(enable_timing=True)
start.record() #start.record()
stream.synchronize() stream.synchronize()
end.record() #end.record()
end.synchronize() #end.synchronize()
#print(start.elapsed_time(end)) #print(start.elapsed_time(end))
self.result_queue.popleft() self.result_queue.popleft()
#start = torch.cuda.Event(enable_timing=True) #start = torch.cuda.Event(enable_timing=True)
......
...@@ -13,10 +13,10 @@ class MemoryMoniter: ...@@ -13,10 +13,10 @@ class MemoryMoniter:
else: else:
return torch.sum((x-y)**2,dim=1) return torch.sum((x-y)**2,dim=1)
def add(self,nid,pre_memory,now_memory): def add(self,nid,pre_memory,now_memory):
pass
self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F')) #self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos')) #self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
self.nid_list.append(nid) #self.nid_list.append(nid)
def draw(self,degree,data,model,e): def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e)) torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e)) torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment