Commit 0dc877d4 by xxx

fix weight for training when boundary sample

parent 74b73c7d
bound.png

16.4 KB

......@@ -37,6 +37,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("src_index", [](const TemporalGraphBlock &tgb) { return vecToTensor<EdgeIDType>(tgb.src_index); })
.def("sample_nodes", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.sample_nodes); })
.def("sample_nodes_ts", [](const TemporalGraphBlock &tgb) { return vecToTensor<TimeStampType>(tgb.sample_nodes_ts); })
.def("sample_weight",[](const TemporalGraphBlock &tgb){
return vecToTensor<float>(tgb.sample_weight);
})
.def_readonly("sample_time", &TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
......
......@@ -11,6 +11,7 @@ class TemporalGraphBlock
vector<int64_t> src_index;
vector<NodeIDType> sample_nodes;
vector<TimeStampType> sample_nodes_ts;
vector<float> sample_weight;
vector<WeightType> e_weights;
double sample_time = 0;
double tot_time = 0;
......
......@@ -308,9 +308,14 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue;
tgb_i[tid].sample_weight.emplace_back((float)ep);
}
else continue;
//cout<<"in"<<endl;
}
else{
tgb_i[tid].sample_weight.emplace_back((float)1.0);
}
tgb_i[tid].src_index.emplace_back(i);
tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]);
......@@ -358,6 +363,8 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
each_begin[i]=size;
size += s;
}
if(policy == "boundery_recent_decay")
ret[cur_layer].sample_weight.resize(size);
ret[cur_layer].eid.resize(size);
ret[cur_layer].src_index.resize(size);
ret[cur_layer].delta_ts.resize(size);
......@@ -366,6 +373,8 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
if(policy == "boundery_recent_decay")
copy(tgb_i[i].sample_weight.begin(), tgb_i[i].sample_weight.end(), ret[cur_layer].sample_weight.begin()+each_begin[i]);
copy(tgb_i[i].eid.begin(), tgb_i[i].eid.end(), ret[cur_layer].eid.begin()+each_begin[i]);
copy(tgb_i[i].src_index.begin(), tgb_i[i].src_index.end(), ret[cur_layer].src_index.begin()+each_begin[i]);
copy(tgb_i[i].delta_ts.begin(), tgb_i[i].delta_ts.end(), ret[cur_layer].delta_ts.begin()+each_begin[i]);
......
import matplotlib.pyplot as plt
import numpy as np
# 数据
p_values = ['recent', 'p=0.1', 'p=0.05', 'p=0.01', 'p=0']
wiki_values = [0.979832, 0.980298, 0.975079, 0.97349, 0.96381]
lastfm_values = [0.820161, 0.852725, 0.848085, 0.817381, 0.796689]
wikitalk_values = [0.969647, 0.974473, 0.973996, 0.968961, 0.964867]
gdelt_values = [0.987338, 0.987454, 0.987038, 0.98812, 0.98726]
# 柱状图的宽度
barWidth = 0.15
# 柱状图的位置
r1 = np.arange(len(wiki_values))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
r4 = [x + barWidth for x in r3]
# 创建图形
plt.figure(figsize=(12,8))
plt.bar(r1, wiki_values, color='b', width=barWidth, edgecolor='grey', label='WIKI')
plt.bar(r2, lastfm_values, color='r', width=barWidth, edgecolor='grey', label='LASTFM')
plt.bar(r3, wikitalk_values, color='g', width=barWidth, edgecolor='grey', label='WikiTalk')
plt.bar(r4, gdelt_values, color='y', width=barWidth, edgecolor='grey', label='GDELT')
# 添加标签
plt.xlabel('p values', fontweight='bold', fontsize=15)
plt.ylabel('SSIM', fontweight='bold', fontsize=15)
plt.xticks([r + barWidth for r in range(len(wiki_values))], p_values)
plt.savefig('bound.png')
plt.legend()
plt.show()
......@@ -295,11 +295,17 @@ class TransfomerAttentionLayer(torch.nn.Module):
#V_remote = V.clone()
#V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
#V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V
#b.edata['v0'] = V_local
#b.edata['v1'] = V_remote
#b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
#b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
if 'weight' in b.edata:
with torch.no_grad():
weight = b.edata['weight'].reshape(-1,1)#(b.edata['weight']/torch.sum(b.edata['weight']).item()).reshape(-1,1)
#print(weight.max())
b.edata['v'] = V*weight
else:
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
......
......@@ -290,7 +290,9 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
weight = sample_out[r].sample_weight()
if(weight.shape[0] > 0):
b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
b.edata['__ID'] = e_idx
col = row
col_len += eid_len[r]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment