Commit e0e37a11 by zlj

add time count

parent 572f53f7
alpha.png

83.5 KB

...@@ -25,12 +25,12 @@ class ParallelSampler ...@@ -25,12 +25,12 @@ class ParallelSampler
th::Tensor eid_inv; th::Tensor eid_inv;
th::Tensor unq_id; th::Tensor unq_id;
th::Tensor first_block_id; th::Tensor first_block_id;
double boundery_probility; double init_boundary_probability;
vector<unsigned int> loc_seeds; vector<unsigned int> loc_seeds;
ParallelSampler(TemporalNeighborBlock& _tnb, NodeIDType _num_nodes, EdgeIDType _num_edges, int _threads, ParallelSampler(TemporalNeighborBlock& _tnb, NodeIDType _num_nodes, EdgeIDType _num_edges, int _threads,
vector<int>& _fanouts, int _num_layers, string _policy, int _local_part, th::Tensor _part, th::Tensor _node_part,double _p) : vector<int>& _fanouts, int _num_layers, string _policy, int _local_part, th::Tensor _part, th::Tensor _node_part,double _p) :
tnb(_tnb), num_nodes(_num_nodes), num_edges(_num_edges), threads(_threads), tnb(_tnb), num_nodes(_num_nodes), num_edges(_num_edges), threads(_threads),
fanouts(_fanouts), num_layers(_num_layers), policy(_policy), local_part(_local_part), boundery_probility(_p) fanouts(_fanouts), num_layers(_num_layers), policy(_policy), local_part(_local_part), init_boundary_probability(_p)
{ {
omp_set_num_threads(_threads); omp_set_num_threads(_threads);
ret.clear(); ret.clear();
...@@ -52,12 +52,12 @@ class ParallelSampler ...@@ -52,12 +52,12 @@ class ParallelSampler
ret.resize(num_layers); ret.resize(num_layers);
} }
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique); void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique,optional<double> boundary_probability);
void neighbor_sample_from_nodes_static(th::Tensor nodes, bool part_unique); void neighbor_sample_from_nodes_static(th::Tensor nodes, bool part_unique);
void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer, bool part_unique); void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer, bool part_unique);
void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts); void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts,double boundary_probability);
void neighbor_sample_from_dynamic_nodes(th::Tensor nodes, th::Tensor root_ts); void neighbor_sample_from_dynamic_nodes(th::Tensor nodes, th::Tensor root_ts);
void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer); void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer,double boundary_probability);
template<typename T> template<typename T>
void union_to_vector(vector<T> *p, vector<T> &to_vec); void union_to_vector(vector<T> *p, vector<T> &to_vec);
void sample_unique( th::Tensor seed, th::Tensor seed_ts, void sample_unique( th::Tensor seed, th::Tensor seed_ts,
...@@ -66,7 +66,7 @@ class ParallelSampler ...@@ -66,7 +66,7 @@ class ParallelSampler
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique) void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique,optional<double> boundary_probability)
{ {
omp_set_num_threads(threads); omp_set_num_threads(threads);
if(policy == "weighted") if(policy == "weighted")
...@@ -81,7 +81,7 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th ...@@ -81,7 +81,7 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th
AT_ASSERTM(tnb.with_timestamp, "Tnb has no timestamp infomation!"); AT_ASSERTM(tnb.with_timestamp, "Tnb has no timestamp infomation!");
AT_ASSERTM(root_ts.has_value(), "Parameter mismatch!"); AT_ASSERTM(root_ts.has_value(), "Parameter mismatch!");
//neighbor_sample_from_dynamic_nodes(nodes,root_ts.value()); //neighbor_sample_from_dynamic_nodes(nodes,root_ts.value());
neighbor_sample_from_nodes_with_before(nodes, root_ts.value()); neighbor_sample_from_nodes_with_before(nodes, root_ts.value(), boundary_probability.has_value()?boundary_probability.value():0);
} }
else{ else{
bool flag = part_unique.has_value() ? part_unique.value() : true; bool flag = part_unique.has_value() ? part_unique.value() : true;
...@@ -211,7 +211,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes, bool ...@@ -211,7 +211,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes, bool
} }
void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
th::Tensor nodes, th::Tensor root_ts, int cur_layer){ th::Tensor nodes, th::Tensor root_ts, int cur_layer, double boundary_probability){
py::gil_scoped_release release; py::gil_scoped_release release;
double tot_start_time = omp_get_wtime(); double tot_start_time = omp_get_wtime();
ret[cur_layer] = TemporalGraphBlock(); ret[cur_layer] = TemporalGraphBlock();
...@@ -267,7 +267,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -267,7 +267,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
int eid = tnb.eid[node][cid]; int eid = tnb.eid[node][cid];
if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){ if(part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0); double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
if(p0 > boundery_probility)continue; if(p0 > boundary_probability)continue;
} }
tgb_i[tid].src_index.emplace_back(i); tgb_i[tid].src_index.emplace_back(i);
tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]); tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]);
...@@ -306,7 +306,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -306,7 +306,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
if((part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part)){ if((part[tnb.eid[node][cid]] != local_part|| node_part[tnb.neighbors[node][cid]]!= local_part)){
if(cal_cnt<=fanout){ if(cal_cnt<=fanout){
double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0); double p0 = (double)rand_r(&loc_seeds[tid]) / (RAND_MAX + 1.0);
double ep = boundery_probility*pr[cal_cnt-1]/sum_p*sum_1; double ep = boundary_probability*pr[cal_cnt-1]/sum_p*sum_1;
if(p0 > ep)continue; if(p0 > ep)continue;
//tgb_i[tid].sample_weight.emplace_back((float)ep); //tgb_i[tid].sample_weight.emplace_back((float)ep);
} }
...@@ -389,11 +389,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -389,11 +389,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
py::gil_scoped_acquire acquire; py::gil_scoped_acquire acquire;
} }
void ParallelSampler :: neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts){ void ParallelSampler :: neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts, double boundary_probability){
for(int i=0;i<num_layers;i++){ for(int i=0;i<num_layers;i++){
if(i==0) neighbor_sample_from_nodes_with_before_layer(nodes, root_ts, i); if(i==0) neighbor_sample_from_nodes_with_before_layer(nodes, root_ts, i, boundary_probability);
else neighbor_sample_from_nodes_with_before_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes), else neighbor_sample_from_nodes_with_before_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes),
vecToTensor<TimeStampType>(ret[i-1].sample_nodes_ts), i); vecToTensor<TimeStampType>(ret[i-1].sample_nodes_ts), i, boundary_probability);
} }
} }
......
import matplotlib.pyplot as plt
# 数据
alpha_values = [0.1, 0.3, 0.5, 0.7, 0.9, 1.3, 1.5, 1.7, 2]
lastfm = [0.933538, 0.931761, 0.932499, 0.933383, 0.931225, 0.929983, 0.933971, 0.928771, 0.932748]
wikitalk = [0.979627, 0.97997, 0.979484, 0.980269, 0.980758, 0.97979, 0.980233, 0.980004, 0.980353]
stackoverflow = [0.979641, 0.979372, 0.97967, 0.978169, 0.979624, 0.978846, 0.978428, 0.978397, 0.978925]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15,4))
# 设置全局标题
#fig.suptitle('精度折线图', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(alpha_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('α', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(alpha_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('α', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(alpha_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('α', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('alpha.png')
import matplotlib.pyplot as plt
# 数据
theta_values = [0, 0.01, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1]
lastfm = [0.813412, 0.920839, 0.92926, 0.930115, 0.930719, 0.923493, 0.927908, 0.925335, 0.896277]
wikitalk = [0.967408, 0.979705, 0.979805, 0.97515, 0.979248, 0.979162, 0.978207, 0.977578, 0.975614]
stackoverflow = [0.919117667, 0.9778155, 0.978890667, 0.979456, 0.97914, 0.980143, 0.980606, 0.980387, 0.977536]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15, 4))
# 设置全局标题
#fig.suptitle('精度折线图', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(theta_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('θ', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(theta_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('θ', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(theta_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('θ', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('theta.png')
\ No newline at end of file
import matplotlib.pyplot as plt
# 数据
k_values = [0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3]
lastfm = [0.903726, 0.921197, 0.931237, 0.926789, 0.930719, 0.929332, 0.915848]
wikitalk = [0.981577, 0.980716, 0.979996, 0.979597, 0.979248, 0.975468, 0.972785]
stackoverflow = [0.974805, 0.978219, 0.97924, 0.979436, 0.979456, 0.976746, 0.972544]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15, 4))
# 设置全局标题
#fig.suptitle('Test AP vs topK', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(k_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('k', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(k_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('k', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(k_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('k', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('topk.png')
...@@ -19,7 +19,7 @@ memory_type=("historical") ...@@ -19,7 +19,7 @@ memory_type=("historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM") data_param=("WikiTalk")
#"GDELT") #"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
...@@ -32,9 +32,9 @@ data_param=("WIKI" "LASTFM") ...@@ -32,9 +32,9 @@ data_param=("WIKI" "LASTFM")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="TGN_large" model="JODIE_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="JODIE"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
......
...@@ -238,6 +238,7 @@ def main(): ...@@ -238,6 +238,7 @@ def main():
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability) ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
else: else:
policy_train = policy policy_train = policy
ada_param = AdaParameter(init_alpha=args.shared_memory_ssim, init_beta=args.probability)
if memory_param['type'] != 'none': if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0, mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)), shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),
......
import torch
from os.path import abspath, join, dirname
import sys
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.count_static import time_count
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
import time
forward_time = 0
backward_time = 0
t = [0,0,0,0]
def get_forward_time():
global forward_time
return forward_time
def get_backward_time():
global backward_time
return backward_time
def get_t():
global t
return t
class all_to_all_embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, input,metadata,neg_samples, memory = None, use_emb = False ):
ctx.save_for_backward(input,
metadata['src_pos_index'],
metadata['dst_pos_index'],
metadata['dst_neg_index'],
metadata['dist_src_index'],
metadata['dist_neg_src_index']
)
with torch.no_grad():
out = input
h_pos_src = out[metadata['src_pos_index']]
h_dst_index = torch.cat((metadata['dst_pos_index'],metadata['dst_neg_index']))
#print(h_dst_index)
h_dst_src_index = torch.cat((metadata['dist_src_index'],metadata['dist_neg_src_index']))
#print(h_dst_src_index)
h_dst = DistributedTensor(torch.empty(h_pos_src.shape[0]*(neg_samples+1),
h_pos_src.shape[1],dtype = h_pos_src.dtype,device = h_pos_src.device))
h_data = out[h_dst_index]
dist_index,ind = h_dst_src_index.sort()
h_dst.all_to_all_set(h_data[ind],dist_index)
h_pos_dst = h_dst.accessor.data[:h_pos_src.shape[0],:]
h_neg_dst = h_dst.accessor.data[h_pos_src.shape[0]:,:]
#print(h_pos_dst,h_neg_dst)
"""
h_pos_dst_data = out[metadata['dst_pos_index']]
h_neg_dst_data = out[metadata['dst_neg_index']]
h_pos_dst = DistributedTensor(torch.empty_like(h_pos_src,device = h_pos_src.device))
h_neg_dst = DistributedTensor(torch.empty(h_pos_src.shape[0]*neg_samples,
h_pos_src.shape[1],dtype = h_pos_src.dtype,device = h_pos_src.device))
dist_index,ind = metadata['dist_src_index'].sort()
h_pos_dst.all_to_all_set(h_pos_dst_data[ind],dist_index)
h_pos_dst = h_pos_dst.accessor.data
dist_index0,ind0 = metadata['dist_neg_src_index'].sort()
h_neg_dst.all_to_all_set(h_neg_dst_data[ind0],dist_index0)
h_neg_dst = h_neg_dst.accessor.data
"""
src_mem = None
mem = None
if memory is not None:
local_memory = DistributedTensor(memory[metadata['src_pos_index']])
dst_memory = memory[metadata['dst_pos_index']]
dist_index0,ind0 = metadata['dist_src_index'].sort()
send_ptr = local_memory.all_to_all_ind2ptr(dist_index0)
mem = DistributedTensor(torch.empty_like(local_memory.accessor.data))
mem.all_to_all_set(dst_memory[ind0],**send_ptr)
#print(send_ptr,ind.max(),src_mem,local_memory.shape)
src_mem = local_memory.all_to_all_get(**send_ptr)[ind0]
#print(src_mem)
mem = mem.accessor.data
"""
local_memory = DistributedTensor(memory[metadata['src_pos_index']])
dst_memory = memory[metadata['dst_pos_index']]
mem = DistributedTensor(torch.empty_like(local_memory.accessor.data))
mem.all_to_all_set(dst_memory[ind],dist_index)
src_mem = DistributedTensor(torch.empty_like(dst_memory))
src_mem = local_memory.all_to_all_send(**metadata['dst_send_dict'])
mem = mem.accessor.data
"""
elif use_emb is True:
mem = h_pos_dst
local_embedding = DistributedTensor(h_pos_src)
src_mem = local_embedding.all_to_all_send(**metadata['dst_send_dict'])
#t[2] += t3-t2
#t[3] += t4-t3
return h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem
@staticmethod
def backward(ctx, grad_pos_src,remote_pos_dst,remote_neg_dst,grad0,grad1):
out,src_pos_index,dst_pos_index,dst_neg_index,dist_src_index,dist_neg_src_index = ctx.saved_tensors
with torch.no_grad():
torch.cuda.synchronize()
t0 = time.time()
remote_dst = DistributedTensor(torch.cat((remote_pos_dst,remote_neg_dst),dim = 0))
dist_index,ind = torch.cat((dist_src_index,dist_neg_src_index)).sort()
grad_dst = remote_dst.all_to_all_get(dist_index)
grad_dst[ind] = grad_dst.clone()
grad = torch.empty_like(out)
grad[src_pos_index] = grad_pos_src
grad[dst_pos_index] = grad_dst[:dst_pos_index.shape[0],:]
grad[dst_neg_index] = grad_dst[dst_pos_index.shape[0]:,:]
"""
remote_pos_dst = DistributedTensor(remote_pos_dst)
remote_neg_dst = DistributedTensor(remote_neg_dst)
dist_index,ind = dist_src_index.sort()
grad_pos_dst = remote_pos_dst.all_to_all_get(dist_index)
grad_pos_dst[ind] = grad_pos_dst.clone()
dist_index_neg,ind_neg = dist_neg_src_index.sort()
grad_neg_dst = remote_neg_dst.all_to_all_get(dist_index_neg)
grad_neg_dst[ind_neg] = grad_neg_dst.clone()
grad = torch.empty_like(out)
grad[src_pos_index] = grad_pos_src
grad[dst_pos_index] = grad_pos_dst
grad[dst_neg_index] = grad_neg_dst
"""
torch.cuda.synchronize()
t1 = time.time()
time_count.add_backward_all_to_all(t1 -t0)
return grad,None,None,None,None
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False, cache_index=None):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,cache_index)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
if 'historical' in self.gnn_param and self.gnn_param['historical'] == True:
self.historical_cache = {}
for l in range(0,gnn_param['layer']):
self.historical_cache['l'+str(l)] = HistoricalCache(cache_index,shape=(cache_index.size,gnn_param['dim_out']),dtype = torch.float,device = torch.device('cuda'))
self.all_to_all_embedding = all_to_all_embedding.apply
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
self.history_embedding_time = TimeEncode(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def empty_cache(self):
if 'historical' in self.gnn_param and self.gnn_param['historical'] == True:
for l in range(0,self.gnn_param['layer']):
self.historical_cache['l'+str(l)].empty()
if self.memory_param['type'] == 'node':
self.memory_updater.empty_cache()
def get_sub_block(self,block,src_index):
all_dst_nodes = src_index
src,dst,eids = block.in_edges(src_index,form = 'all')
unq,ind = torch.cat((src,dst,all_dst_nodes)).unique(return_inverse=True)
subgraph = dgl.create_block((ind[:src.shape[0]],ind[src.shape[0]:src.shape[0]+dst.shape[0]]),
num_src_nodes = int(unq.shape[0]),
num_dst_nodes = int(ind[src.shape[0]:].max().item())+1,
device = block.device)
for k in block.srcdata:
subgraph.srcdata[k] = block.srcdata[k][unq]
for k in block.edata:
subgraph.edata[k] = block.edata[k][eids]
return subgraph
def forward(self, mfgs, metadata = None,neg_samples=1, mode = 'triplet'):
torch.cuda.synchronize()
t0 = time.time()
if(metadata['src_pos_index'].shape[0] == 0 or metadata['dst_pos_index'].shape[0] == 0 or metadata['dst_neg_index'].shape[0] == 0):
print(metadata['src_pos_index'].shape,metadata['dst_pos_index'].shape,metadata['dst_neg_index'].shape)
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
"""
g = mfgs[l][h]
subgraph0 = self.get_sub_block(g,metadata['src_pos_index'])
subgraph1 = self.get_sub_block(g,metadata['dst_pos_index'])
subgraph2 = self.get_sub_block(g,metadata['dst_neg_index'])
rst0 = self.layers['l' + str(l) + 'h' + str(h)](subgraph0)
rst1 = self.layers['l' + str(l) + 'h' + str(h)](subgraph1)
rst2 = self.layers['l' + str(l) + 'h' + str(h)](subgraph2)
rst = torch.empty(mfgs[l][h].num_dst_nodes(),rst0.shape[1],device = rst0.device,dtype = rst0.dtype)
rst[metadata['src_pos_index']] = rst0
rst[metadata['dst_pos_index']] = rst1
rst[metadata['dst_neg_index']] = rst2
"""
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
if 'historical' in self.gnn_param and self.gnn_param['historical'] is True:
local_mask = DistIndex(mfgs[l+1][h].srcdata['ID']).part == torch.distributed.get_rank()
with torch.no_grad():
historical_embedding = self.historical_cache['l'+str(l)].get_data(mfgs[l+1][h].srcdata['_ID'][~local_mask])
historical_ts = self.historical_cache['l'+str(l)].get_ts(mfgs[l+1][h].srcdata['_ID'][~local_mask])
self.historical_cache['l'+str(l)].update(mfgs[l+1][h].srcdata['_ID'][local_mask],rst[local_mask],mfgs[l+1][h].srcdata['ts'][:mfgs[l][h].num_dst_nodes()][local_mask])
rst[~local_mask] = historical_embedding #+ self.history_embedding_time(historical_ts)
mfgs[l + 1][h].srcdata['h'] = rst
else:
##test using historical embedding for remote nodes
#if 'historical' in self.gnn_param and self.gnn_param['historical'] is True:
# local_mask = DistIndex(mfgs[l][h].srcdata['ID'][:mfgs[l][h].num_dst_nodes()]).part == torch.distributed.get_rank()
# with torch.no_grad():
# history_embedding = self.historical_cache['l'+str(l)].get_data(mfgs[l][h].srcdata['_ID'][:mfgs[l][h].num_dst_nodes()][~local_mask])
# self.historical_cache['l'+str(l)].update(mfgs[l][h].srcdata['_ID'][:mfgs[l][h].num_dst_nodes()][local_mask],rst[local_mask],mfgs[l][h].srcdata#['ts'][:mfgs[l][h].num_dst_nodes()][local_mask])
# historical_ts = self.historical_cache['l'+str(l)].get_ts(mfgs[l][h].srcdata['_ID'][:mfgs[l][h].num_dst_nodes()][~local_mask])
# rst[~local_mask] = history_embedding + self.history_embedding_time(historical_ts)
#######
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
self.out = out
torch.cuda.synchronize()
t1 = time.time()
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if 'dist_src_index' not in metadata:
h_pos_src = out[metadata['src_pos_index']]
h_pos_dst = out[metadata['dst_pos_index']]
h_neg_dst = out[metadata['dst_neg_index']]
if 'src_neg_index' in metadata:
h_neg_src = out[metadata['src_neg_index']]
return self.edge_predictor(h_pos_src, h_pos_dst, h_neg_src, h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
return self.edge_predictor(h_pos_src, h_pos_dst, None , h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
if self.memory_param['type'] == 'node':
h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem = self.all_to_all_embedding(out,metadata,neg_samples,self.memory_updater.last_updated_memory,self.gnn_param['use_src_emb'])
self.dst_mem = mem.detach().clone()
self.src_mem = src_mem.detach().clone()
else:
h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem = self.all_to_all_embedding(out,metadata,neg_samples,None,self.gnn_param['use_src_emb'])
torch.cuda.synchronize()
t2 = time.time()
time_count.add_train_forward_embedding(t1-t0)
time_count.add_train_foward_all_to_all(t2-t1)
return self.edge_predictor(h_pos_src, h_pos_dst, None, h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import torch
from os.path import abspath, join, dirname
import sys
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.count_static import time_count
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
import time
import concurrent.futures
forward_time = 0
backward_time = 0
t = [0,0,0,0]
def get_forward_time():
global forward_time
return forward_time
def get_backward_time():
global backward_time
return backward_time
def get_t():
global t
return t
class all_to_all_embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, input,metadata,neg_samples, memory = None, use_emb = False ):
ctx.save_for_backward(input,
metadata['src_pos_index'],
metadata['dst_pos_index'],
metadata['dst_neg_index'],
metadata['dist_src_index'],
metadata['dist_neg_src_index']
)
with torch.no_grad():
out = input
h_pos_src = out[metadata['src_pos_index']]
h_dst_index = torch.cat((metadata['dst_pos_index'],metadata['dst_neg_index']))
#print(h_dst_index)
h_dst_src_index = torch.cat((metadata['dist_src_index'],metadata['dist_neg_src_index']))
#print(h_dst_src_index)
h_dst = DistributedTensor(torch.empty(h_pos_src.shape[0]*(neg_samples+1),
h_pos_src.shape[1],dtype = h_pos_src.dtype,device = h_pos_src.device))
h_data = out[h_dst_index]
dist_index,ind = h_dst_src_index.sort()
h_dst.all_to_all_set(h_data[ind],dist_index)
h_pos_dst = h_dst.accessor.data[:h_pos_src.shape[0],:]
h_neg_dst = h_dst.accessor.data[h_pos_src.shape[0]:,:]
#print(h_pos_dst,h_neg_dst)
"""
h_pos_dst_data = out[metadata['dst_pos_index']]
h_neg_dst_data = out[metadata['dst_neg_index']]
h_pos_dst = DistributedTensor(torch.empty_like(h_pos_src,device = h_pos_src.device))
h_neg_dst = DistributedTensor(torch.empty(h_pos_src.shape[0]*neg_samples,
h_pos_src.shape[1],dtype = h_pos_src.dtype,device = h_pos_src.device))
dist_index,ind = metadata['dist_src_index'].sort()
h_pos_dst.all_to_all_set(h_pos_dst_data[ind],dist_index)
h_pos_dst = h_pos_dst.accessor.data
dist_index0,ind0 = metadata['dist_neg_src_index'].sort()
h_neg_dst.all_to_all_set(h_neg_dst_data[ind0],dist_index0)
h_neg_dst = h_neg_dst.accessor.data
"""
src_mem = None
mem = None
if memory is not None:
local_memory = DistributedTensor(memory[metadata['src_pos_index']])
dst_memory = memory[metadata['dst_pos_index']]
dist_index0,ind0 = metadata['dist_src_index'].sort()
send_ptr = local_memory.all_to_all_ind2ptr(dist_index0)
mem = DistributedTensor(torch.empty_like(local_memory.accessor.data))
mem.all_to_all_set(dst_memory[ind0],**send_ptr)
#print(send_ptr,ind.max(),src_mem,local_memory.shape)
src_mem = local_memory.all_to_all_get(**send_ptr)[ind0]
#print(src_mem)
mem = mem.accessor.data
"""
local_memory = DistributedTensor(memory[metadata['src_pos_index']])
dst_memory = memory[metadata['dst_pos_index']]
mem = DistributedTensor(torch.empty_like(local_memory.accessor.data))
mem.all_to_all_set(dst_memory[ind],dist_index)
src_mem = DistributedTensor(torch.empty_like(dst_memory))
src_mem = local_memory.all_to_all_send(**metadata['dst_send_dict'])
mem = mem.accessor.data
"""
if use_emb is True:
mem = h_pos_dst
local_embedding = DistributedTensor(h_pos_src)
src_mem = local_embedding.all_to_all_send(**metadata['dst_send_dict'])
#t[2] += t3-t2
#t[3] += t4-t3
return h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem
@staticmethod
def backward(ctx, grad_pos_src,remote_pos_dst,remote_neg_dst,grad0,grad1):
out,src_pos_index,dst_pos_index,dst_neg_index,dist_src_index,dist_neg_src_index = ctx.saved_tensors
with torch.no_grad():
torch.cuda.synchronize()
t0 = time.time()
remote_dst = DistributedTensor(torch.cat((remote_pos_dst,remote_neg_dst),dim = 0))
dist_index,ind = torch.cat((dist_src_index,dist_neg_src_index)).sort()
grad_dst = remote_dst.all_to_all_get(dist_index)
grad_dst[ind] = grad_dst.clone()
grad = torch.empty_like(out)
grad[src_pos_index] = grad_pos_src
grad[dst_pos_index] = grad_dst[:dst_pos_index.shape[0],:]
grad[dst_neg_index] = grad_dst[dst_pos_index.shape[0]:,:]
"""
remote_pos_dst = DistributedTensor(remote_pos_dst)
remote_neg_dst = DistributedTensor(remote_neg_dst)
dist_index,ind = dist_src_index.sort()
grad_pos_dst = remote_pos_dst.all_to_all_get(dist_index)
grad_pos_dst[ind] = grad_pos_dst.clone()
dist_index_neg,ind_neg = dist_neg_src_index.sort()
grad_neg_dst = remote_neg_dst.all_to_all_get(dist_index_neg)
grad_neg_dst[ind_neg] = grad_neg_dst.clone()
grad = torch.empty_like(out)
grad[src_pos_index] = grad_pos_src
grad[dst_pos_index] = grad_pos_dst
grad[dst_neg_index] = grad_neg_dst
"""
torch.cuda.synchronize()
t1 = time.time()
time_count.add_backward_all_to_all(t1 -t0)
return grad,None,None,None,None
executor = concurrent.futures.ThreadPoolExecutor(1)
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False, cache_index=None):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,cache_index)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param
['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
if 'historical' in self.gnn_param and self.gnn_param['historical'] == True:
self.historical_cache = {}
for l in range(0,gnn_param['layer']):
self.historical_cache['l'+str(l)] = HistoricalCache(cache_index,shape=(cache_index.size,gnn_param['dim_out']),dtype = torch.float,device = torch.device('cuda'))
self.all_to_all_embedding = all_to_all_embedding.apply
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
self.history_embedding_time = TimeEncode(gnn_param['dim_out'])
self.historical_update = torch.nn.Linear(gnn_param['dim_out'] + gnn_param['dim_out'],gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def empty_cache(self):
if 'historical' in self.gnn_param and self.gnn_param['historical'] == True:
for l in range(0,self.gnn_param['layer']):
self.historical_cache['l'+str(l)].empty()
if self.memory_param['type'] == 'node':
self.memory_updater.empty_cache()
def get_sub_block(self,block,src_index):
all_dst_nodes = src_index
src,dst,eids = block.in_edges(src_index,form = 'all')
unq,ind = torch.cat((src,dst,all_dst_nodes)).unique(return_inverse=True)
subgraph = dgl.create_block((ind[:src.shape[0]],ind[src.shape[0]:src.shape[0]+dst.shape[0]]),
num_src_nodes = int(unq.shape[0]),
num_dst_nodes = int(ind[src.shape[0]:].max().item())+1,
device = block.device)
for k in block.srcdata:
subgraph.srcdata[k] = block.srcdata[k][unq]
for k in block.edata:
subgraph.edata[k] = block.edata[k][eids]
return subgraph
@staticmethod
def send_memory_async(metadata,memory = None):
if memory is not None:
local_memory = DistributedTensor(memory[metadata['src_pos_index']])
dst_memory = memory[metadata['dst_pos_index']]
dist_index0,ind0 = metadata['dist_src_index'].sort()
send_ptr = local_memory.all_to_all_ind2ptr(dist_index0)
mem = DistributedTensor(torch.empty_like(local_memory.accessor.data))
mem.all_to_all_set(dst_memory[ind0],**send_ptr,)
src_mem = local_memory.all_to_all_get(**send_ptr)[ind0]
mem = mem.accessor.data
return mem,src_mem
else:
return None,None
def forward(self, mfgs, metadata = None,neg_samples=1, mode = 'triplet'):
#torch.cuda.synchronize()
t0 = time.time()
#if(metadata['src_pos_index'].shape[0] == 0 or metadata['dst_pos_index'].shape[0] == 0 or metadata['dst_neg_index'].shape[0] == 0):
#print(metadata['src_pos_index'].shape,metadata['dst_pos_index'].shape,metadata['dst_neg_index'].shape)
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
if (metadata is not None) and ('dist_src_index' in metadata):
update_mem = self.memory_updater.last_updated_memory
fut = executor.submit(GeneralModel.send_memory_async,metadata,update_mem)
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
if 'historical' in self.gnn_param and self.gnn_param['historical'] is True:
local_mask = DistIndex(mfgs[l+1][h].srcdata['ID']).part == torch.distributed.get_rank()
with torch.no_grad():
historical_embedding,historical_ts = self.historical_cache['l'+str(l)].get_data(mfgs[l+1][h].srcdata['_ID'][~local_mask])
self.historical_cache['l'+str(l)].update(mfgs[l+1][h].srcdata['_ID'][local_mask],rst[local_mask],mfgs[l+1][h].srcdata['ts'][:mfgs[l][h].num_dst_nodes()][local_mask])
rst[~local_mask] = historical_embedding #+ self.history_embedding_time(historical_ts)
mfgs[l + 1][h].srcdata['h'] = rst
else:
##test using historical embedding for remote nodes
if 'historical' in self.gnn_param and self.gnn_param['historical'] is True:
local_mask = DistIndex(mfgs[l][h].srcdata['ID'][:mfgs[l][h].num_dst_nodes()]).part == torch.distributed.get_rank()
with torch.no_grad():
history_embedding,historical_ts = self.historical_cache['l'+str(l)].get_data(mfgs[l][h].srcdata['_ID'][:mfgs[l][h].num_dst_nodes()][~local_mask])
self.historical_cache['l'+str(l)].update(mfgs[l][h].srcdata['_ID'][:mfgs[l][h].num_dst_nodes()][local_mask],rst[local_mask],mfgs[l][h].srcdata['ts'][:mfgs[l][h].num_dst_nodes()][local_mask])
rst[~local_mask] = self.historical_update(torch.cat((history_embedding ,self.history_embedding_time(historical_ts)),dim = 1))
#history_embedding + self.history_embedding_time(historical_ts)
#######
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
self.out = out
if self.memory_param['type'] == 'node':
if (metadata is not None) and ('dist_src_index' in metadata):
with torch.no_grad():
self.dst_mem,self.src_mem = fut.result()
#torch.cuda.synchronize()
t1 = time.time()
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if 'dist_src_index' not in metadata:
h_pos_src = out[metadata['src_pos_index']]
h_pos_dst = out[metadata['dst_pos_index']]
h_neg_dst = out[metadata['dst_neg_index']]
if 'src_neg_index' in metadata:
h_neg_src = out[metadata['src_neg_index']]
return self.edge_predictor(h_pos_src, h_pos_dst, h_neg_src, h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
return self.edge_predictor(h_pos_src, h_pos_dst, None , h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
if self.memory_param['type'] == 'node':
h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem = self.all_to_all_embedding(out,metadata,neg_samples,None,self.gnn_param['use_src_emb'])
#self.dst_mem = mem.detach().clone()
#self.src_mem = src_mem.detach().clone()
else:
h_pos_src,h_pos_dst,h_neg_dst,mem,src_mem = self.all_to_all_embedding(out,metadata,neg_samples,None,self.gnn_param['use_src_emb'])
#torch.cuda.synchronize()
t2 = time.time()
time_count.add_train_forward_embedding(t1-t0)
time_count.add_train_foward_all_to_all(t2-t1)
return self.edge_predictor(h_pos_src, h_pos_dst, None, h_neg_dst, neg_samples=neg_samples, mode = mode)
else:
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h_src, h_pos_dst, h_neg_dst=None):
h_src = self.src_fc(h_src)
h_pos_dst = self.dst_fc(h_pos_dst)
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
prob_pos = self.out_fc(h_pos_edge)
prob_neg = None
if h_neg_dst is not None:
h_neg_dst = self.dst_fc(h_neg_dst)
neg_samples = h_neg_dst.shape[0] // h_src.shape[0]
h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)
prob_neg = self.out_fc(h_neg_edge)
return prob_pos, prob_neg
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=b.src_idx_cuda.device))
src_data = b.srcdata['h'][b.src_idx_cuda]
if b.combined:
q_data = torch.cat([src_data[:b.num_pos_dst], src_data[b.num_pos_idx:b.num_pos_idx + b.num_neg_dst]], dim=0)
kv_data = torch.cat([src_data[b.num_pos_dst:b.num_pos_idx], src_data[b.num_pos_idx + b.num_neg_dst:]], dim=0)
else:
q_data = src_data[:b.num_dst_nodes()]
kv_data = src_data[b.num_dst_nodes():]
if b.num_edges() > 0:
if self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([q_data, zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([kv_data, time_feat], dim=1))
V = self.w_v(torch.cat([kv_data, time_feat], dim=1))
else:
Q = self.w_q(torch.cat([q_data, zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([kv_data, b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([kv_data, b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
else:
b.dstdata['h'] = torch.zeros((b.num_dst_nodes(), self.dim_out), device=b.src_idx_cuda.device)
if b.combined:
rst_pos = torch.cat([b.dstdata['h'][:b.num_pos_dst], src_data[:b.num_pos_dst]], dim=1)
rst_neg = torch.cat([b.dstdata['h'][b.num_pos_dst:], src_data[b.num_pos_idx:b.num_pos_idx + b.num_neg_dst]], dim=1)
rst = torch.cat([rst_pos, rst_neg])
else:
rst = torch.cat([b.dstdata['h'], src_data[:b.num_dst_nodes()]], dim=1)
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
import torch
import dgl
import time
from starrygl.module.memorys import *
from starrygl.module.disttgl_layers import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_node=None, no_learn_node=False, combined=False, edge_classification=False, edge_classes=0):
super(GeneralModel, self).__init__()
self.edge_classification = edge_classification
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'smart':
self.memory_updater = SmartMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, num_node, no_learn_node=no_learn_node)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
else:
raise NotImplementedError
if not self.edge_classification:
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
else:
self.edge_classifier = torch.nn.Linear(2 * gnn_param['dim_out'], edge_classes)
def forward(self, mfg, metadata = None,neg_samples=1, mode = 'triplet'):
# import pdb; pdb.set_trace()
# torch.cuda.synchronize()
# t_s=time.time()
if self.memory_param['type'] == 'node':
self.memory_updater(mfg)
# torch.cuda.synchronize()
# t_mem = time.time() - t_s
# t_s = time.time()
rst = self.layers['l0h0'](mfg)
self.embedding = rst.detach().clone()
if not self.edge_classification:
return self.edge_predictor(rst[metadata['src_pos_index']],rst[metadata['dst_pos_index']],rst[metadata['dst_neg_index']])
else:
rst = torch.cat([rst[metadata['src_pos_index']], rst[metadata['dst_pos_index']]], dim=1)
return self.edge_classifier(rst), None
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
...@@ -41,7 +41,7 @@ pinn_memory = {} ...@@ -41,7 +41,7 @@ pinn_memory = {}
class HistoricalCache: class HistoricalCache:
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0): def __init__(self,cache_index,layer,shape,dtype,device,ada_param = None,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
#self.cache_index = cache_index #self.cache_index = cache_index
self.layer = layer self.layer = layer
print(shape) print(shape)
...@@ -51,7 +51,8 @@ class HistoricalCache: ...@@ -51,7 +51,8 @@ class HistoricalCache:
#self.ts = torch.zeros(cache_index.historical_num,dtype = torch.float,device = torch.device('cpu')) #self.ts = torch.zeros(cache_index.historical_num,dtype = torch.float,device = torch.device('cpu'))
self.local_ts = torch.zeros(cache_index.shape[0],dtype = torch.float,device = device) self.local_ts = torch.zeros(cache_index.shape[0],dtype = torch.float,device = device)
self.loss_count = torch.zeros(cache_index.shape[0],dtype = torch.int,device = device) self.loss_count = torch.zeros(cache_index.shape[0],dtype = torch.int,device = device)
self.threshold = threshold self.threshold = ada_param.alpha
self.ada_param = ada_param
self.time_threshold = time_threshold self.time_threshold = time_threshold
self.times_threshold = times_threshold self.times_threshold = times_threshold
self.num_threshold = num_threshold self.num_threshold = num_threshold
...@@ -87,6 +88,7 @@ class HistoricalCache: ...@@ -87,6 +88,7 @@ class HistoricalCache:
else: else:
return torch.sum((x -y)**2,dim = 1) return torch.sum((x -y)**2,dim = 1)
def historical_check(self,index,new_data,ts): def historical_check(self,index,new_data,ts):
self.threshold = self.ada_param.alpha
if self.time_threshold is not None: if self.time_threshold is not None:
mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold)) mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
self.loss_count[index][~mask] += 1 self.loss_count[index][~mask] += 1
...@@ -134,6 +136,8 @@ class HistoricalCache: ...@@ -134,6 +136,8 @@ class HistoricalCache:
self.last_shared_update_wait = None self.last_shared_update_wait = None
handle0.wait() handle0.wait()
handle1.wait() handle1.wait()
if self.ada_param.training is True:
self.ada_param.update_memory_sync_time(self.ada_param.last_start_event_memory_sync)
shared_data = torch.cat(shared_data,dim = 0) shared_data = torch.cat(shared_data,dim = 0)
shared_index = torch.cat(shared_index) shared_index = torch.cat(shared_index)
if(shared_data.shape[0] == 0): if(shared_data.shape[0] == 0):
......
...@@ -395,6 +395,9 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -395,6 +395,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False): def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.mailbox.sychronize_shared() self.mailbox.sychronize_shared()
self.mailbox.handle_last_async() self.mailbox.handle_last_async()
#if self.ada_param.training is True:
# self.ada_param.last_start_event_memory_sync = self.ada_param.start_event()
#self.ada_param.update_memory_sync_time(self.ada_param.last_start_event_memory_sync)
submit_to_queue = False submit_to_queue = False
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
submit_to_queue = True submit_to_queue = True
...@@ -406,6 +409,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -406,6 +409,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail, wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False, update_cross_mm=False,
) )
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
...@@ -420,7 +424,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -420,7 +424,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts']) time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1) b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem']) return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem'])
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None): def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None, ada_param = None):
super(AsyncMemeoryUpdater, self).__init__() super(AsyncMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat self.dim_node_feat = dim_node_feat
...@@ -436,6 +440,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -436,6 +440,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_ts = None self.last_updated_ts = None
self.last_updated_nid = None self.last_updated_nid = None
self.delta_memory = 0 self.delta_memory = 0
self.ada_param = ada_param
if dim_time > 0: if dim_time > 0:
self.time_enc = TimeEncode(dim_time) self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']: if memory_param['combine_node_feature']:
...@@ -524,6 +529,11 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -524,6 +529,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
is_deliver=(self.mailbox.deliver_to == 'neighbors') is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver) self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
if self.training is True:
self.ada_param.last_start_event_gnn_aggregate = self.ada_param.start_event()
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory b.srcdata['h'] += updated_memory
......
...@@ -69,13 +69,14 @@ class NegFixLayer(torch.autograd.Function): ...@@ -69,13 +69,14 @@ class NegFixLayer(torch.autograd.Function):
class GeneralModel(torch.nn.Module): class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None): def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None,ada_param = None):
super(GeneralModel, self).__init__() super(GeneralModel, self).__init__()
self.dim_node = dim_node self.dim_node = dim_node
self.dim_node_input = dim_node self.dim_node_input = dim_node
self.dim_edge = dim_edge self.dim_edge = dim_edge
self.sample_param = sample_param self.sample_param = sample_param
self.memory_param = memory_param self.memory_param = memory_param
self.ada_param = ada_param
#self.train_pos_ratio,self.train_neg_ratio = train_ratio #self.train_pos_ratio,self.train_neg_ratio = train_ratio
if not 'dim_out' in gnn_param: if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out'] gnn_param['dim_out'] = memory_param['dim_out']
...@@ -89,7 +90,7 @@ class GeneralModel(torch.nn.Module): ...@@ -89,7 +90,7 @@ class GeneralModel(torch.nn.Module):
#else: #else:
updater = torch.nn.GRUCell updater = torch.nn.GRUCell
# if memory_param['historical_fix'] == False: # if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode']) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],ada_param=ada_param)
# else: # else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'rnn': elif memory_param['memory_update'] == 'rnn':
...@@ -98,12 +99,12 @@ class GeneralModel(torch.nn.Module): ...@@ -98,12 +99,12 @@ class GeneralModel(torch.nn.Module):
#else: #else:
updater = torch.nn.RNNCell updater = torch.nn.RNNCell
#if memory_param['historical_fix'] == False: #if memory_param['historical_fix'] == False:
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode']) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],ada_param=ada_param)
# else: # else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
updater = TransformerMemoryUpdater updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param,ada_param=ada_param)
else: else:
raise NotImplementedError raise NotImplementedError
self.dim_node_input = memory_param['dim_out'] self.dim_node_input = memory_param['dim_out']
......
import yaml import yaml
import numpy as np import numpy as np
import torch
import math
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0] sample_param = conf['sampling'][0]
...@@ -36,3 +37,109 @@ class EarlyStopMonitor(object): ...@@ -36,3 +37,109 @@ class EarlyStopMonitor(object):
self.epoch_count += 1 self.epoch_count += 1
return self.num_round >= self.max_round return self.num_round >= self.max_round
class AdaParameter:
def __init__(self, wait_threshold=0.1, init_beta = 0.1 ,init_alpha = 0.1, min_beta = 0.01, max_beta = 1, min_alpha = 0.1, max_alpha = 1):
self.wait_threshold = wait_threshold
self.beta = init_beta
self.alpha = init_alpha
self.average_fetch = 0
self.count_fetch = 0
self.average_memory_sync = 0
self.count_memory_sync = 0
self.average_memory_update= 0
self.count_memory_update = 0
self.average_gnn_aggregate = 0
self.count_gnn_aggregate = 0
self.counts = 0
self.last_start_event_fetch = None
self.last_start_event_memory_sync = None
self.last_start_event_memory_update = None
self.last_start_event_gnn_aggregate = None
self.min_beta = min_beta
self.max_beta = max_beta
self.max_alpha = max_alpha
self.min_alpha = min_alpha
self.training = False
def train(self):
self.training = True
def eval(self):
self.training = False
def start_event(self):
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
return start_event
def update_fetch_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_fetch += elapsed_time_ms
self.count_fetch += 1
def update_memory_sync_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_sync += elapsed_time_ms
self.count_memory_sync += 1
def update_memory_update_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_memory_update += elapsed_time_ms
self.count_memory_update += 1
def update_gnn_aggregate_time(self,start_event):
if start_event is None:
return
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
end_event.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
self.average_gnn_aggregate += elapsed_time_ms
self.count_gnn_aggregate += 1
def reset_time(self):
self.average_fetch = 0
self.count_fetch = 0
self.average_memory_sync = 0
self.count_memory_sync = 0
self.average_memory_update= 0
self.count_memory_update = 0
self.average_gnn_aggregate = 0
self.count_gnn_aggregate = 0
def update_parameter(self):
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
if self.count_fetch == 0 or self.count_memory_sync == 0 or self.count_memory_update == 0 or self.count_gnn_aggregate == 0:
return
average_gnn_aggregate = self.average_gnn_aggregate/self.count_gnn_aggregate
average_fetch = self.average_fetch/self.count_fetch
self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold)
average_memory_sync_time = self.average_memory_sync/self.count_memory_sync
average_memory_update_time = self.average_memory_update/self.count_memory_update
self.alpha = (2-math.pow((2-self.alpha),average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold)))
self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
#a1 = 2 - 2-a2 ^(t1/t2 * (1 + wait_threshold))
...@@ -101,8 +101,10 @@ class DistributedDataLoader: ...@@ -101,8 +101,10 @@ class DistributedDataLoader:
use_local_feature = True, use_local_feature = True,
probability = 1, probability = 1,
reversed = False, reversed = False,
ada_param = None,
**kwargs **kwargs
): ):
self.ada_param = ada_param
self.reversed = reversed self.reversed = reversed
self.use_local_feature = use_local_feature self.use_local_feature = use_local_feature
self.local_embedding = local_embedding self.local_embedding = local_embedding
...@@ -255,6 +257,7 @@ class DistributedDataLoader: ...@@ -255,6 +257,7 @@ class DistributedDataLoader:
return return
while(len(self.result_queue)==0): while(len(self.result_queue)==0):
pass pass
batch_data,dist_nid,dist_eid = self.result_queue[0].result() batch_data,dist_nid,dist_eid = self.result_queue[0].result()
b = batch_data[1][0][0] b = batch_data[1][0][0]
self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum().item() self.remote_node += (DistIndex(dist_nid).part != dist.get_rank()).sum().item()
...@@ -275,16 +278,21 @@ class DistributedDataLoader: ...@@ -275,16 +278,21 @@ class DistributedDataLoader:
#start = torch.cuda.Event(enable_timing=True) #start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True) #end = torch.cuda.Event(enable_timing=True)
#start.record() #start.record()
nind,ndata = get_node_all_to_all_route(self.graph,self.mailbox,dist_nid,out_device=self.device) nind,ndata = get_node_all_to_all_route(self.graph,self.mailbox,dist_nid,out_device=self.device)
eind,edata = get_edge_all_to_all_route(self.graph,dist_eid,out_device=self.device) eind,edata = get_edge_all_to_all_route(self.graph,dist_eid,out_device=self.device)
if nind is not None: if nind is not None:
node_feat = DistributedTensor.all_to_all_get_data(ndata,send_ptr=nind['send_ptr'],recv_ptr=nind['recv_ptr'],is_async=True) node_feat = DistributedTensor.all_to_all_get_data(ndata,send_ptr=nind['send_ptr'],recv_ptr=nind['recv_ptr'],is_async=True)
else: else:
node_feat = None node_feat = None
if eind is not None: if eind is not None:
edge_feat = DistributedTensor.all_to_all_get_data(edata,send_ptr=eind['send_ptr'],recv_ptr=eind['recv_ptr'],is_async=True) edge_feat = DistributedTensor.all_to_all_get_data(edata,send_ptr=eind['send_ptr'],recv_ptr=eind['recv_ptr'],is_async=True)
else: else:
edge_feat = None edge_feat = None
if self.ada_param is not None:
self.ada_param.last_start_event_fetch = self.ada_param.start_event()
t3 = time.time() t3 = time.time()
self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat)) self.result_queue.append((batch_data,dist_nid,dist_eid,edge_feat,node_feat))
self.submit() self.submit()
...@@ -356,10 +364,12 @@ class DistributedDataLoader: ...@@ -356,10 +364,12 @@ class DistributedDataLoader:
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]] #batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]]
self.recv_idxs += 1 self.recv_idxs += 1
else: else:
if(self.recv_idxs < self.expected_idx): if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0 assert len(self.result_queue) > 0
#print(len(self.result_queue[0])) #print(len(self.result_queue[0]))
if isinstance(self.result_queue[0],tuple) : if isinstance(self.result_queue[0],tuple) :
t0 = time.time() t0 = time.time()
batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = self.result_queue[0] batch_data,dist_nid,dist_eid,edge_feat,node_feat0 = self.result_queue[0]
self.result_queue.popleft() self.result_queue.popleft()
...@@ -382,6 +392,8 @@ class DistributedDataLoader: ...@@ -382,6 +392,8 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = None node_feat = None
mem = self.mailbox.unpack(node_feat0,mailbox = True) mem = self.mailbox.unpack(node_feat0,mailbox = True)
if self.ada_param is not None:
self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
#print(node_feat.shape,edge_feat.shape,mem[0].shape) #print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait() #node_feat[1].wait()
#node_feat = node_feat[0] #node_feat = node_feat[0]
...@@ -395,6 +407,8 @@ class DistributedDataLoader: ...@@ -395,6 +407,8 @@ class DistributedDataLoader:
#mem = (mem[0][0],mem[1][0],mem[2][0],mem[3][0]) #mem = (mem[0][0],mem[1][0],mem[2][0],mem[3][0])
#node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox, dist_nid,is_local,out_device=self.device) #node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox, dist_nid,is_local,out_device=self.device)
t1 = time.time() t1 = time.time()
#if self.ada_param is not None:
# self.ada_param.update_fetch_time(self.ada_param.last_start_event_fetch)
else: else:
batch_data,dist_nid,dist_eid = self.result_queue[0].result() batch_data,dist_nid,dist_eid = self.result_queue[0].result()
stream.synchronize() stream.synchronize()
...@@ -410,6 +424,7 @@ class DistributedDataLoader: ...@@ -410,6 +424,7 @@ class DistributedDataLoader:
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]] indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx] batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1) batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
self.recv_idxs += 1 self.recv_idxs += 1
else: else:
raise StopIteration raise StopIteration
......
...@@ -59,8 +59,9 @@ class SharedMailBox(): ...@@ -59,8 +59,9 @@ class SharedMailBox():
uvm = False, uvm = False,
use_pin = False, use_pin = False,
start_historical = False, start_historical = False,
shared_ssim = 2): ada_param = None):
ctx = distributed.context._get_default_dist_context() ctx = distributed.context._get_default_dist_context()
self.ada_param = ada_param
self.device = device self.device = device
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.num_parts = dist.get_world_size() self.num_parts = dist.get_world_size()
...@@ -106,7 +107,7 @@ class SharedMailBox(): ...@@ -106,7 +107,7 @@ class SharedMailBox():
device=torch.device('cuda:{}'.format(ctx.local_rank))) device=torch.device('cuda:{}'.format(ctx.local_rank)))
if start_historical: if start_historical:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim) self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,ada_param)
else: else:
self.historical_cache = None self.historical_cache = None
self._mem_pin = {} self._mem_pin = {}
...@@ -294,6 +295,8 @@ class SharedMailBox(): ...@@ -294,6 +295,8 @@ class SharedMailBox():
if self.next_wait_gather_memory_job is not None: if self.next_wait_gather_memory_job is not None:
shared_list,mem,shared_id_list,shared_memory_ind = self.next_wait_gather_memory_job shared_list,mem,shared_id_list,shared_memory_ind = self.next_wait_gather_memory_job
self.next_wait_gather_memory_job = None self.next_wait_gather_memory_job = None
if self.ada_param.training is True:
self.ada_param.last_start_event_memory_sync = self.ada_param.start_event()
handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True) handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True) handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list) self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
......
...@@ -22,7 +22,7 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -22,7 +22,7 @@ class LocalNegativeSampling(NegativeSampling):
dst_node_list: torch.Tensor = None, dst_node_list: torch.Tensor = None,
local_mask = None, local_mask = None,
seed = None, seed = None,
prob = None ada_param = None
): ):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique) super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
...@@ -38,7 +38,11 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -38,7 +38,11 @@ class LocalNegativeSampling(NegativeSampling):
self.local_mask = local_mask self.local_mask = local_mask
if self.local_mask is not None: if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask] self.local_dst = dst_node_list[local_mask]
self.prob = prob self.ada_param = ada_param
self.remote_weight = None
self.local_weight = None
#self.prob = prob
#self.rdm.manual_seed(42) #self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list)) #print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool: def is_binary(self) -> bool:
...@@ -50,6 +54,7 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -50,6 +54,7 @@ class LocalNegativeSampling(NegativeSampling):
def sample(self, num_samples: int, def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor: num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples.""" r"""Generates :obj:`num_samples` negative samples."""
if self.is_binary(): if self.is_binary():
if self.src_node_list is None or self.dst_node_list is None: if self.src_node_list is None or self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, )),torch.randint(num_nodes, (num_samples, )) return torch.randint(num_nodes, (num_samples, )),torch.randint(num_nodes, (num_samples, ))
...@@ -60,10 +65,16 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -60,10 +65,16 @@ class LocalNegativeSampling(NegativeSampling):
if self.dst_node_list is None: if self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm) return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None: elif self.local_mask is not None:
prob = self.ada_param.beta
remote_ratio = self.local_dst.shape[0] / self.dst_node_list.shape[0]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
self.train_ratio_pos = 1.0/(1-prob+ prob * remote_ratio) if ((prob<1) & (prob > 0)) else 1
self.train_ratio_neg = 1.0/(prob*remote_ratio) if ((prob <1) & (prob > 0)) else 1
p = torch.rand(size=(num_samples,)) p = torch.rand(size=(num_samples,))
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)] sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)] sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=self.prob,sr,sl) s=torch.where(p<=prob,sr,sl)
return s return s
else: else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm) s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
......
...@@ -96,8 +96,8 @@ class NeighborSampler(BaseSampler): ...@@ -96,8 +96,8 @@ class NeighborSampler(BaseSampler):
local_part = -1, local_part = -1,
node_part = None, node_part = None,
edge_part = None, edge_part = None,
probability = 1,
no_neg = False, no_neg = False,
ada_param = None
) -> None: ) -> None:
r"""__init__ r"""__init__
Args: Args:
...@@ -140,9 +140,9 @@ class NeighborSampler(BaseSampler): ...@@ -140,9 +140,9 @@ class NeighborSampler(BaseSampler):
else: else:
assert tnb is not None assert tnb is not None
self.tnb = tnb self.tnb = tnb
self.ada_param = ada_param
self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers, self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy, local_part,edge_part.to(torch.int),node_part.to(torch.int),probability) fanout, num_layers, policy, local_part,edge_part.to(torch.int),node_part.to(torch.int),ada_param.beta)
def _get_sample_info(self): def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers return self.num_nodes,self.num_layers,self.fanout,self.workers
...@@ -229,7 +229,7 @@ class NeighborSampler(BaseSampler): ...@@ -229,7 +229,7 @@ class NeighborSampler(BaseSampler):
sampled_edge_index_list: the edge sampled sampled_edge_index_list: the edge sampled
""" """
if self.policy != 'identity': if self.policy != 'identity':
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None) self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None, self.ada_param.beta if self.ada_param is not None else None)
ret = self.p_sampler.get_ret() ret = self.p_sampler.get_ret()
else: else:
ret = None ret = None
......
theta.png

55.6 KB

topk.png

61.2 KB

Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment