Commit 8adc9ed9 by zlj

fix smooth aggregation

parent 27dbfee5
sampling:
- layer: 1
neighbor:
- 20
- 10
strategy: 'recent'
prop_time: False
history: 1
......@@ -27,10 +27,10 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 200
- epoch: 100
batch_size: 1000
# reorder: 16
lr: 0.0004
lr: 0.0008
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
......@@ -30,7 +30,7 @@ train:
- epoch: 50
batch_size: 3000
# reorder: 16
lr: 0.0005
lr: 0.0008
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
......@@ -300,13 +300,16 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
cal_cnt = 0;
for(int cid = end_index-1;cid>=0;cid--){
cal_cnt++;
if(cal_cnt > fanout)break;
//if(cal_cnt > fanout)break;
int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue;
if((part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part)){
if(cal_cnt<=fanout){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue;
}
else continue;
//cout<<"in"<<endl;
}
tgb_i[tid].src_index.emplace_back(i);
......
......@@ -31,7 +31,7 @@ class Filter(nn.Module):
def get_incretment(self, node_idxs):
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
return self.incretment[node_idxs,:]/torch.clamp(self.count[node_idxs,:],1)
return self.incretment[node_idxs,:]#/torch.clamp(self.count[node_idxs,:],1)
def get_incretment_remote(self, idx):
remote_tensor = DistributedTensor(self.incretment)
......
......@@ -139,8 +139,8 @@ class HistoricalCache:
if(shared_data.shape[0] == 0):
return None
len = self.local_historical_data.shape[1]
mail_ts = shared_data[:,-1]
mail_data = shared_data[:,len+1:-1]
#mail_ts = shared_data[:,-1]
#mail_data = shared_data[:,len+1:-1]
shared_ts = shared_data[:,len]
shared_mem = shared_data[:,:len]
#print(shared_index)
......@@ -150,8 +150,8 @@ class HistoricalCache:
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_mem = shared_mem[idx]
shared_ts = shared_ts[idx]
mail_data = mail_data[idx]
mail_ts = mail_ts[idx]
#mail_data = mail_data[idx]
#mail_ts = mail_ts[idx]
shared_index = unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None:
......@@ -164,7 +164,7 @@ class HistoricalCache:
self.local_historical_data[shared_index] = shared_mem
self.local_ts[shared_index] = shared_ts
self.last_shared_update_wait = None
return shared_index,shared_mem,shared_ts,mail_data,mail_ts
return shared_index,shared_mem,shared_ts#,mail_data,mail_ts
......
......@@ -288,8 +288,8 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
#tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
#tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone()
V_remote = V.clone()
......
......@@ -445,7 +445,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
transition_dense*=2
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /=transition_dense.max()
transition_dense /=torch.clamp(transition_dense.max() ,1)
transition_dense = 2*transition_dense - 1
upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense
else:
......@@ -456,7 +456,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
with torch.no_grad():
if self.mode == 'historical':
change = upd0[mask] - b.srcdata['his_mem'][mask]
change = updated_memory[mask] - b.srcdata['his_mem'][mask]
change.detach()
if not (change.max().item() == 0):
change -= change.min()
......
......@@ -42,14 +42,15 @@ class time_count:
return time.perf_counter(),0
@staticmethod
def elapsed_event(start_event):
if isinstance(start_event,tuple):
start_event,end_event = start_event
end_event.record()
end_event.synchronize()
return start_event.elapsed_time(end_event)
else:
torch.cuda.synchronize()
return time.perf_counter() - start_event
#if isinstance(start_event,tuple):
# start_event,end_event = start_event
# end_event.record()
# end_event.synchronize()
# return start_event.elapsed_time(end_event)
#else:
# torch.cuda.synchronize()
# return time.perf_counter() - start_event
return 0
@staticmethod
def print():
print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format(
......
......@@ -253,7 +253,7 @@ class SharedMailBox():
def sychronize_shared(self):
out=self.historical_cache.synchronize_shared_update()
if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out
shared_index,shared_data,shared_ts = out
index = self.shared_nodes_index[shared_index]
mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory.accessor.data[index][mask] = shared_data[mask]
......@@ -320,8 +320,10 @@ class SharedMailBox():
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,index=shared_memory_ind,mode=mode)
else:
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
self.tot_shared_count += shared_memory_ind.shape[0]
mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts = shared_mail_ts,index=shared_memory_ind,mode=mode)
broadcast_len = torch.empty([1],device = mem.device,dtype = torch.int)
broadcast_len[0] = shared_memory_ind.shape[0]
shared_len = [torch.empty([1],device = mem.device,dtype = torch.int) for _ in range(ctx.memory_group_size)]
......
......@@ -286,10 +286,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
return load_from_shared_node_partition(data,None,None,sample_add_rev=sampler_graph_add_rev,device=device,feature_device=feature_device)
else:
if partition == 'ours':
fnode_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_t2/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_t2/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_i = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
reorder = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/reorder.txt'.format(data)
edge_i = '../../SPEED/partition/divided_nodes_seed_starrygl/{}/{}/{}_{}parts_top{}/edge_output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
elif partition == 'metis':
fnode_i = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/output{}.txt'.format(data,seed,data,ctx.memory_group_size,top,ctx.memory_group_rank)
fnode_share = '../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/outputshared.txt'.format(data,seed,data,ctx.memory_group_size,top)
......
......@@ -63,8 +63,8 @@ class LocalNegativeSampling(NegativeSampling):
p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=self.prob,sl,sr)
return sr
s=torch.where(p<=self.prob,sr,sl)
return s
else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment