Commit e0e37a11 by zlj

add time count

parent 572f53f7
alpha.png

83.5 KB

......@@ -25,12 +25,12 @@ class ParallelSampler
th::Tensor eid_inv;
th::Tensor unq_id;
th::Tensor first_block_id;
double boundery_probility;
double init_boundary_probability;
vector<unsigned int> loc_seeds;
ParallelSampler(TemporalNeighborBlock& _tnb, NodeIDType _num_nodes, EdgeIDType _num_edges, int _threads,
vector<int>& _fanouts, int _num_layers, string _policy, int _local_part, th::Tensor _part, th::Tensor _node_part,double _p) :
tnb(_tnb), num_nodes(_num_nodes), num_edges(_num_edges), threads(_threads),
fanouts(_fanouts), num_layers(_num_layers), policy(_policy), local_part(_local_part), boundery_probility(_p)
fanouts(_fanouts), num_layers(_num_layers), policy(_policy), local_part(_local_part), init_boundary_probability(_p)
{
omp_set_num_threads(_threads);
ret.clear();
......@@ -52,12 +52,12 @@ class ParallelSampler
ret.resize(num_layers);
}
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique);
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique,optional<double> boundary_probability);
void neighbor_sample_from_nodes_static(th::Tensor nodes, bool part_unique);
void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer, bool part_unique);
void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts);
void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts,double boundary_probability);
void neighbor_sample_from_dynamic_nodes(th::Tensor nodes, th::Tensor root_ts);
void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer);
void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer,double boundary_probability);
template<typename T>
void union_to_vector(vector<T> *p, vector<T> &to_vec);
void sample_unique( th::Tensor seed, th::Tensor seed_ts,
......@@ -66,7 +66,7 @@ class ParallelSampler
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique)
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique,optional<double> boundary_probability)
{
omp_set_num_threads(threads);
if(policy == "weighted")
......@@ -81,7 +81,7 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th
AT_ASSERTM(tnb.with_timestamp, "Tnb has no timestamp infomation!");
AT_ASSERTM(root_ts.has_value(), "Parameter mismatch!");
//neighbor_sample_from_dynamic_nodes(nodes,root_ts.value());
neighbor_sample_from_nodes_with_before(nodes, root_ts.value());
neighbor_sample_from_nodes_with_before(nodes, root_ts.value(), boundary_probability.has_value()?boundary_probability.value():0);
}
else{
bool flag = part_unique.has_value() ? part_unique.value() : true;
......@@ -211,7 +211,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes, bool
}
void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
th::Tensor nodes, th::Tensor root_ts, int cur_layer){
th::Tensor nodes, th::Tensor root_ts, int cur_layer, double boundary_probability){
py::gil_scoped_release release;
double tot_start_time = omp_get_wtime();
ret[cur_layer] = TemporalGraphBlock();
......@@ -267,7 +267,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
if(p0 > boundery_probility)continue;
if(p0 > boundary_probability)continue;
}
tgb_i[tid].src_index.emplace_back(i);
tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]);
......@@ -306,7 +306,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
if((part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part)){
if(cal_cnt<=fanout){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1;
double ep = boundary_probability*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue;
//tgb_i[tid].sample_weight.emplace_back((float)ep);
}
......@@ -389,11 +389,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
py::gil_scoped_acquire acquire;
}
void ParallelSampler :: neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts){
void ParallelSampler :: neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts, double boundary_probability){
for(int i=0;i<num_layers;i++){
if(i==0) neighbor_sample_from_nodes_with_before_layer(nodes, root_ts, i);
if(i==0) neighbor_sample_from_nodes_with_before_layer(nodes, root_ts, i, boundary_probability);
else neighbor_sample_from_nodes_with_before_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes),
vecToTensor<TimeStampType>(ret[i-1].sample_nodes_ts), i);
vecToTensor<TimeStampType>(ret[i-1].sample_nodes_ts), i, boundary_probability);
}
}
......
import matplotlib.pyplot as plt
# 数据
alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9, 1.3, 1.5, 1.7, 2]
lastfm = [0.933538, 0.931761, 0.932499, 0.933383, 0.931225, 0.929983, 0.933971, 0.928771, 0.932748]
wikitalk = [0.979627, 0.97997, 0.979484, 0.980269, 0.980758, 0.97979, 0.980233, 0.980004, 0.980353]
stackoverflow = [0.979641, 0.979372, 0.97967, 0.978169, 0.979624, 0.978846, 0.978428, 0.978397, 0.978925]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15,4))
# 设置全局标题
#fig.suptitle('精度折线图', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(alpha_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('α', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(alpha_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('α', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(alpha_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('α', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('alpha.png')
import matplotlib.pyplot as plt
# 数据
theta_values = [0, 0.01, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1]
lastfm = [0.813412, 0.920839, 0.92926, 0.930115, 0.930719, 0.923493, 0.927908, 0.925335, 0.896277]
wikitalk = [0.967408, 0.979705, 0.979805, 0.97515, 0.979248, 0.979162, 0.978207, 0.977578, 0.975614]
stackoverflow = [0.919117667, 0.9778155, 0.978890667, 0.979456, 0.97914, 0.980143, 0.980606, 0.980387, 0.977536]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15, 4))
# 设置全局标题
#fig.suptitle('精度折线图', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(theta_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('θ', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(theta_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('θ', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(theta_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('θ', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('theta.png')
\ No newline at end of file
import matplotlib.pyplot as plt
# 数据
k_values = [0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3]
lastfm = [0.903726, 0.921197, 0.931237, 0.926789, 0.930719, 0.929332, 0.915848]
wikitalk = [0.981577, 0.980716, 0.979996, 0.979597, 0.979248, 0.975468, 0.972785]
stackoverflow = [0.974805, 0.978219, 0.97924, 0.979436, 0.979456, 0.976746, 0.972544]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15, 4))
# 设置全局标题
#fig.suptitle('Test AP vs topK', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(k_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('k', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(k_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('k', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(k_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('k', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('topk.png')
......@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM")
data_param=("WikiTalk")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......@@ -32,9 +32,9 @@ data_param=("WIKI" "LASTFM")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="TGN_large"
model="JODIE_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN"
model="JODIE"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......
......@@ -238,6 +238,7 @@ def main():
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
else:
policy_train = policy
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),
......
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h_src, h_pos_dst, h_neg_dst=None):
h_src = self.src_fc(h_src)
h_pos_dst = self.dst_fc(h_pos_dst)
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
prob_pos = self.out_fc(h_pos_edge)
prob_neg = None
if h_neg_dst is not None:
h_neg_dst = self.dst_fc(h_neg_dst)
neg_samples = h_neg_dst.shape[0] // h_src.shape[0]
h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)
prob_neg = self.out_fc(h_neg_edge)
return prob_pos, prob_neg
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=b.src_idx_cuda.device))
src_data = b.srcdata['h'][b.src_idx_cuda]
if b.combined:
q_data = torch.cat([src_data[:b.num_pos_dst], src_data[b.num_pos_idx:b.num_pos_idx + b.num_neg_dst]], dim=0)
kv_data = torch.cat([src_data[b.num_pos_dst:b.num_pos_idx], src_data[b.num_pos_idx + b.num_neg_dst:]], dim=0)
else:
q_data = src_data[:b.num_dst_nodes()]
kv_data = src_data[b.num_dst_nodes():]
if b.num_edges() > 0:
if self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([q_data, zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([kv_data, time_feat], dim=1))
V = self.w_v(torch.cat([kv_data, time_feat], dim=1))
else:
Q = self.w_q(torch.cat([q_data, zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([kv_data, b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([kv_data, b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
else:
b.dstdata['h'] = torch.zeros((b.num_dst_nodes(), self.dim_out), device=b.src_idx_cuda.device)
if b.combined:
rst_pos = torch.cat([b.dstdata['h'][:b.num_pos_dst], src_data[:b.num_pos_dst]], dim=1)
rst_neg = torch.cat([b.dstdata['h'][b.num_pos_dst:], src_data[b.num_pos_idx:b.num_pos_idx + b.num_neg_dst]], dim=1)
rst = torch.cat([rst_pos, rst_neg])
else:
rst = torch.cat([b.dstdata['h'], src_data[:b.num_dst_nodes()]], dim=1)
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
import torch
import dgl
import time
from starrygl.module.memorys import *
from starrygl.module.disttgl_layers import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_node=None, no_learn_node=False, combined=False, edge_classification=False, edge_classes=0):
super(GeneralModel, self).__init__()
self.edge_classification = edge_classification
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'smart':
self.memory_updater = SmartMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, num_node, no_learn_node=no_learn_node)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
else:
raise NotImplementedError
if not self.edge_classification:
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
else:
self.edge_classifier = torch.nn.Linear(2 * gnn_param['dim_out'], edge_classes)
def forward(self, mfg, metadata = None,neg_samples=1, mode = 'triplet'):
# import pdb; pdb.set_trace()
# torch.cuda.synchronize()
# t_s=time.time()
if self.memory_param['type'] == 'node':
self.memory_updater(mfg)
# torch.cuda.synchronize()
# t_mem = time.time() - t_s
# t_s = time.time()
rst = self.layers['l0h0'](mfg)
self.embedding = rst.detach().clone()
if not self.edge_classification:
return self.edge_predictor(rst[metadata['src_pos_index']],rst[metadata['dst_pos_index']],rst[metadata['dst_neg_index']])
else:
rst = torch.cat([rst[metadata['src_pos_index']], rst[metadata['dst_pos_index']]], dim=1)
return self.edge_classifier(rst), None
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
......@@ -41,7 +41,7 @@ pinn_memory = {}
class HistoricalCache:
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
def __init__(self,cache_index,layer,shape,dtype,device,ada_param = None,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
#self.cache_index = cache_index
self.layer = layer
print(shape)
......@@ -51,7 +51,8 @@ class HistoricalCache:
#self.ts = torch.zeros(cache_index.historical_num,dtype = torch.float,device = torch.device('cpu'))
self.local_ts = torch.zeros(cache_index.shape[0],dtype = torch.float,device = device)
self.loss_count = torch.zeros(cache_index.shape[0],dtype = torch.int,device = device)
self.threshold = threshold
self.threshold = ada_param.alpha
self.ada_param = ada_param
self.time_threshold = time_threshold
self.times_threshold = times_threshold
self.num_threshold = num_threshold
......@@ -87,6 +88,7 @@ class HistoricalCache:
else:
return torch.sum((x -y)**2,dim = 1)
def historical_check(self,index,new_data,ts):
self.threshold = self.ada_param.alpha
if self.time_threshold is not None:
mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
self.loss_count[index][~mask] += 1
......@@ -134,6 +136,8 @@ class HistoricalCache:
self.last_shared_update_wait = None
handle0.wait()
handle1.wait()
if self.ada_param.training is True:
self.ada_param.update_memory_sync_time(self.ada_param.last_start_event_memory_sync)
shared_data = torch.cat(shared_data,dim = 0)
shared_index = torch.cat(shared_index)
if(shared_data.shape[0] == 0):
......
......@@ -395,6 +395,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.sychronize_shared()
self.mailbox.handle_last_async()
#if self.ada_param.training is True:
# self.ada_param.last_start_event_memory_sync = self.ada_param.start_event()
#self.ada_param.update_memory_sync_time(self.ada_param.last_start_event_memory_sync)
submit_to_queue = False
if nxt_fetch_func is not None:
submit_to_queue = True
......@@ -406,6 +409,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
if nxt_fetch_func is not None:
nxt_fetch_func()
......@@ -420,7 +424,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem'])
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None, ada_param = None):
super(AsyncMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
......@@ -436,6 +440,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_ts = None
self.last_updated_nid = None
self.delta_memory = 0
self.ada_param = ada_param
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
......@@ -523,7 +528,12 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
if self.training is True:
self.ada_param.last_start_event_gnn_aggregate = self.ada_param.start_event()
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
......
......@@ -69,13 +69,14 @@ class NegFixLayer(torch.autograd.Function):
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None,ada_param = None):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
self.ada_param = ada_param
#self.train_pos_ratio,self.train_neg_ratio = train_ratio
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
......@@ -89,7 +90,7 @@ class GeneralModel(torch.nn.Module):
#else:
updater = torch.nn.GRUCell
# if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],ada_param=ada_param)
# else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'rnn':
......@@ -98,12 +99,12 @@ class GeneralModel(torch.nn.Module):
#else:
updater = torch.nn.RNNCell
#if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'])
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],ada_param=ada_param)
# else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer':
updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param,ada_param=ada_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
......
import yaml
import numpy as np
import torch
import math
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
......@@ -35,4 +36,110 @@ class EarlyStopMonitor(object):
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
return self.num_round >= self.max_round
class AdaParameter:
def __init__(self, wait_threshold=0.1, init_beta = 0.1 ,init_alpha = 0.1, min_beta = 0.01, max_beta = 1, min_alpha = 0.1, max_alpha = 1):
self.wait_threshold = wait_threshold
self.beta = init_beta
self.alpha = init_alpha
self.average_fetch = 0
self.count_fetch = 0
self.average_memory_sync = 0
self.count_memory_sync = 0
self.average_memory_update= 0
self.count_memory_update = 0
self.average_gnn_aggregate = 0
self.count_gnn_aggregate = 0
self.counts = 0
self.last_start_event_fetch = None
self.last_start_event_memory_sync = None
self.last_start_event_memory_update = None
self.last_start_event_gnn_aggregate = None
self.min_beta = min_beta
self.max_beta = max_beta
self.max_alpha = max_alpha
self.min_alpha = min_alpha
self.training = False
def train(self):
self.training = True
def eval(self):
self.training = False
def start_event(self):
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
return start_event
def update_fetch_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_fetch += elapsed_time_ms
self.count_fetch += 1
def update_memory_sync_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_sync += elapsed_time_ms
self.count_memory_sync += 1
def update_memory_update_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_update += elapsed_time_ms
self.count_memory_update += 1
def update_gnn_aggregate_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_gnn_aggregate += elapsed_time_ms
self.count_gnn_aggregate += 1
def reset_time(self):
self.average_fetch = 0
self.count_fetch = 0
self.average_memory_sync = 0
self.count_memory_sync = 0
self.average_memory_update= 0
self.count_memory_update = 0
self.average_gnn_aggregate = 0
self.count_gnn_aggregate = 0
def update_parameter(self):
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
return
average_gnn_aggregate = self.average_gnn_aggregate/self.count_gnn_aggregate
average_fetch = self.average_fetch/self.count_fetch
self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold)
average_memory_sync_time = self.average_memory_sync/self.count_memory_sync
average_memory_update_time = self.average_memory_update/self.count_memory_update
self.alpha = (2-math.pow((2-self.alpha),average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold)))
self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
#a1 = 2 - 2-a2 ^(t1/t2 * (1 + wait_threshold))
......@@ -101,8 +101,10 @@ class DistributedDataLoader:
use_local_feature = True,
probability = 1,
reversed = False,
ada_param = None,
**kwargs
):
self.ada_param = ada_param
self.reversed = reversed
self.use_local_feature = use_local_feature
self.local_embedding = local_embedding
......@@ -255,6 +257,7 @@ class DistributedDataLoader:
return
while(len(self.result_queue)==0):
pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0]
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum().item()
......@@ -275,16 +278,21 @@ class DistributedDataLoader:
#start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True)
#start.record()
nind,ndata = get_node_all_to_all_route(self.graph,self.mailbox,dist_nid,out_device=self.device)
eind,edata = get_edge_all_to_all_route(self.graph,dist_eid,out_device=self.device)
if nind is not None:
node_feat = DistributedTensor.all_to_all_get_data(ndata,send_ptr=nind['send_ptr'],recv_ptr=nind['recv_ptr'],is_async=True)
else:
node_feat = None
if eind is not None:
edge_feat = DistributedTensor.all_to_all_get_data(edata,send_ptr=eind['send_ptr'],recv_ptr=eind['recv_ptr'],is_async=True)
else:
edge_feat = None
if self.ada_param is not None:
self.ada_param.last_start_event_fetch = self.ada_param.start_event()
t3 = time.time()
self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat))
self.submit()
......@@ -356,10 +364,12 @@ class DistributedDataLoader:
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]]
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
#print(len(self.result_queue[0]))
if isinstance(self.result_queue[0],tuple) :
t0 = time.time()
batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = self.result_queue[0]
self.result_queue.popleft()
......@@ -382,6 +392,8 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0]
node_feat = None
mem = self.mailbox.unpack(node_feat0,mailbox = True)
if self.ada_param is not None:
self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
#print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait()
#node_feat = node_feat[0]
......@@ -395,6 +407,8 @@ class DistributedDataLoader:
#mem = (mem[0][0],mem[1][0],mem[2][0],mem[3][0])
#node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox, dist_nid,is_local,out_device=self.device)
t1 = time.time()
#if self.ada_param is not None:
# self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
else:
batch_data,dist_nid,dist_eid = self.result_queue[0].result()
stream.synchronize()
......@@ -410,6 +424,7 @@ class DistributedDataLoader:
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
self.recv_idxs += 1
else:
raise StopIteration
......
......@@ -59,8 +59,9 @@ class SharedMailBox():
uvm = False,
use_pin = False,
start_historical = False,
shared_ssim = 2):
ada_param = None):
ctx = distributed.context._get_default_dist_context()
self.ada_param = ada_param
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
......@@ -106,7 +107,7 @@ class SharedMailBox():
device=torch.device('cuda:{}'.format(ctx.local_rank)))
if start_historical:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim)
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,ada_param)
else:
self.historical_cache = None
self._mem_pin = {}
......@@ -294,6 +295,8 @@ class SharedMailBox():
if self.next_wait_gather_memory_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.next_wait_gather_memory_job
self.next_wait_gather_memory_job = None
if self.ada_param.training is True:
self.ada_param.last_start_event_memory_sync = self.ada_param.start_event()
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
......
......@@ -22,7 +22,7 @@ class LocalNegativeSampling(NegativeSampling):
dst_node_list: torch.Tensor = None,
local_mask = None,
seed = None,
prob = None
ada_param = None
):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
......@@ -38,7 +38,11 @@ class LocalNegativeSampling(NegativeSampling):
self.local_mask = local_mask
if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask]
self.prob = prob
self.ada_param = ada_param
self.remote_weight = None
self.local_weight = None
#self.prob = prob
#self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool:
......@@ -50,6 +54,7 @@ class LocalNegativeSampling(NegativeSampling):
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.is_binary():
if self.src_node_list is None or self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, )),torch.randint(num_nodes, (num_samples, ))
......@@ -60,10 +65,16 @@ class LocalNegativeSampling(NegativeSampling):
if self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None:
prob = self.ada_param.beta
remote_ratio = self.local_dst.shape[0] / self.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
self.train_ratio_pos = 1.0/(1-prob+ prob * remote_ratio) if ((prob<1) & (prob > 0)) else 1
self.train_ratio_neg = 1.0/(prob*remote_ratio) if ((prob <1) & (prob > 0)) else 1
p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=self.prob,sr,sl)
s=torch.where(p<=prob,sr,sl)
return s
else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
......
......@@ -96,8 +96,8 @@ class NeighborSampler(BaseSampler):
local_part = -1,
node_part = None,
edge_part = None,
probability = 1,
no_neg = False,
ada_param = None
) -> None:
r"""__init__
Args:
......@@ -140,9 +140,9 @@ class NeighborSampler(BaseSampler):
else:
assert tnb is not None
self.tnb = tnb
self.ada_param = ada_param
self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy, local_part,edge_part.to(torch.int),node_part.to(torch.int),probability)
fanout, num_layers, policy, local_part,edge_part.to(torch.int),node_part.to(torch.int),ada_param.beta)
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
......@@ -229,7 +229,7 @@ class NeighborSampler(BaseSampler):
sampled_edge_index_list: the edge sampled
"""
if self.policy != 'identity':
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None)
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None, self.ada_param.beta if self.ada_param is not None else None)
ret = self.p_sampler.get_ret()
else:
ret = None
......
theta.png

55.6 KB

topk.png

61.2 KB

Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment