Commit e893493a by zlj

delete weight in c scourcee

parent 529db53a
......@@ -308,14 +308,14 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue;
tgb_i[tid].sample_weight.emplace_back((float)ep);
//tgb_i[tid].sample_weight.emplace_back((float)ep);
}
else continue;
//cout<<"in"<<endl;
}
else{
tgb_i[tid].sample_weight.emplace_back((float)1.0);
//tgb_i[tid].sample_weight.emplace_back((float)1.0);
}
tgb_i[tid].src_index.emplace_back(i);
tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]);
......@@ -373,8 +373,8 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
if(policy == "boundery_recent_decay")
copy(tgb_i[i].sample_weight.begin(), tgb_i[i].sample_weight.end(), ret[cur_layer].sample_weight.begin()+each_begin[i]);
//if(policy == "boundery_recent_decay")
// copy(tgb_i[i].sample_weight.begin(), tgb_i[i].sample_weight.end(), ret[cur_layer].sample_weight.begin()+each_begin[i]);
copy(tgb_i[i].eid.begin(), tgb_i[i].eid.end(), ret[cur_layer].eid.begin()+each_begin[i]);
copy(tgb_i[i].src_index.begin(), tgb_i[i].src_index.end(), ret[cur_layer].src_index.begin()+each_begin[i]);
copy(tgb_i[i].delta_ts.begin(), tgb_i[i].delta_ts.end(), ret[cur_layer].delta_ts.begin()+each_begin[i]);
......
......@@ -496,7 +496,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
"""
with torch.no_grad():
if param is not None:
_,src,dst,ts,edge_feats,nxt_fetch_func = param
......@@ -523,7 +523,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
"""
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment