Commit 8adc9ed9 by zlj

fix smooth aggregation

parent 27dbfee5
sampling: sampling:
- layer: 1 - layer: 1
neighbor: neighbor:
- 20 - 10
strategy: 'recent' strategy: 'recent'
prop_time: False prop_time: False
history: 1 history: 1
...@@ -27,10 +27,10 @@ gnn: ...@@ -27,10 +27,10 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 200 - epoch: 100
batch_size: 1000 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0004 lr: 0.0008
dropout: 0.2 dropout: 0.2
att_dropout: 0.2 att_dropout: 0.2
all_on_gpu: True all_on_gpu: True
...@@ -30,7 +30,7 @@ train: ...@@ -30,7 +30,7 @@ train:
- epoch: 50 - epoch: 50
batch_size: 3000 batch_size: 3000
# reorder: 16 # reorder: 16
lr: 0.0005 lr: 0.0008
dropout: 0.2 dropout: 0.2
att_dropout: 0.2 att_dropout: 0.2
all_on_gpu: True all_on_gpu: True
...@@ -300,13 +300,16 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -300,13 +300,16 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
cal_cnt = 0; cal_cnt = 0;
for(int cid = end_index-1;cid>=0;cid--){ for(int cid = end_index-1;cid>=0;cid--){
cal_cnt++; cal_cnt++;
if(cal_cnt > fanout)break; //if(cal_cnt > fanout)break;
int eid = tnb.eid[node][cid]; int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){ if((part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part)){
if(cal_cnt<=fanout){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0); double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1; double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue; if(p0 > ep)continue;
}
else continue;
//cout<<"in"<<endl; //cout<<"in"<<endl;
} }
tgb_i[tid].src_index.emplace_back(i); tgb_i[tid].src_index.emplace_back(i);
......
...@@ -31,7 +31,7 @@ class Filter(nn.Module): ...@@ -31,7 +31,7 @@ class Filter(nn.Module):
def get_incretment(self, node_idxs): def get_incretment(self, node_idxs):
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape) #print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
return self.incretment[node_idxs,:]/torch.clamp(self.count[node_idxs,:],1) return self.incretment[node_idxs,:]#/torch.clamp(self.count[node_idxs,:],1)
def get_incretment_remote(self, idx): def get_incretment_remote(self, idx):
remote_tensor = DistributedTensor(self.incretment) remote_tensor = DistributedTensor(self.incretment)
......
...@@ -139,8 +139,8 @@ class HistoricalCache: ...@@ -139,8 +139,8 @@ class HistoricalCache:
if(shared_data.shape[0] == 0): if(shared_data.shape[0] == 0):
return None return None
len = self.local_historical_data.shape[1] len = self.local_historical_data.shape[1]
mail_ts = shared_data[:,-1] #mail_ts = shared_data[:,-1]
mail_data = shared_data[:,len+1:-1] #mail_data = shared_data[:,len+1:-1]
shared_ts = shared_data[:,len] shared_ts = shared_data[:,len]
shared_mem = shared_data[:,:len] shared_mem = shared_data[:,:len]
#print(shared_index) #print(shared_index)
...@@ -150,8 +150,8 @@ class HistoricalCache: ...@@ -150,8 +150,8 @@ class HistoricalCache:
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0) #shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_mem = shared_mem[idx] shared_mem = shared_mem[idx]
shared_ts = shared_ts[idx] shared_ts = shared_ts[idx]
mail_data = mail_data[idx] #mail_data = mail_data[idx]
mail_ts = mail_ts[idx] #mail_ts = mail_ts[idx]
shared_index = unq_index shared_index = unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts)) #print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None: # if filter is not None:
...@@ -164,7 +164,7 @@ class HistoricalCache: ...@@ -164,7 +164,7 @@ class HistoricalCache:
self.local_historical_data[shared_index] = shared_mem self.local_historical_data[shared_index] = shared_mem
self.local_ts[shared_index] = shared_ts self.local_ts[shared_index] = shared_ts
self.last_shared_update_wait = None self.last_shared_update_wait = None
return shared_index,shared_mem,shared_ts,mail_data,mail_ts return shared_index,shared_mem,shared_ts#,mail_data,mail_ts
......
...@@ -288,8 +288,8 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -288,8 +288,8 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1)) #att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2))) att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att) att = self.att_dropout(att)
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2) #tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2) #tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone() V_local = V.clone()
V_remote = V.clone() V_remote = V.clone()
......
...@@ -445,7 +445,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -445,7 +445,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
transition_dense*=2 transition_dense*=2
if not (transition_dense.max().item() == 0): if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min() transition_dense -= transition_dense.min()
transition_dense /=transition_dense.max() transition_dense /=torch.clamp(transition_dense.max() ,1)
transition_dense = 2*transition_dense - 1 transition_dense = 2*transition_dense - 1
upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense
else: else:
...@@ -456,7 +456,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -456,7 +456,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0) updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
with torch.no_grad(): with torch.no_grad():
if self.mode == 'historical': if self.mode == 'historical':
change = upd0[mask] - b.srcdata['his_mem'][mask] change = updated_memory[mask] - b.srcdata['his_mem'][mask]
change.detach() change.detach()
if not (change.max().item() == 0): if not (change.max().item() == 0):
change -= change.min() change -= change.min()
......
...@@ -42,14 +42,15 @@ class time_count: ...@@ -42,14 +42,15 @@ class time_count:
return time.perf_counter(),0 return time.perf_counter(),0
@staticmethod @staticmethod
def elapsed_event(start_event): def elapsed_event(start_event):
if isinstance(start_event,tuple): #if isinstance(start_event,tuple):
start_event,end_event = start_event # start_event,end_event = start_event
end_event.record() # end_event.record()
end_event.synchronize() # end_event.synchronize()
return start_event.elapsed_time(end_event) # return start_event.elapsed_time(end_event)
else: #else:
torch.cuda.synchronize() # torch.cuda.synchronize()
return time.perf_counter() - start_event # return time.perf_counter() - start_event
return 0
@staticmethod @staticmethod
def print(): def print():
print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format( print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format(
......
...@@ -253,7 +253,7 @@ class SharedMailBox(): ...@@ -253,7 +253,7 @@ class SharedMailBox():
def sychronize_shared(self): def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update() out=self.historical_cache.synchronize_shared_update()
if out is not None: if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out shared_index,shared_data,shared_ts = out
index = self.shared_nodes_index[shared_index] index = self.shared_nodes_index[shared_index]
mask= (shared_ts > self.node_memory_ts.accessor.data[index]) mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory.accessor.data[index][mask] = shared_data[mask] self.node_memory.accessor.data[index][mask] = shared_data[mask]
...@@ -320,8 +320,10 @@ class SharedMailBox(): ...@@ -320,8 +320,10 @@ class SharedMailBox():
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode) #mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode) #mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0] mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
else:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode) mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int) broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0] broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)] shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
......
...@@ -286,10 +286,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud ...@@ -286,10 +286,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
return load_from_shared_node_partition(data,None,None,sample_add_rev=sampler_graph_add_rev,device=device,feature_device=feature_device) return load_from_shared_node_partition(data,None,None,sample_add_rev=sampler_graph_add_rev,device=device,feature_device=feature_device)
else: else:
if partition == 'ours': if partition == 'ours':
fnode_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank) fnode_i = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top) fnode_share = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_t2/{}/reorder.txt'.format(data) reorder = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank) edge_i = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'metis': elif partition == 'metis':
fnode_i = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank) fnode_i = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top) fnode_share = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
......
...@@ -63,8 +63,8 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -63,8 +63,8 @@ class LocalNegativeSampling(NegativeSampling):
p = torch.rand(size=(num_samples,)) p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)] sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)] sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=self.prob,sl,sr) s=torch.where(p<=self.prob,sr,sl)
return sr return s
else: else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm) s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s] return self.dst_node_list[s]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment