Commit 34f10818 by xxx

first commit

parent 71cb0206
......@@ -42,6 +42,7 @@ if(WITH_CUDA)
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
# add_definitions(-DWITH_METIS)
# set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
......@@ -164,6 +165,7 @@ add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
message(STATUS "Torch include directories: ${TORCH_LIBRARIES}")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
......
sampling:
- layer: 2
neighbor:
- 20
- 20
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 200
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
......@@ -21,7 +21,7 @@ gnn:
dim_out: 100
train:
- epoch: 50
batch_size: 1000
batch_size: 200
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
......
sampling:
- layer: 1
neighbor:
- 10
- 20
strategy: 'recent'
prop_time: False
history: 1
......@@ -13,6 +13,8 @@ memory:
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
historical: False
async: True
mailbox_size: 1
combine_node_feature: True
dim_out: 100
......@@ -25,10 +27,10 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 5
batch_size: 1000
- epoch: 100
batch_size: 600
# reorder: 16
lr: 0.0001
lr: 0.0005
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 5
batch_size: 1000
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 5
batch_size: 1000
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 600
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 20
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
historical: False
async: True
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 3000
# reorder: 16
lr: 0.0005
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
......@@ -16,6 +16,7 @@ memory:
mailbox_size: 1
combine_node_feature: True
dim_out: 100
historical: True
gnn:
- arch: 'transformer_attention'
use_src_emb: True
......@@ -24,6 +25,7 @@ gnn:
att_head: 2
dim_time: 100
dim_out: 100
historcial: True
train:
- epoch: 50
batch_size: 1000
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
#'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
historical: True
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
historical: True
train:
- epoch: 20
batch_size: 3000
# reorder: 16
lr: 0.0002
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'smart'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 3000
# reorder: 16
lr: 0.0002
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
train_neg_samples: 1
eval_neg_samples: 1
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 100
batch_size: 3000
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
#include "head.h"
struct
\ No newline at end of file
#include<head.h>
#include <sampler.h>
#include <tppr.h>
#include <output.h>
#include <neighbors.h>
#include <temporal_utils.h>
/*------------Python Bind--------------------------------------------------------------*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m
.def("get_neighbors",
&get_neighbors,
py::return_value_policy::reference)
.def("heads_unique",
&heads_unique,
py::return_value_policy::reference)
.def("divide_nodes_to_part",
&divide_nodes_to_part,
py::return_value_policy::reference)
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<vector<NodeIDType> &, vector<NodeIDType> &,
vector<NodeIDType> &>())
.def("row", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.col); })
.def("eid", [](const TemporalGraphBlock &tgb) { return vecToTensor<EdgeIDType>(tgb.eid); })
.def("delta_ts", [](const TemporalGraphBlock &tgb) { return vecToTensor<TimeStampType>(tgb.delta_ts); })
.def("src_index", [](const TemporalGraphBlock &tgb) { return vecToTensor<EdgeIDType>(tgb.src_index); })
.def("sample_nodes", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.sample_nodes); })
.def("sample_nodes_ts", [](const TemporalGraphBlock &tgb) { return vecToTensor<TimeStampType>(tgb.sample_nodes_ts); })
.def_readonly("sample_time", &TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<T_TemporalGraphBlock>(m, "T_TemporalGraphBlock")
.def(py::init<th::Tensor &, th::Tensor &,
th::Tensor &>())
.def_readonly("row", &T_TemporalGraphBlock::row, py::return_value_policy::reference)
.def_readonly("col", &T_TemporalGraphBlock::col, py::return_value_policy::reference)
.def_readonly("eid", &T_TemporalGraphBlock::eid, py::return_value_policy::reference)
.def_readonly("delta_ts", &T_TemporalGraphBlock::delta_ts, py::return_value_policy::reference)
.def_readonly("src_index", &T_TemporalGraphBlock::src_index, py::return_value_policy::reference)
.def_readonly("sample_nodes", &T_TemporalGraphBlock::sample_nodes, py::return_value_policy::reference)
.def_readonly("sample_nodes_ts", &T_TemporalGraphBlock::sample_nodes_ts, py::return_value_policy::reference)
.def_readonly("sample_time", &T_TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &T_TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &T_TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<TemporalNeighborList>(m,"TemporalNeighborList")
.def(py::init<th::Tensor &, th::Tensor &, th::Tensor &,
vector<th::Tensor> &, vector<th::Tensor> &,vector<th::Tensor> &>())
.def_readonly("nids",&TemporalNeighborList::nids,py::return_value_policy::reference)
.def_readonly("ts",&TemporalNeighborList::ts,py::return_value_policy::reference)
.def_readonly("eids",&TemporalNeighborList::eids,py::return_value_policy::reference)
.def_readonly("uid",&TemporalNeighborList::uid,py::return_value_policy::reference)
.def_readonly("src_index",&TemporalNeighborList::src_index,py::return_value_policy::reference)
.def_readonly("ueid",&TemporalNeighborList::ueid,py::return_value_policy::reference);
py::class_<TemporalNeighborBlock>(m, "TemporalNeighborBlock")
.def(py::init<vector<vector<NodeIDType>>&,
vector<int64_t> &>())
.def(py::pickle(
[](const TemporalNeighborBlock& tnb) { return tnb.serialize(); },
[](const std::string& s) { return TemporalNeighborBlock::deserialize(s); }
))
.def("update_neighbors_with_time",
&TemporalNeighborBlock::update_neighbors_with_time)
.def("update_edge_weight",
&TemporalNeighborBlock::update_edge_weight)
.def("update_node_weight",
&TemporalNeighborBlock::update_node_weight)
.def("update_all_node_weight",
&TemporalNeighborBlock::update_all_node_weight)
// .def("get_node_neighbor",&TemporalNeighborBlock::get_node_neighbor)
// .def("get_node_deg", &TemporalNeighborBlock::get_node_deg)
.def_readonly("neighbors", &TemporalNeighborBlock::neighbors, py::return_value_policy::reference)
.def_readonly("timestamp", &TemporalNeighborBlock::timestamp, py::return_value_policy::reference)
.def_readonly("edge_weight", &TemporalNeighborBlock::edge_weight, py::return_value_policy::reference)
.def_readonly("eid", &TemporalNeighborBlock::eid, py::return_value_policy::reference)
.def_readonly("deg", &TemporalNeighborBlock::deg, py::return_value_policy::reference)
.def_readonly("with_eid", &TemporalNeighborBlock::with_eid, py::return_value_policy::reference)
.def_readonly("with_timestamp", &TemporalNeighborBlock::with_timestamp, py::return_value_policy::reference)
.def_readonly("weighted", &TemporalNeighborBlock::weighted, py::return_value_policy::reference);
py::class_<ParallelSampler>(m, "ParallelSampler")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
vector<int>&, int, string, int, th::Tensor &>())
.def_readonly("ret", &ParallelSampler::ret, py::return_value_policy::reference)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; })
.def_readonly("block",&ParallelSampler::block, py::return_value_policy::reference);
py::class_<ParallelTppRComputer>(m, "ParallelTppRComputer")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
int, int, int, vector<float>&, vector<float>& >())
.def_readonly("ret", &ParallelTppRComputer::ret, py::return_value_policy::reference)
.def("reset_ret", &ParallelTppRComputer::reset_ret)
.def("reset_tppr", &ParallelTppRComputer::reset_tppr)
.def("reset_val_tppr", &ParallelTppRComputer::reset_val_tppr)
.def("backup_tppr", &ParallelTppRComputer::backup_tppr)
.def("restore_tppr", &ParallelTppRComputer::restore_tppr)
.def("restore_val_tppr", &ParallelTppRComputer::restore_val_tppr)
.def("get_pruned_topk", &ParallelTppRComputer::get_pruned_topk)
.def("extract_streaming_tppr", &ParallelTppRComputer::extract_streaming_tppr)
.def("streaming_topk", &ParallelTppRComputer::streaming_topk)
.def("single_streaming_topk", &ParallelTppRComputer::single_streaming_topk)
.def("streaming_topk_no_fake", &ParallelTppRComputer::streaming_topk_no_fake)
.def("compute_val_tppr", &ParallelTppRComputer::compute_val_tppr)
.def("get_ret", [](const ParallelTppRComputer &ps) { return ps.ret; });
}
\ No newline at end of file
#pragma once
#include <iostream>
#include <algorithm>
#include <torch/extension.h>
#include <omp.h>
#include <time.h>
#include <random>
#include <parallel_hashmap/phmap.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <atomic>
using namespace std;
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
typedef int64_t EdgeIDType;
typedef float WeightType;
typedef float TimeStampType;
typedef tuple<NodeIDType, EdgeIDType, TimeStampType> PPRKeyType;
typedef double PPRValueType;
typedef phmap::parallel_flat_hash_map<PPRKeyType, PPRValueType> PPRDictType;
typedef vector<PPRDictType> PPRListDictType;
typedef vector<vector<PPRDictType>> PPRListListDictType;
typedef vector<vector<double>> NormListType;
class TemporalNeighborBlock;
class TemporalGraphBlock;
class ParallelSampler;
TemporalNeighborBlock& get_neighbors(string graph_name, th::Tensor row, th::Tensor col, int64_t num_nodes, int is_distinct, optional<th::Tensor> eid, optional<th::Tensor> dist_eid,optional<th::Tensor> edge_weight, optional<th::Tensor> time);
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads);
int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr);
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr);
vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads);
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e);
vector<int64_t> sample_max(const vector<WeightType>& weights, int k);
#define MTX std::mutex
#define EXTRAARGS , phmap::priv::hash_default_hash<K>, \
phmap::priv::hash_default_eq<K>, \
std::allocator<K>, 10, MTX
template <class K>
using HashT = phmap::parallel_flat_hash_set<K EXTRAARGS>;
template <class K, class V>
using HashM = phmap::parallel_flat_hash_map<K, V EXTRAARGS>;
// 辅助函数
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
// non-copy value transfer
auto v = new std::vector<T>(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast<std::vector<T> *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
template <typename T>
T* get_data_ptr(const th::Tensor& tensor) {
AT_ASSERTM(tensor.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(tensor.dim() == 1, "Offset tensor must be one-dimensional");
return tensor.data_ptr<T>();
}
template <typename T>
torch::Tensor vecToTensor(const std::vector<T>& vec) {
// 确定数据类型
torch::ScalarType dtype;
if (std::is_same<T, int64_t>::value) {
dtype = torch::kInt64;
} else if (std::is_same<T, float>::value) {
dtype = torch::kFloat32;
} else {
throw std::runtime_error("Unsupported data type");
}
// 创建Tensor
torch::Tensor tensor = torch::from_blob(
const_cast<T*>(vec.data()), /* 数据指针 */
{static_cast<long>(vec.size())}, /* 尺寸 */
dtype /* 数据类型 */
);
return tensor;//.clone(); // 克隆Tensor以拷贝数据
}
/*-------------------------------------------------------------------------------------**
**------------Utils--------------------------------------------------------------------**
**-------------------------------------------------------------------------------------*/
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads){
auto array_ptr = array.data_ptr<NodeIDType>();
phmap::parallel_flat_hash_set<NodeIDType> s(array_ptr, array_ptr+array.numel());
if(heads.numel()==0) return th::tensor(vector<NodeIDType>(s.begin(), s.end()));
AT_ASSERTM(heads.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(heads.dim() == 1, "0ffset tensor must be one-dimensional");
auto heads_ptr = heads.data_ptr<NodeIDType>();
#pragma omp parallel for num_threads(threads)
for(int64_t i=0; i<heads.size(0); i++){
if(s.count(heads_ptr[i])==1){
#pragma omp critical(erase)
s.erase(heads_ptr[i]);
}
}
vector<NodeIDType> ret;
ret.reserve(s.size()+heads.numel());
ret.assign(heads_ptr, heads_ptr+heads.numel());
ret.insert(ret.end(), s.begin(), s.end());
// cout<<"s: "<<s.size()<<" array: "<<array.size()<<endl;
return th::tensor(ret);
}
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr){
int partitionId = -1;
for(int i=0;i<part_ptr.size()-1;i++){
if(nid>=part_ptr[i]&&nid<part_ptr[i+1]){
partitionId = i;
break;
}
}
if(partitionId<0) throw "nid 不存在对应的分区";
return partitionId;
}
//0:inner; 1:outer
int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr){
if(nid>=part_ptr[pid]&&nid<part_ptr[pid+1]){
return 0;
}
return 1;
}
vector<th::Tensor> divide_nodes_to_part(
th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads){
double start_time = omp_get_wtime();
AT_ASSERTM(nodes.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(nodes.dim() == 1, "0ffset tensor must be one-dimensional");
auto nodes_id = nodes.data_ptr<NodeIDType>();
vector<vector<vector<NodeIDType>>> node_part_threads;
vector<th::Tensor> result(part_ptr.size()-1);
//初始化点的分区,每个分区按线程划分避免冲突
for(int i = 0; i<threads; i++){
vector<vector<NodeIDType>> node_parts;
for(int j=0;j<part_ptr.size()-1;j++){
node_parts.push_back(vector<NodeIDType>());
}
node_part_threads.push_back(node_parts);
}
#pragma omp parallel for num_threads(threads) default(shared)
for(int64_t i=0; i<nodes.size(0); i++){
int tid = omp_get_thread_num();
int pid = nodeIdToPartId(nodes_id[i], part_ptr);
node_part_threads[tid][pid].emplace_back(nodes_id[i]);
}
#pragma omp parallel for num_threads(part_ptr.size()-1) default(shared)
for(int i = 0; i<part_ptr.size()-1; i++){
vector<NodeIDType> temp;
for(int j=0;j<threads;j++){
temp.insert(temp.end(), node_part_threads[j][i].begin(), node_part_threads[j][i].end());
}
result[i]=th::tensor(temp);
}
double end_time = omp_get_wtime();
// cout<<"end divide consume: "<<end_time-start_time<<"s"<<endl;
return result;
}
float getRandomFloat(unsigned int *seed, float min, float max) {
float scale = rand_r(seed) / (float) RAND_MAX; // 转换为0到1之间的浮点数
return min + scale * (max - min); // 调整到min到max之间
}
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e){
NodeIDType sample_indice;
vector<WeightType> cumulative_weights;
partial_sum(weights.begin(), weights.end(), back_inserter(cumulative_weights));
AT_ASSERTM(cumulative_weights.back() > 0, "Edge weight sum should be greater than 0.");
// uniform_real_distribution<WeightType> distribution(0.0, cumulative_weights.back());
// WeightType random_value = distribution(e);
unsigned int loc_seed = omp_get_thread_num();
WeightType random_value = getRandomFloat(&loc_seed, 0.0, cumulative_weights.back());
auto it = lower_bound(cumulative_weights.begin(), cumulative_weights.end(), random_value);
sample_indice = distance(cumulative_weights.begin(), it);
return sample_indice;
}
vector<int64_t> sample_max(const vector<WeightType>& weights, int k) {
vector<int64_t> indices(weights.size());
for (int i = 0; i < weights.size(); ++i) {
indices[i] = i;
}
// 使用部分排序算法(选择算法)找到前k个最大值的索引
partial_sort(indices.begin(), indices.begin() + k, indices.end(),
[&weights](int64_t a, int64_t b) { return weights[a] > weights[b]; });
// 返回前k个最大值的索引
return vector<int64_t>(indices.begin(), indices.begin() + k);
}
\ No newline at end of file
#pragma once
#include <head.h>
class TemporalGraphBlock
{
public:
vector<NodeIDType> row;
vector<NodeIDType> col;
vector<EdgeIDType> eid;
vector<TimeStampType> delta_ts;
vector<int64_t> src_index;
vector<NodeIDType> sample_nodes;
vector<TimeStampType> sample_nodes_ts;
vector<WeightType> e_weights;
double sample_time = 0;
double tot_time = 0;
int64_t sample_edge_num = 0;
TemporalGraphBlock(){}
// TemporalGraphBlock(const TemporalGraphBlock &tgb);
TemporalGraphBlock(vector<NodeIDType> &_row, vector<NodeIDType> &_col,
vector<NodeIDType> &_sample_nodes):
row(_row), col(_col), sample_nodes(_sample_nodes){}
TemporalGraphBlock(vector<NodeIDType> &_row, vector<NodeIDType> &_col,
vector<NodeIDType> &_sample_nodes,
vector<TimeStampType> &_sample_nodes_ts):
row(_row), col(_col), sample_nodes(_sample_nodes),
sample_nodes_ts(_sample_nodes_ts){}
};
class T_TemporalGraphBlock
{
public:
th::Tensor row;
th::Tensor col;
th::Tensor eid;
th::Tensor delta_ts;
th::Tensor src_index;
th::Tensor sample_nodes;
th::Tensor sample_nodes_ts;
double sample_time = 0;
double tot_time = 0;
int64_t sample_edge_num = 0;
T_TemporalGraphBlock(){}
T_TemporalGraphBlock(th::Tensor &_row, th::Tensor &_col,
th::Tensor &_sample_nodes):
row(_row), col(_col), sample_nodes(_sample_nodes){}
};
class TemporalNeighborList{
public:
th::Tensor nids;
th::Tensor ts;
th::Tensor eids;
vector<th::Tensor> uid;
vector<th::Tensor> src_index;
vector<th::Tensor> ueid;
TemporalNeighborList(){}
TemporalNeighborList(th::Tensor &_nids, th::Tensor &_ts, th::Tensor &_eids,
vector<th::Tensor> &_uid, vector<th::Tensor> &_src_index,
vector<th::Tensor> &_ueid):
nids(_nids),ts(_ts),eids(_eids),uid(_uid),src_index(_src_index),ueid(_ueid){
}
};
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from conans import ConanFile, tools
import os
class SparseppConan(ConanFile):
name = "parallel_hashmap"
version = "1.27"
description = "A header-only, very fast and memory-friendly hash map"
# Indicates License type of the packaged library
license = "https://github.com/greg7mdp/parallel-hashmap/blob/master/LICENSE"
# Packages the license for the conanfile.py
exports = ["LICENSE"]
# Custom attributes for Bincrafters recipe conventions
source_subfolder = "source_subfolder"
def source(self):
source_url = "https://github.com/greg7mdp/parallel-hashmap"
tools.get("{0}/archive/{1}.tar.gz".format(source_url, self.version))
extracted_dir = self.name + "-" + self.version
#Rename to "source_folder" is a convention to simplify later steps
os.rename(extracted_dir, self.source_subfolder)
def package(self):
include_folder = os.path.join(self.source_subfolder, "parallel_hashmap")
self.copy(pattern="LICENSE")
self.copy(pattern="*", dst="include/parallel_hashmap", src=include_folder)
def package_id(self):
self.info.header_only()
#if !defined(spp_memory_h_guard)
#define spp_memory_h_guard
#include <cstdint>
#include <cstring>
#include <cstdlib>
#if defined(_WIN32) || defined( __CYGWIN__)
#define SPP_WIN
#endif
#ifdef SPP_WIN
#include <windows.h>
#include <Psapi.h>
#undef min
#undef max
#elif defined(__linux__)
#include <sys/types.h>
#include <sys/sysinfo.h>
#elif defined(__FreeBSD__)
#include <paths.h>
#include <fcntl.h>
#include <kvm.h>
#include <unistd.h>
#include <sys/sysctl.h>
#include <sys/user.h>
#endif
namespace spp
{
uint64_t GetSystemMemory();
uint64_t GetTotalMemoryUsed();
uint64_t GetProcessMemoryUsed();
uint64_t GetPhysicalMemory();
uint64_t GetSystemMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo (&memInfo);
auto totalVirtualMem = memInfo.totalram;
totalVirtualMem += memInfo.totalswap;
totalVirtualMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalVirtualMem);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageCnt;
size_t pageCntLen = sizeof(pageCnt);
u_int pageSize;
struct kvm_swap kswap;
uint64_t totalVirtualMem;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
totalVirtualMem = pageCnt * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
totalVirtualMem += kswap.ksw_total * pageSize;
return totalVirtualMem;
#else
return 0;
#endif
}
uint64_t GetTotalMemoryUsed()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile - memInfo.ullAvailPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto virtualMemUsed = memInfo.totalram - memInfo.freeram;
virtualMemUsed += memInfo.totalswap - memInfo.freeswap;
virtualMemUsed *= memInfo.mem_unit;
return static_cast<uint64_t>(virtualMemUsed);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageSize;
u_int pageCnt, freeCnt;
size_t pageCntLen = sizeof(pageCnt);
size_t freeCntLen = sizeof(freeCnt);
struct kvm_swap kswap;
uint64_t virtualMemUsed;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
sysctlbyname("vm.stats.vm.v_free_count", &freeCnt, &freeCntLen, NULL, 0);
virtualMemUsed = (pageCnt - freeCnt) * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
virtualMemUsed += kswap.ksw_used * pageSize;
return virtualMemUsed;
#else
return 0;
#endif
}
uint64_t GetProcessMemoryUsed()
{
#ifdef SPP_WIN
PROCESS_MEMORY_COUNTERS_EX pmc;
GetProcessMemoryInfo(GetCurrentProcess(), reinterpret_cast<PPROCESS_MEMORY_COUNTERS>(&pmc), sizeof(pmc));
return static_cast<uint64_t>(pmc.PrivateUsage);
#elif defined(__linux__)
auto parseLine =
[](char* line)->int
{
auto i = strlen(line);
while(*line < '0' || *line > '9')
{
line++;
}
line[i-3] = '\0';
i = atoi(line);
return i;
};
auto file = fopen("/proc/self/status", "r");
auto result = -1;
char line[128];
while(fgets(line, 128, file) != nullptr)
{
if(strncmp(line, "VmSize:", 7) == 0)
{
result = parseLine(line);
break;
}
}
fclose(file);
return static_cast<uint64_t>(result) * 1024;
#elif defined(__FreeBSD__)
struct kinfo_proc info;
size_t infoLen = sizeof(info);
int mib[] = { CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid() };
sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &infoLen, NULL, 0);
return static_cast<uint64_t>(info.ki_rssize * getpagesize());
#else
return 0;
#endif
}
uint64_t GetPhysicalMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPhys);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto totalPhysMem = memInfo.totalram;
totalPhysMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalPhysMem);
#elif defined(__FreeBSD__)
u_long physMem;
size_t physMemLen = sizeof(physMem);
int mib[] = { CTL_HW, HW_PHYSMEM };
sysctl(mib, sizeof(mib) / sizeof(*mib), &physMem, &physMemLen, NULL, 0);
return physMem;
#else
return 0;
#endif
}
}
#endif // spp_memory_h_guard
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
#if !defined(phmap_dump_h_guard_)
#define phmap_dump_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// providing dump/load/mmap_load
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
#include <iostream>
#include <fstream>
#include <sstream>
#include "phmap.h"
namespace phmap
{
namespace type_traits_internal {
#if defined(__GLIBCXX__) && __GLIBCXX__ < 20150801
template<typename T> struct IsTriviallyCopyable : public std::integral_constant<bool, __has_trivial_copy(T)> {};
#else
template<typename T> struct IsTriviallyCopyable : public std::is_trivially_copyable<T> {};
#endif
template <class T1, class T2>
struct IsTriviallyCopyable<std::pair<T1, T2>> {
static constexpr bool value = IsTriviallyCopyable<T1>::value && IsTriviallyCopyable<T2>::value;
};
}
namespace priv {
// ------------------------------------------------------------------------
// dump/load for raw_hash_set
// ------------------------------------------------------------------------
template <class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (!ar.dump(size_)) {
std::cerr << "Failed to dump size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.dump(capacity_)) {
std::cerr << "Failed to dump capacity_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to dump ctrl_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to dump slot_" << std::endl;
return false;
}
return true;
}
template <class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
raw_hash_set<Policy, Hash, Eq, Alloc>().swap(*this); // clear any existing content
if (!ar.load(&size_)) {
std::cerr << "Failed to load size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.load(&capacity_)) {
std::cerr << "Failed to load capacity_" << std::endl;
return false;
}
// allocate memory for ctrl_ and slots_
initialize_slots();
if (!ar.load(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to load ctrl" << std::endl;
return false;
}
if (!ar.load(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to load slot" << std::endl;
return false;
}
return true;
}
// ------------------------------------------------------------------------
// dump/load for parallel_hash_set
// ------------------------------------------------------------------------
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (! ar.dump(subcnt())) {
std::cerr << "Failed to dump meta!" << std::endl;
return false;
}
for (size_t i = 0; i < sets_.size(); ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.dump(ar)) {
std::cerr << "Failed to dump submap " << i << std::endl;
return false;
}
}
return true;
}
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
size_t submap_count = 0;
if (!ar.load(&submap_count)) {
std::cerr << "Failed to load submap count!" << std::endl;
return false;
}
if (submap_count != subcnt()) {
std::cerr << "submap count(" << submap_count << ") != N(" << N << ")" << std::endl;
return false;
}
for (size_t i = 0; i < submap_count; ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.load(ar)) {
std::cerr << "Failed to load submap " << i << std::endl;
return false;
}
}
return true;
}
} // namespace priv
// ------------------------------------------------------------------------
// BinaryArchive
// File is closed when archive object is destroyed
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
class BinaryOutputArchive {
public:
BinaryOutputArchive(const char *file_path) {
ofs_.open(file_path, std::ios_base::binary);
}
bool dump(const char *p, size_t sz) {
ofs_.write(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
dump(const V& v) {
ofs_.write(reinterpret_cast<const char *>(&v), sizeof(V));
return true;
}
private:
std::ofstream ofs_;
};
class BinaryInputArchive {
public:
BinaryInputArchive(const char * file_path) {
ifs_.open(file_path, std::ios_base::binary);
}
bool load(char* p, size_t sz) {
ifs_.read(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
load(V* v) {
ifs_.read(reinterpret_cast<char *>(v), sizeof(V));
return true;
}
private:
std::ifstream ifs_;
};
} // namespace phmap
#endif // phmap_dump_h_guard_
#if !defined(phmap_fwd_decl_h_guard_)
#define phmap_fwd_decl_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
// ---------------------------------------------------------------------------
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4514) // unreferenced inline function has been removed
#pragma warning(disable : 4710) // function not inlined
#pragma warning(disable : 4711) // selected for automatic inline expansion
#endif
#include <memory>
#include <utility>
#if defined(PHMAP_USE_ABSL_HASH) && !defined(ABSL_HASH_HASH_H_)
namespace absl { template <class T> struct Hash; };
#endif
namespace phmap {
#if defined(PHMAP_USE_ABSL_HASH)
template <class T> using Hash = ::absl::Hash<T>;
#else
template <class T> struct Hash;
#endif
template <class T> struct EqualTo;
template <class T> struct Less;
template <class T> using Allocator = typename std::allocator<T>;
template<class T1, class T2> using Pair = typename std::pair<T1, T2>;
class NullMutex;
namespace priv {
// The hash of an object of type T is computed by using phmap::Hash.
template <class T, class E = void>
struct HashEq
{
using Hash = phmap::Hash<T>;
using Eq = phmap::EqualTo<T>;
};
template <class T>
using hash_default_hash = typename priv::HashEq<T>::Hash;
template <class T>
using hash_default_eq = typename priv::HashEq<T>::Eq;
// type alias for std::allocator so we can forward declare without including other headers
template <class T>
using Allocator = typename phmap::Allocator<T>;
// type alias for std::pair so we can forward declare without including other headers
template<class T1, class T2>
using Pair = typename phmap::Pair<T1, T2>;
} // namespace priv
// ------------- forward declarations for hash containers ----------------------------------
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>> // alias for std::allocator
class flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>> // alias for std::allocator
class node_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_map;
// ------------- forward declarations for btree containers ----------------------------------
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_set;
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_multiset;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_map;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_multimap;
} // namespace phmap
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // phmap_fwd_decl_h_guard_
#pragma once
#include <torch/extension.h>
#include <parallel_hashmap/phmap.h>
#include <cstring>
#include <vector>
#include <iostream>
#include <map>
#include "head.h"
// #include <boost/thread/mutex.hpp>
using namespace std;
#define work_thread 10
th::Tensor sparse_get_index(th::Tensor in,th::Tensor map_key){
auto key_ptr = map_key.data_ptr<NodeIDType>();
auto in_ptr = in.data_ptr<NodeIDType>();
int sz = map_key.size(0);
vector<pair<NodeIDType,NodeIDType>> mp(sz);
vector<NodeIDType> out(in.size(0));
#pragma omp parallel for
for(int i=0;i<sz;i++){
mp[i] = make_pair(key_ptr[i],i);
}
phmap::parallel_flat_hash_map<NodeIDType,NodeIDType> dict(mp.begin(),mp.end());
#pragma omp parallel for
for(int i=0;i<in.size(0);i++){
out[i] = dict.find(in_ptr[i])->second;
}
return th::tensor(out);
}
vector<double> get_norm_temporal(th::Tensor row,th::Tensor col,th::Tensor timestamp,int num_nodes){
vector<double> ret(4);
HashM<NodeIDType,TimeStampType> dict0;
HashM<NodeIDType,TimeStampType> dict1;
auto rowptr = row.data_ptr<NodeIDType>();
auto colptr = col.data_ptr<NodeIDType>();
auto time_ptr = timestamp.data_ptr<TimeStampType>();
vector<TimeStampType> out_timestamp[work_thread];
vector<TimeStampType> in_timestamp[work_thread];
#pragma omp parallel for num_threads(work_thread)
for(int i = 0;i<row.size(0);i++){
int tid = omp_get_thread_num();
if(dict0.find(rowptr[i])!=dict0.end()){
out_timestamp[tid].push_back(time_ptr[i]-dict0.find(rowptr[i])->second);
dict0.find(rowptr[i])->second = time_ptr[i];
}
else dict0.insert(make_pair(rowptr[i],time_ptr[i]));
if(dict1.find(colptr[i])!=dict1.end()){
in_timestamp[tid].push_back(time_ptr[i]-dict1.find(colptr[i])->second);
dict1.find(colptr[i])->second = time_ptr[i];
}
else dict1.insert(make_pair(colptr[i],time_ptr[i]));
}
double srcavg = 0;
double dstavg = 0;
double srcvar = 0;
double dstvar = 0;
double srccnt = 0;
double dstcnt = 0;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(auto &v: in_timestamp[i]){
dstavg += v;
dstcnt++;
}
#pragma omp parallel for num_threads(work_thread)
for(auto &v: out_timestamp[i]){
srcavg += v;
srccnt++;
}
}
dstavg /= dstcnt;
srcavg /= srccnt;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<in_timestamp[i].size();j++){
TimeStampType v=in_timestamp[i][j];
dstvar += (v-dstavg)*(v-dstavg)/dstavg;
}
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<out_timestamp[i].size();j++){
TimeStampType v=out_timestamp[i][j];
srcavg += (v-srcavg)*(v-srcavg)/srcavg;
}
}
ret[0]=srcavg;
ret[1]=srcvar;
ret[2]=dstavg;
ret[3]=dstvar;
return ret;
}
PYBIND11_MODULE(torch_utils, m)
{
m
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
}
\ No newline at end of file
/*!
* Copyright (c) 2016 by Contributors
* \file array_view.h
* \brief Read only data structure to reference array
*/
#ifndef DMLC_ARRAY_VIEW_H_
#define DMLC_ARRAY_VIEW_H_
#include <vector>
#include <array>
namespace dmlc {
/*!
* \brief Read only data structure to reference continuous memory region of array.
* Provide unified view for vector, array and C style array.
* This data structure do not guarantee aliveness of referenced array.
*
* Make sure do not use array_view to record data in async function closures.
* Also do not use array_view to create reference to temporary data structure.
*
* \tparam ValueType The value
*
* \code
* std::vector<int> myvec{1,2,3};
* dmlc::array_view<int> view(myvec);
* // indexed visit to the view.
* LOG(INFO) << view[0];
*
* for (int v : view) {
* // visit each element in the view
* }
* \endcode
*/
template<typename ValueType>
class array_view {
public:
/*! \brief default constructor */
array_view() = default;
/*!
* \brief default copy constructor
* \param other another array view.
*/
array_view(const array_view<ValueType> &other) = default; // NOLINT(*)
#ifndef _MSC_VER
/*!
* \brief default move constructor
* \param other another array view.
*/
array_view(array_view<ValueType>&& other) = default; // NOLINT(*)
#else
/*!
* \brief default move constructor
* \param other another array view.
*/
array_view(array_view<ValueType>&& other) { // NOLINT(*)
begin_ = other.begin_;
size_ = other.size_;
other.begin_ = nullptr;
}
#endif
/*!
* \brief default assign constructor
* \param other another array view.
* \return self.
*/
array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*)
/*!
* \brief construct array view std::vector
* \param other vector container
*/
array_view(const std::vector<ValueType>& other) { // NOLINT(*)
if (other.size() != 0) {
begin_ = &other[0]; size_ = other.size();
}
}
/*!
* \brief construct array std::array
* \param other another array view.
*/
template<std::size_t size>
array_view(const std::array<ValueType, size>& other) { // NOLINT(*)
if (size != 0) {
begin_ = &other[0]; size_ = size;
}
}
/*!
* \brief construct array view from continuous segment
* \param begin beginning pointre
* \param end end pointer
*/
array_view(const ValueType* begin, const ValueType* end) {
if (begin < end) {
begin_ = begin;
size_ = end - begin;
}
}
/*! \return size of the array */
inline size_t size() const {
return size_;
}
/*! \return begin of the array */
inline const ValueType* begin() const {
return begin_;
}
/*! \return end point of the array */
inline const ValueType* end() const {
return begin_ + size_;
}
/*!
* \brief get i-th element from the view
* \param i The index.
* \return const reference to i-th element.
*/
inline const ValueType& operator[](size_t i) const {
return begin_[i];
}
private:
/*! \brief the begin of the view */
const ValueType* begin_{nullptr};
/*! \brief The size of the view */
size_t size_{0};
};
} // namespace dmlc
#endif // DMLC_ARRAY_VIEW_H_
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief defines configuration macros
*/
#ifndef DMLC_BASE_H_
#define DMLC_BASE_H_
/*! \brief whether use glog for logging */
#ifndef DMLC_USE_GLOG
#define DMLC_USE_GLOG 0
#endif
/*!
* \brief whether throw dmlc::Error instead of
* directly calling abort when FATAL error occured
* NOTE: this may still not be perfect.
* do not use FATAL and CHECK in destructors
*/
#ifndef DMLC_LOG_FATAL_THROW
#define DMLC_LOG_FATAL_THROW 1
#endif
/*!
* \brief whether always log a message before throw
* This can help identify the error that cannot be catched.
*/
#ifndef DMLC_LOG_BEFORE_THROW
#define DMLC_LOG_BEFORE_THROW 0
#endif
/*!
* \brief Whether to use customized logger,
* whose output can be decided by other libraries.
*/
#ifndef DMLC_LOG_CUSTOMIZE
#define DMLC_LOG_CUSTOMIZE 0
#endif
/*!
* \brief Whether to enable debug logging feature.
*/
#ifndef DMLC_LOG_DEBUG
#ifdef NDEBUG
#define DMLC_LOG_DEBUG 0
#else
#define DMLC_LOG_DEBUG 1
#endif
#endif
/*!
* \brief Whether to disable date message on the log.
*/
#ifndef DMLC_LOG_NODATE
#define DMLC_LOG_NODATE 0
#endif
/*! \brief whether compile with hdfs support */
#ifndef DMLC_USE_HDFS
#define DMLC_USE_HDFS 0
#endif
/*! \brief whether compile with s3 support */
#ifndef DMLC_USE_S3
#define DMLC_USE_S3 0
#endif
/*! \brief whether or not use parameter server */
#ifndef DMLC_USE_PS
#define DMLC_USE_PS 0
#endif
/*! \brief whether or not use c++11 support */
#ifndef DMLC_USE_CXX11
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER)
#define DMLC_USE_CXX11 1
#else
#define DMLC_USE_CXX11 (__cplusplus >= 201103L)
#endif
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#if defined(_MSC_VER)
#define DMLC_STRICT_CXX11 1
#else
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L)
#endif
#endif
/*! \brief Whether cxx11 thread local is supported */
#ifndef DMLC_CXX11_THREAD_LOCAL
#if defined(_MSC_VER)
#define DMLC_CXX11_THREAD_LOCAL (_MSC_VER >= 1900)
#elif defined(__clang__)
#define DMLC_CXX11_THREAD_LOCAL (__has_feature(cxx_thread_local))
#else
#define DMLC_CXX11_THREAD_LOCAL (__cplusplus >= 201103L)
#endif
#endif
/*! \brief Whether to use modern thread local construct */
#ifndef DMLC_MODERN_THREAD_LOCAL
#define DMLC_MODERN_THREAD_LOCAL 1
#endif
/*! \brief whether RTTI is enabled */
#ifndef DMLC_ENABLE_RTTI
#define DMLC_ENABLE_RTTI 1
#endif
/*! \brief whether use fopen64 */
#ifndef DMLC_USE_FOPEN64
#define DMLC_USE_FOPEN64 1
#endif
/// check if g++ is before 5.0
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ < 5
#pragma message("Will need g++-5.0 or higher to compile all" \
"the features in dmlc-core, " \
"compile without c++11, some features may be disabled")
#undef DMLC_USE_CXX11
#define DMLC_USE_CXX11 0
#endif
#endif
/*!
* \brief Use little endian for binary serialization
* if this is set to 0, use big endian.
*/
#ifndef DMLC_IO_USE_LITTLE_ENDIAN
#define DMLC_IO_USE_LITTLE_ENDIAN 1
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
*/
#ifndef DMLC_ENABLE_STD_THREAD
#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11
#endif
/*! \brief whether enable regex support, actually need g++-4.9 or higher*/
#ifndef DMLC_USE_REGEX
#define DMLC_USE_REGEX DMLC_STRICT_CXX11
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to supress Undefined Behavior Sanitizer for a specific function */
#if defined(__clang__)
#define DMLC_SUPPRESS_UBSAN __attribute__((no_sanitize("undefined")))
#elif defined(__GNUC__) && (__GNUC__ * 100 + __GNUC_MINOR__ >= 409)
#define DMLC_SUPPRESS_UBSAN __attribute__((no_sanitize_undefined))
#else
#define DMLC_SUPPRESS_UBSAN
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
/*!
* \brief Disable copy constructor and assignment operator.
*
* If C++11 is supported, both copy and move constructors and
* assignment operators are deleted explicitly. Otherwise, they are
* only declared but not implemented. Place this macro in private
* section if C++11 is not available.
*/
#ifndef DISALLOW_COPY_AND_ASSIGN
# if DMLC_USE_CXX11
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&) = delete; \
T(T&&) = delete; \
T& operator=(T const&) = delete; \
T& operator=(T&&) = delete
# else
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&); \
T& operator=(T const&)
# endif
#endif
#ifdef __APPLE__
# define off64_t off_t
#endif
#ifdef _MSC_VER
#if _MSC_VER < 1900
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
#endif
#else
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif
#endif
extern "C" {
#include <sys/types.h>
}
#endif
#ifdef _MSC_VER
//! \cond Doxygen_Suppress
typedef signed char int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
//! \endcond
#else
#include <inttypes.h>
#endif
#include <string>
#include <vector>
#if defined(_MSC_VER) && _MSC_VER < 1900
#define noexcept_true throw ()
#define noexcept_false
#define noexcept(a) noexcept_##a
#endif
#if defined(_MSC_VER)
#define DMLC_NO_INLINE __declspec(noinline)
#else
#define DMLC_NO_INLINE __attribute__((noinline))
#endif
#if defined(__GNUC__) || defined(__clang__)
#define DMLC_ALWAYS_INLINE inline __attribute__((__always_inline__))
#elif defined(_MSC_VER)
#define DMLC_ALWAYS_INLINE __forceinline
#else
#define DMLC_ALWAYS_INLINE inline
#endif
#if DMLC_USE_CXX11
#define DMLC_THROW_EXCEPTION noexcept(false)
#define DMLC_NO_EXCEPTION noexcept(true)
#else
#define DMLC_THROW_EXCEPTION
#define DMLC_NO_EXCEPTION
#endif
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief safely get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a const vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a string
* \param str input string
* \return beginning address of a string
*/
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL;
return &str[0];
}
/*!
* \brief get the beginning address of a const string
* \param str input string
* \return beginning address of a string
*/
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace dmlc
#if defined(_MSC_VER) && _MSC_VER < 1900
#define constexpr const
#define alignof __alignof
#endif
/* If fopen64 is not defined by current machine,
replace fopen64 with std::fopen. Also determine ability to print stack trace
for fatal error and define DMLC_LOG_STACK_TRACE if stack trace can be
produced. Always keep this include directive at the bottom of dmlc/base.h */
#ifdef DMLC_CORE_USE_CMAKE
#include <dmlc/build_config.h>
#else
#include <dmlc/build_config_default.h>
#endif
#endif // DMLC_BASE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file build_config_default.h
* \brief Default detection logic for fopen64 and other symbols.
* May be overriden by CMake
* \author KOLANICH
*/
#ifndef DMLC_BUILD_CONFIG_DEFAULT_H_
#define DMLC_BUILD_CONFIG_DEFAULT_H_
/* default logic for fopen64 */
#if DMLC_USE_FOPEN64 && \
(!defined(__GNUC__) || (defined __ANDROID__) || (defined __FreeBSD__) \
|| (defined __APPLE__) || ((defined __MINGW32__) && !(defined __MINGW64__)) \
|| (defined __CYGWIN__) )
#define fopen64 std::fopen
#endif
/* default logic for stack trace */
#if (defined(__GNUC__) && !defined(__MINGW32__)\
&& !defined(__sun) && !defined(__SVR4)\
&& !(defined __MINGW64__) && !(defined __ANDROID__))\
&& !defined(__CYGWIN__) && !defined(__EMSCRIPTEN__)\
&& !defined(__RISCV__) && !defined(__hexagon__)
#ifndef DMLC_LOG_STACK_TRACE
#define DMLC_LOG_STACK_TRACE 1
#endif
#ifndef DMLC_LOG_STACK_TRACE_SIZE
#define DMLC_LOG_STACK_TRACE_SIZE 10
#endif
#define DMLC_EXECINFO_H <execinfo.h>
#endif
/* default logic for detecting existence of nanosleep() */
#if !(defined _WIN32) || (defined __CYGWIN__)
#define DMLC_NANOSLEEP_PRESENT
#endif
#endif // DMLC_BUILD_CONFIG_DEFAULT_H_
/*!
* Copyright (c) 2015 by Contributors
* \file common.h
* \brief defines some common utility function.
*/
#ifndef DMLC_COMMON_H_
#define DMLC_COMMON_H_
#include <vector>
#include <string>
#include <sstream>
#include <mutex>
#include <utility>
#include "./logging.h"
namespace dmlc {
/*!
* \brief Split a string by delimiter
* \param s String to be splitted.
* \param delim The delimiter.
* \return a splitted vector of strings.
*/
inline std::vector<std::string> Split(const std::string& s, char delim) {
std::string item;
std::istringstream is(s);
std::vector<std::string> ret;
while (std::getline(is, item, delim)) {
ret.push_back(item);
}
return ret;
}
/*!
* \brief hash an object and combines the key with previous keys
*/
template<typename T>
inline size_t HashCombine(size_t key, const T& value) {
std::hash<T> hash_func;
return key ^ (hash_func(value) + 0x9e3779b9 + (key << 6) + (key >> 2));
}
/*!
* \brief specialize for size_t
*/
template<>
inline size_t HashCombine<size_t>(size_t key, const size_t& value) {
return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2));
}
/*!
* \brief OMP Exception class catches, saves and rethrows exception from OMP blocks
*/
class OMPException {
private:
// exception_ptr member to store the exception
std::exception_ptr omp_exception_;
// mutex to be acquired during catch to set the exception_ptr
std::mutex mutex_;
public:
/*!
* \brief Parallel OMP blocks should be placed within Run to save exception
*/
template <typename Function, typename... Parameters>
void Run(Function f, Parameters... params) {
try {
f(params...);
} catch (dmlc::Error &ex) {
std::lock_guard<std::mutex> lock(mutex_);
if (!omp_exception_) {
omp_exception_ = std::current_exception();
}
} catch (std::exception &ex) {
std::lock_guard<std::mutex> lock(mutex_);
if (!omp_exception_) {
omp_exception_ = std::current_exception();
}
}
}
/*!
* \brief should be called from the main thread to rethrow the exception
*/
void Rethrow() {
if (this->omp_exception_) std::rethrow_exception(this->omp_exception_);
}
};
} // namespace dmlc
#endif // DMLC_COMMON_H_
/*!
* Copyright (c) 2015 by Contributors
* \file concurrency.h
* \brief thread-safe data structures.
* \author Yutian Li
*/
#ifndef DMLC_CONCURRENCY_H_
#define DMLC_CONCURRENCY_H_
// this code depends on c++11
#if DMLC_USE_CXX11
#include <atomic>
#include <deque>
#include <queue>
#include <mutex>
#include <vector>
#include <utility>
#include <condition_variable>
#include "dmlc/base.h"
namespace dmlc {
/*!
* \brief Simple userspace spinlock implementation.
*/
class Spinlock {
public:
#ifdef _MSC_VER
Spinlock() {
lock_.clear();
}
#else
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wbraced-scalar-init"
#endif // defined(__clang__)
Spinlock() : lock_(ATOMIC_FLAG_INIT) {
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif // defined(__clang__)
#endif
~Spinlock() = default;
/*!
* \brief Acquire lock.
*/
inline void lock() noexcept(true);
/*!
* \brief Release lock.
*/
inline void unlock() noexcept(true);
private:
std::atomic_flag lock_;
/*!
* \brief Disable copy and move.
*/
DISALLOW_COPY_AND_ASSIGN(Spinlock);
};
/*! \brief type of concurrent queue */
enum class ConcurrentQueueType {
/*! \brief FIFO queue */
kFIFO,
/*! \brief queue with priority */
kPriority
};
/*!
* \brief Cocurrent blocking queue.
*/
template <typename T,
ConcurrentQueueType type = ConcurrentQueueType::kFIFO>
class ConcurrentBlockingQueue {
public:
ConcurrentBlockingQueue();
~ConcurrentBlockingQueue() = default;
/*!
* \brief Push element to the end of the queue.
* \param e Element to push into.
* \param priority the priority of the element, only used for priority queue.
* The higher the priority is, the better.
* \tparam E the element type
*
* It will copy or move the element into the queue, depending on the type of
* the parameter.
*/
template <typename E>
void Push(E&& e, int priority = 0);
/*!
* \brief Push element to the front of the queue. Only works for FIFO queue.
* For priority queue it is the same as Push.
* \param e Element to push into.
* \param priority the priority of the element, only used for priority queue.
* The higher the priority is, the better.
* \tparam E the element type
*
* It will copy or move the element into the queue, depending on the type of
* the parameter.
*/
template <typename E>
void PushFront(E&& e, int priority = 0);
/*!
* \brief Pop element from the queue.
* \param rv Element popped.
* \return On false, the queue is exiting.
*
* The element will be copied or moved into the object passed in.
*/
bool Pop(T* rv);
/*!
* \brief Signal the queue for destruction.
*
* After calling this method, all blocking pop call to the queue will return
* false.
*/
void SignalForKill();
/*!
* \brief Get the size of the queue.
* \return The size of the queue.
*/
size_t Size();
private:
struct Entry {
T data;
int priority;
inline bool operator<(const Entry &b) const {
return priority < b.priority;
}
};
std::mutex mutex_;
std::condition_variable cv_;
std::atomic<bool> exit_now_;
int nwait_consumer_;
// a priority queue
std::vector<Entry> priority_queue_;
// a FIFO queue
std::deque<T> fifo_queue_;
/*!
* \brief Disable copy and move.
*/
DISALLOW_COPY_AND_ASSIGN(ConcurrentBlockingQueue);
};
inline void Spinlock::lock() noexcept(true) {
while (lock_.test_and_set(std::memory_order_acquire)) {
}
}
inline void Spinlock::unlock() noexcept(true) {
lock_.clear(std::memory_order_release);
}
template <typename T, ConcurrentQueueType type>
ConcurrentBlockingQueue<T, type>::ConcurrentBlockingQueue()
: exit_now_{false}, nwait_consumer_{0} {}
template <typename T, ConcurrentQueueType type>
template <typename E>
void ConcurrentBlockingQueue<T, type>::Push(E&& e, int priority) {
static_assert(std::is_same<typename std::remove_cv<
typename std::remove_reference<E>::type>::type,
T>::value,
"Types must match.");
bool notify;
{
std::lock_guard<std::mutex> lock{mutex_};
if (type == ConcurrentQueueType::kFIFO) {
fifo_queue_.emplace_back(std::forward<E>(e));
notify = nwait_consumer_ != 0;
} else {
Entry entry;
entry.data = std::move(e);
entry.priority = priority;
priority_queue_.push_back(std::move(entry));
std::push_heap(priority_queue_.begin(), priority_queue_.end());
notify = nwait_consumer_ != 0;
}
}
if (notify) cv_.notify_one();
}
template <typename T, ConcurrentQueueType type>
template <typename E>
void ConcurrentBlockingQueue<T, type>::PushFront(E&& e, int priority) {
static_assert(std::is_same<typename std::remove_cv<
typename std::remove_reference<E>::type>::type,
T>::value,
"Types must match.");
bool notify;
{
std::lock_guard<std::mutex> lock{mutex_};
if (type == ConcurrentQueueType::kFIFO) {
fifo_queue_.emplace_front(std::forward<E>(e));
notify = nwait_consumer_ != 0;
} else {
Entry entry;
entry.data = std::move(e);
entry.priority = priority;
priority_queue_.push_back(std::move(entry));
std::push_heap(priority_queue_.begin(), priority_queue_.end());
notify = nwait_consumer_ != 0;
}
}
if (notify) cv_.notify_one();
}
template <typename T, ConcurrentQueueType type>
bool ConcurrentBlockingQueue<T, type>::Pop(T* rv) {
std::unique_lock<std::mutex> lock{mutex_};
if (type == ConcurrentQueueType::kFIFO) {
++nwait_consumer_;
cv_.wait(lock, [this] {
return !fifo_queue_.empty() || exit_now_.load();
});
--nwait_consumer_;
if (!exit_now_.load()) {
*rv = std::move(fifo_queue_.front());
fifo_queue_.pop_front();
return true;
} else {
return false;
}
} else {
++nwait_consumer_;
cv_.wait(lock, [this] {
return !priority_queue_.empty() || exit_now_.load();
});
--nwait_consumer_;
if (!exit_now_.load()) {
std::pop_heap(priority_queue_.begin(), priority_queue_.end());
*rv = std::move(priority_queue_.back().data);
priority_queue_.pop_back();
return true;
} else {
return false;
}
}
}
template <typename T, ConcurrentQueueType type>
void ConcurrentBlockingQueue<T, type>::SignalForKill() {
{
std::lock_guard<std::mutex> lock{mutex_};
exit_now_.store(true);
}
cv_.notify_all();
}
template <typename T, ConcurrentQueueType type>
size_t ConcurrentBlockingQueue<T, type>::Size() {
std::lock_guard<std::mutex> lock{mutex_};
if (type == ConcurrentQueueType::kFIFO) {
return fifo_queue_.size();
} else {
return priority_queue_.size();
}
}
} // namespace dmlc
#endif // DMLC_USE_CXX11
#endif // DMLC_CONCURRENCY_H_
This source diff could not be displayed because it is too large. You can view the blob instead.
/*!
* Copyright (c) 2015 by Contributors
* \file config.h
* \brief defines config parser class
*/
#ifndef DMLC_CONFIG_H_
#define DMLC_CONFIG_H_
#include <cstring>
#include <iostream>
#include <iterator>
#include <map>
#include <vector>
#include <utility>
#include <string>
#include <sstream>
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief class for config parser
*
* Two modes are supported:
* 1. non-multi value mode: if two same keys in the configure file, the later one will replace the
* ealier one; when using iterator, the order will be the "last effective insersion" order
* 2. multi value mode: multiple values with the same key could co-exist; when using iterator, the
* order will be the insersion order.
*
* [Basic usage]
*
* Config cfg(file_input_stream);
* for(Config::ConfigIterator iter = cfg.begin(); iter != cfg.end(); ++iter) {
* ConfigEntry ent = *iter;
* std::string key = ent.first;
* std::string value = ent.second;
* do_something_with(key, value);
* }
*/
class Config {
public:
/*!
* \brief type when extracting from iterator
*/
typedef std::pair<std::string, std::string> ConfigEntry;
/*!
* \brief iterator class
*/
class ConfigIterator;
/*!
* \brief create empty config
* \param multi_value whether the config supports multi value
*/
explicit Config(bool multi_value = false);
/*!
* \brief create config and load content from the given stream
* \param is input stream
* \param multi_value whether the config supports multi value
*/
explicit Config(std::istream& is, bool multi_value = false); // NOLINT(*)
/*!
* \brief clear all the values
*/
void Clear(void);
/*!
* \brief load the contents from the stream
* \param is the stream as input
*/
void LoadFromStream(std::istream& is); // NOLINT(*)
/*!
* \brief set a key-value pair into the config; if the key already exists in the configure file,
* it will either replace the old value with the given one (in non-multi value mode) or
* store it directly (in multi-value mode);
* \param key key
* \param value value
* \param is_string whether the value should be wrapped by quotes in proto string
*/
template<class T>
void SetParam(const std::string& key, const T& value, bool is_string = false);
/*!
* \brief get the config under the key; if multiple values exist for the same key,
* return the last inserted one.
* \param key key
* \return config value
*/
const std::string& GetParam(const std::string& key) const;
/*!
* \brief check whether the configure value given by the key should be wrapped by quotes
* \param key key
* \return whether the configure value is represented by string
*/
bool IsGenuineString(const std::string& key) const;
/*!
* \brief transform all the configuration into string recognizable to protobuf
* \return string that could be parsed directly by protobuf
*/
std::string ToProtoString(void) const;
/*!
* \brief get begin iterator
* \return begin iterator
*/
ConfigIterator begin() const;
/*!
* \brief get end iterator
* \return end iterator
*/
ConfigIterator end() const;
public:
/*!
* \brief iterator class
*/
class ConfigIterator : public std::iterator< std::input_iterator_tag, ConfigEntry > {
friend class Config;
public:
/*!
* \brief copy constructor
*/
ConfigIterator(const ConfigIterator& other);
/*!
* \brief uni-increment operators
* \return the reference of current config
*/
ConfigIterator& operator++();
/*!
* \brief uni-increment operators
* \return the reference of current config
*/
ConfigIterator operator++(int); // NOLINT(*)
/*!
* \brief compare operators
* \param rhs the other config to compare against
* \return the compared result
*/
bool operator == (const ConfigIterator& rhs) const;
/*!
* \brief compare operators not equal
* \param rhs the other config to compare against
* \return the compared result
*/
bool operator != (const ConfigIterator& rhs) const;
/*!
* \brief retrieve value from operator
*/
ConfigEntry operator * () const;
private:
ConfigIterator(size_t index, const Config* config);
void FindNextIndex();
private:
size_t index_;
const Config* config_;
};
private:
struct ConfigValue {
std::vector<std::string> val;
std::vector<size_t> insert_index;
bool is_string;
};
void Insert(const std::string& key, const std::string& value, bool is_string);
private:
std::map<std::string, ConfigValue> config_map_;
std::vector<std::pair<std::string, size_t> > order_;
const bool multi_value_;
};
template<class T>
void Config::SetParam(const std::string& key, const T& value, bool is_string) {
std::ostringstream oss;
oss << value;
Insert(key, oss.str(), is_string);
}
} // namespace dmlc
#endif // DMLC_CONFIG_H_
/*!
* Copyright (c) 2017 by Contributors
* \file endian.h
* \brief Endian testing, need c++11
*/
#ifndef DMLC_ENDIAN_H_
#define DMLC_ENDIAN_H_
#include "./base.h"
#ifdef DMLC_CMAKE_LITTLE_ENDIAN
// If compiled with CMake, use CMake's endian detection logic
#define DMLC_LITTLE_ENDIAN DMLC_CMAKE_LITTLE_ENDIAN
#else
#if defined(__APPLE__) || defined(_WIN32)
#define DMLC_LITTLE_ENDIAN 1
#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) \
|| defined(__ANDROID__) || defined(__RISCV__)
#include <endian.h>
#define DMLC_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN)
#elif defined(__FreeBSD__) || defined(__OpenBSD__)
#include <sys/endian.h>
#define DMLC_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN)
#elif defined(__EMSCRIPTEN__) || defined(__hexagon__)
#define DMLC_LITTLE_ENDIAN 1
#elif defined(__sun) || defined(sun)
#include <sys/isa_defs.h>
#if defined(_LITTLE_ENDIAN)
#define DMLC_LITTLE_ENDIAN 1
#else
#define DMLC_LITTLE_ENDIAN 0
#endif
#else
#error "Unable to determine endianness of your machine; use CMake to compile"
#endif
#endif
/*! \brief whether serialize using little endian */
#define DMLC_IO_NO_ENDIAN_SWAP (DMLC_LITTLE_ENDIAN == DMLC_IO_USE_LITTLE_ENDIAN)
namespace dmlc {
/*!
* \brief A generic inplace byte swapping function.
* \param data The data pointer.
* \param elem_bytes The number of bytes of the data elements
* \param num_elems Number of elements in the data.
* \note Always try pass in constant elem_bytes to enable
* compiler optimization
*/
inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) {
for (size_t i = 0; i < num_elems; ++i) {
uint8_t* bptr = reinterpret_cast<uint8_t*>(data) + elem_bytes * i;
for (size_t j = 0; j < elem_bytes / 2; ++j) {
uint8_t v = bptr[elem_bytes - 1 - j];
bptr[elem_bytes - 1 - j] = bptr[j];
bptr[j] = v;
}
}
}
} // namespace dmlc
#endif // DMLC_ENDIAN_H_
/*!
* Copyright (c) 2018 by Contributors
* \file filesystem.h
* \brief Utilities to manipulate files
* \author Hyunsu Philip Cho
*/
#ifndef DMLC_FILESYSTEM_H_
#define DMLC_FILESYSTEM_H_
#include <dmlc/logging.h>
#include <dmlc/io.h>
#include <algorithm>
#include <string>
#include <vector>
#include <random>
/* platform specific headers */
#ifdef _WIN32
#define NOMINMAX
#include <windows.h>
#include <Shlwapi.h>
#pragma comment(lib, "Shlwapi.lib")
#else // _WIN32
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif // _WIN32
namespace dmlc {
/*!
* \brief Manager class for temporary directories. Whenever a new
* TemporaryDirectory object is constructed, a temporary directory is
* created. The directory is deleted when the object is deleted or goes
* out of scope. Note: no symbolic links are allowed inside the
* temporary directory.
*
* Usage example:
* \code
*
* void foo() {
* dmlc::TemporaryDirectory tempdir;
* // Create a file my_file.txt inside the temporary directory
* std::ofstream of(tempdir.path + "/my_file.txt");
* // ... write to my_file.txt ...
*
* // ... use my_file.txt
*
* // When tempdir goes out of scope, the temporary directory is deleted
* }
*
* \endcode
*/
class TemporaryDirectory {
public:
/*!
* \brief Default constructor.
* Creates a new temporary directory with a unique name.
* \param verbose whether to emit extra messages
*/
explicit TemporaryDirectory(bool verbose = false)
: verbose_(verbose) {
#if _WIN32
/* locate the root directory of temporary area */
char tmproot[MAX_PATH] = {0};
const DWORD dw_retval = GetTempPathA(MAX_PATH, tmproot);
if (dw_retval > MAX_PATH || dw_retval == 0) {
LOG(FATAL) << "TemporaryDirectory(): "
<< "Could not create temporary directory";
}
/* generate a unique 8-letter alphanumeric string */
const std::string letters = "abcdefghijklmnopqrstuvwxyz0123456789_";
std::string uniqstr(8, '\0');
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(0, letters.length() - 1);
std::generate(uniqstr.begin(), uniqstr.end(),
[&dis, &gen, &letters]() -> char {
return letters[dis(gen)];
});
/* combine paths to get the name of the temporary directory */
char tmpdir[MAX_PATH] = {0};
PathCombineA(tmpdir, tmproot, uniqstr.c_str());
if (!CreateDirectoryA(tmpdir, NULL)) {
LOG(FATAL) << "TemporaryDirectory(): "
<< "Could not create temporary directory";
}
path = std::string(tmpdir);
#else // _WIN32
std::string tmproot; /* root directory of temporary area */
std::string dirtemplate; /* template for temporary directory name */
/* Get TMPDIR env variable or fall back to /tmp/ */
{
const char* tmpenv = getenv("TMPDIR");
if (tmpenv) {
tmproot = std::string(tmpenv);
// strip trailing forward slashes
while (tmproot.length() != 0 && tmproot[tmproot.length() - 1] == '/') {
tmproot.resize(tmproot.length() - 1);
}
} else {
tmproot = "/tmp";
}
}
dirtemplate = tmproot + "/tmpdir.XXXXXX";
std::vector<char> dirtemplate_buf(dirtemplate.begin(), dirtemplate.end());
dirtemplate_buf.push_back('\0');
char* tmpdir = mkdtemp(&dirtemplate_buf[0]);
if (!tmpdir) {
LOG(FATAL) << "TemporaryDirectory(): "
<< "Could not create temporary directory";
}
path = std::string(tmpdir);
#endif // _WIN32
if (verbose_) {
LOG(INFO) << "Created temporary directory " << path;
}
}
/*! \brief Destructor. Will perform recursive deletion via RecursiveDelete() */
~TemporaryDirectory() {
RecursiveDelete(path);
}
/*! \brief Full path of the temporary directory */
std::string path;
private:
/*! \brief Whether to emit extra messages */
bool verbose_;
/*!
* \brief Determine whether a given path is a symbolic link
* \param path String representation of path
*/
inline bool IsSymlink(const std::string& path) {
#ifdef _WIN32
DWORD attr = GetFileAttributesA(path.c_str());
CHECK_NE(attr, INVALID_FILE_ATTRIBUTES)
<< "dmlc::TemporaryDirectory::IsSymlink(): Unable to read file attributes";
return attr & FILE_ATTRIBUTE_REPARSE_POINT;
#else // _WIN32
struct stat sb;
CHECK_EQ(lstat(path.c_str(), &sb), 0)
<< "dmlc::TemporaryDirectory::IsSymlink(): Unable to read file attributes";
return S_ISLNK(sb.st_mode);
#endif // _WIN32
}
/*!
* \brief Delete a directory recursively, along with sub-directories and files.
* \param path String representation of path. It must refer to a directory.
*/
void RecursiveDelete(const std::string& path);
};
} // namespace dmlc
#endif // DMLC_FILESYSTEM_H_
/*!
* Copyright (c) 2016 by Contributors
* \file input_split_shuffle.h
* \brief base class to construct input split with global shuffling
* \author Yifeng Geng
*/
#ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_
#define DMLC_INPUT_SPLIT_SHUFFLE_H_
#include <cstdio>
#include <cstring>
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
namespace dmlc {
/*! \brief class to construct input split with global shuffling */
class InputSplitShuffle : public InputSplit {
public:
// destructor
virtual ~InputSplitShuffle(void) { source_.reset(); }
// implement BeforeFirst
virtual void BeforeFirst(void) {
if (num_shuffle_parts_ > 1) {
std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_;
source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
cur_shuffle_idx_ = 0;
} else {
source_->BeforeFirst();
}
}
virtual void HintChunkSize(size_t chunk_size) {
source_->HintChunkSize(chunk_size);
}
virtual size_t GetTotalSize(void) {
return source_->GetTotalSize();
}
// implement next record
virtual bool NextRecord(Blob *out_rec) {
if (num_shuffle_parts_ > 1) {
if (!source_->NextRecord(out_rec)) {
if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
return false;
}
++cur_shuffle_idx_;
int idx =
shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
return NextRecord(out_rec);
} else {
return true;
}
} else {
return source_->NextRecord(out_rec);
}
}
// implement next chunk
virtual bool NextChunk(Blob* out_chunk) {
if (num_shuffle_parts_ > 1) {
if (!source_->NextChunk(out_chunk)) {
if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
return false;
}
++cur_shuffle_idx_;
int idx =
shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
return NextChunk(out_chunk);
} else {
return true;
}
} else {
return source_->NextChunk(out_chunk);
}
}
// implement ResetPartition.
virtual void ResetPartition(unsigned rank, unsigned nsplit) {
CHECK(nsplit == num_parts_) << "num_parts is not consistent!";
int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_;
source_->ResetPartition(idx, nsplit * num_shuffle_parts_);
cur_shuffle_idx_ = 0;
}
/*!
* \brief constructor
* \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
* \param type type of record
* List of possible types: "text", "recordio"
* - "text":
* text file, each line is treated as a record
* input split will split on '\\n' or '\\r'
* - "recordio":
* binary recordio file, see recordio.h
* \param num_shuffle_parts number of shuffle chunks for each split
* \param shuffle_seed shuffle seed for chunk shuffling
*/
InputSplitShuffle(const char* uri,
unsigned part_index,
unsigned num_parts,
const char* type,
unsigned num_shuffle_parts,
int shuffle_seed)
: part_index_(part_index),
num_parts_(num_parts),
num_shuffle_parts_(num_shuffle_parts),
cur_shuffle_idx_(0) {
for (unsigned i = 0; i < num_shuffle_parts_; i++) {
shuffle_indexes_.push_back(i);
}
trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ +
shuffle_seed);
std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
source_.reset(
InputSplit::Create(uri, idx , num_parts_ * num_shuffle_parts_, type));
}
/*!
* \brief factory function:
* create input split with chunk shuffling given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
* \param type type of record
* List of possible types: "text", "recordio"
* - "text":
* text file, each line is treated as a record
* input split will split on '\\n' or '\\r'
* - "recordio":
* binary recordio file, see recordio.h
* \param num_shuffle_parts number of shuffle chunks for each split
* \param shuffle_seed shuffle seed for chunk shuffling
* \return a new input split
* \sa InputSplit::Type
*/
static InputSplit* Create(const char* uri,
unsigned part_index,
unsigned num_parts,
const char* type,
unsigned num_shuffle_parts,
int shuffle_seed) {
CHECK(num_shuffle_parts > 0) << "number of shuffle parts should be greater than zero!";
return new InputSplitShuffle(
uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed);
}
private:
// magic nyumber for seed
static const int kRandMagic_ = 666;
/*! \brief random engine */
std::mt19937 trnd_;
/*! \brief inner inputsplit */
std::unique_ptr<InputSplit> source_;
/*! \brief part index */
unsigned part_index_;
/*! \brief number of parts */
unsigned num_parts_;
/*! \brief the number of block for shuffling*/
unsigned num_shuffle_parts_;
/*! \brief current shuffle block index */
unsigned cur_shuffle_idx_;
/*! \brief shuffled indexes */
std::vector<int> shuffle_indexes_;
};
} // namespace dmlc
#endif // DMLC_INPUT_SPLIT_SHUFFLE_H_
/*!
* Copyright (c) 2015 by Contributors
* \file memory.h
* \brief Additional memory hanlding utilities.
*/
#ifndef DMLC_MEMORY_H_
#define DMLC_MEMORY_H_
#include <vector>
#include <memory>
#include <utility>
#include "./base.h"
#include "./logging.h"
#include "./thread_local.h"
namespace dmlc {
/*!
* \brief A memory pool that allocate memory of fixed size and alignment.
* \tparam size The size of each piece.
* \tparam align The alignment requirement of the memory.
*/
template<size_t size, size_t align>
class MemoryPool {
public:
/*! \brief constructor */
MemoryPool() {
static_assert(align % alignof(LinkedList) == 0,
"alignment requirement failed.");
curr_page_.reset(new Page());
}
/*! \brief allocate a new memory of size */
inline void* allocate() {
if (head_ != nullptr) {
LinkedList* ret = head_;
head_ = head_->next;
return ret;
} else {
if (page_ptr_ < kPageSize) {
return &(curr_page_->data[page_ptr_++]);
} else {
allocated_.push_back(std::move(curr_page_));
curr_page_.reset(new Page());
page_ptr_ = 1;
return &(curr_page_->data[0]);
}
}
}
/*!
* \brief deallocate a piece of memory
* \param p The pointer to the memory to be de-allocated.
*/
inline void deallocate(void* p) {
LinkedList* ptr = static_cast<LinkedList*>(p);
ptr->next = head_;
head_ = ptr;
}
private:
// page size of each member
static const int kPageSize = ((1 << 22) / size);
// page to be requested.
struct Page {
typename std::aligned_storage<size, align>::type data[kPageSize];
};
// internal linked list structure.
struct LinkedList {
LinkedList* next{nullptr};
};
// head of free list
LinkedList* head_{nullptr};
// current free page
std::unique_ptr<Page> curr_page_;
// pointer to the current free page position.
size_t page_ptr_{0};
// allocated pages.
std::vector<std::unique_ptr<Page> > allocated_;
};
/*!
* \brief A thread local allocator that get memory from a threadlocal memory pool.
* This is suitable to allocate objects that do not cross thread.
* \tparam T the type of the data to be allocated.
*/
template<typename T>
class ThreadlocalAllocator {
public:
/*! \brief pointer type */
typedef T* pointer;
/*! \brief const pointer type */
typedef const T* const_ptr;
/*! \brief value type */
typedef T value_type;
/*! \brief default constructor */
ThreadlocalAllocator() {}
/*!
* \brief constructor from another allocator
* \param other another allocator
* \tparam U another type
*/
template<typename U>
ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {}
/*!
* \brief allocate memory
* \param n number of blocks
* \return an uninitialized memory of type T.
*/
inline T* allocate(size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
return static_cast<T*>(Store::Get()->allocate());
}
/*!
* \brief deallocate memory
* \param p a memory to be returned.
* \param n number of blocks
*/
inline void deallocate(T* p, size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
Store::Get()->deallocate(p);
}
};
/*!
* \brief a shared pointer like type that allocate object
* from a threadlocal object pool. This object is not thread-safe
* but can be faster than shared_ptr in certain usecases.
* \tparam T the data type.
*/
template<typename T>
struct ThreadlocalSharedPtr {
public:
/*! \brief default constructor */
ThreadlocalSharedPtr() : block_(nullptr) {}
/*!
* \brief constructor from nullptr
* \param other the nullptr type
*/
ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other)
: block_(other.block_) {
IncRef(block_);
}
/*!
* \brief move constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other)
: block_(other.block_) {
other.block_ = nullptr;
}
/*!
* \brief destructor
*/
~ThreadlocalSharedPtr() {
DecRef(block_);
}
/*!
* \brief move assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) {
DecRef(block_);
block_ = other.block_;
other.block_ = nullptr;
return *this;
}
/*!
* \brief copy assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) {
DecRef(block_);
block_ = other.block_;
IncRef(block_);
return *this;
}
/*! \brief check if nullptr */
inline bool operator==(std::nullptr_t other) const {
return block_ == nullptr;
}
/*!
* \return get the pointer content.
*/
inline T* get() const {
if (block_ == nullptr) return nullptr;
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief reset the pointer to nullptr.
*/
inline void reset() {
DecRef(block_);
block_ = nullptr;
}
/*! \return if use_count == 1*/
inline bool unique() const {
if (block_ == nullptr) return false;
return block_->use_count_ == 1;
}
/*! \return dereference pointer */
inline T* operator*() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*! \return dereference pointer */
inline T* operator->() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief create a new space from threadlocal storage and return it.
* \tparam Args the arguments.
* \param args The input argument
* \return the allocated pointer.
*/
template <typename... Args>
inline static ThreadlocalSharedPtr<T> Create(Args&&... args) {
ThreadlocalAllocator<RefBlock> arena;
ThreadlocalSharedPtr<T> p;
p.block_ = arena.allocate(1);
p.block_->use_count_ = 1;
new (&(p.block_->data)) T(std::forward<Args>(args)...);
return p;
}
private:
// internal reference block
struct RefBlock {
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
unsigned use_count_;
};
// decrease ref counter
inline static void DecRef(RefBlock* block) {
if (block != nullptr) {
if (--block->use_count_ == 0) {
ThreadlocalAllocator<RefBlock> arena;
T* dptr = reinterpret_cast<T*>(&(block->data));
dptr->~T();
arena.deallocate(block, 1);
}
}
}
// increase ref counter
inline static void IncRef(RefBlock* block) {
if (block != nullptr) {
++block->use_count_;
}
}
// internal block
RefBlock *block_;
};
} // namespace dmlc
#endif // DMLC_MEMORY_H_
/*!
* Copyright (c) 2015 by Contributors
* \file memory_io.h
* \brief defines binary serialization class to serialize things into/from memory region.
*/
#ifndef DMLC_MEMORY_IO_H_
#define DMLC_MEMORY_IO_H_
#include <cstring>
#include <string>
#include <algorithm>
#include "./base.h"
#include "./io.h"
#include "./logging.h"
namespace dmlc {
/*!
* \brief A Stream that operates on fixed region of memory
* This class allows us to read/write from/to a fixed memory region.
*/
struct MemoryFixedSizeStream : public SeekStream {
public:
/*!
* \brief constructor
* \param p_buffer the head pointer of the memory region.
* \param buffer_size the size of the memorybuffer
*/
MemoryFixedSizeStream(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
CHECK(curr_ptr_ + size <= buffer_size_);
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
CHECK(curr_ptr_ + size <= buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
char *p_buffer_;
/*! \brief current pointer */
size_t buffer_size_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixedSizeStream
/*!
* \brief A in memory stream that is backed by std::string.
* This class allows us to read/write from/to a std::string.
*/
struct MemoryStringStream : public dmlc::SeekStream {
public:
/*!
* \brief constructor
* \param p_buffer the pointer to the string.
*/
explicit MemoryStringStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
CHECK(curr_ptr_ <= p_buffer_->length());
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryStringStream
} // namespace dmlc
#endif // DMLC_MEMORY_IO_H_
/*!
* Copyright (c) 2015 by Contributors
* \file omp.h
* \brief header to handle OpenMP compatibility issues
*/
#ifndef DMLC_OMP_H_
#define DMLC_OMP_H_
#if defined(_OPENMP)
#include <omp.h>
#else
#if defined(__clang__)
#undef __GOMP_NOTHROW
#define __GOMP_NOTHROW
#elif defined(__cplusplus)
#undef __GOMP_NOTHROW
#define __GOMP_NOTHROW throw()
#else
#undef __GOMP_NOTHROW
#define __GOMP_NOTHROW __attribute__((__nothrow__))
#endif
//! \cond Doxygen_Suppress
#ifdef __cplusplus
extern "C" {
#endif
inline int omp_get_thread_num() __GOMP_NOTHROW { return 0; }
inline int omp_get_num_threads() __GOMP_NOTHROW { return 1; }
inline int omp_get_max_threads() __GOMP_NOTHROW { return 1; }
inline int omp_get_num_procs() __GOMP_NOTHROW { return 1; }
inline void omp_set_num_threads(int nthread) __GOMP_NOTHROW {}
inline int omp_in_parallel() __GOMP_NOTHROW { return 0; }
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // _OPENMP
// loop variable used in openmp
namespace dmlc {
#ifdef _MSC_VER
typedef int omp_uint;
typedef long omp_ulong; // NOLINT(*)
#else
typedef unsigned omp_uint;
typedef unsigned long omp_ulong; // NOLINT(*)
#endif
//! \endcond
} // namespace dmlc
#endif // DMLC_OMP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file optional.h
* \brief Container to hold optional data.
*/
#ifndef DMLC_OPTIONAL_H_
#define DMLC_OPTIONAL_H_
#include <iostream>
#include <string>
#include <utility>
#include <algorithm>
#include "./base.h"
#include "./common.h"
#include "./logging.h"
#include "./type_traits.h"
namespace dmlc {
/*! \brief dummy type for assign null to optional */
struct nullopt_t {
#if defined(_MSC_VER) && _MSC_VER < 1900
/*! \brief dummy constructor */
explicit nullopt_t(int a) {}
#else
/*! \brief dummy constructor */
constexpr explicit nullopt_t(int a) {}
#endif
};
/*! Assign null to optional: optional<T> x = nullopt; */
constexpr const nullopt_t nullopt = nullopt_t(0);
/*!
* \brief c++17 compatible optional class.
*
* At any time an optional<T> instance either
* hold no value (string representation "None")
* or hold a value of type T.
*/
template<typename T>
class optional {
public:
/*! \brief construct an optional object that contains no value */
optional() : is_none(true) {}
/*! \brief construct an optional object with value */
explicit optional(const T& value) {
#pragma GCC diagnostic push
#if __GNUC__ >= 6
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
is_none = false;
new (&val) T(value);
#pragma GCC diagnostic pop
}
/*! \brief construct an optional object with another optional object */
optional(const optional<T>& other) {
#pragma GCC diagnostic push
#if __GNUC__ >= 6
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
is_none = other.is_none;
if (!is_none) {
new (&val) T(other.value());
}
#pragma GCC diagnostic pop
}
/*! \brief deconstructor */
~optional() {
if (!is_none) {
reinterpret_cast<T*>(&val)->~T();
}
}
/*! \brief swap two optional */
void swap(optional<T>& other) {
std::swap(val, other.val);
std::swap(is_none, other.is_none);
}
/*! \brief set this object to hold value
* \param value the value to hold
* \return return self to support chain assignment
*/
optional<T>& operator=(const T& value) {
(optional<T>(value)).swap(*this);
return *this;
}
/*! \brief set this object to hold the same value with other
* \param other the other object
* \return return self to support chain assignment
*/
optional<T>& operator=(const optional<T> &other) {
(optional<T>(other)).swap(*this);
return *this;
}
/*! \brief clear the value this object is holding.
* optional<T> x = nullopt;
*/
optional<T>& operator=(nullopt_t) {
(optional<T>()).swap(*this);
return *this;
}
/*! \brief non-const dereference operator */
T& operator*() { // NOLINT(*)
return *reinterpret_cast<T*>(&val);
}
/*! \brief const dereference operator */
const T& operator*() const {
return *reinterpret_cast<const T*>(&val);
}
/*! \brief equal comparison */
bool operator==(const optional<T>& other) const {
return this->is_none == other.is_none &&
(this->is_none == true || this->value() == other.value());
}
/*! \brief return the holded value.
* throws std::logic_error if holding no value
*/
const T& value() const {
if (is_none) {
throw std::logic_error("bad optional access");
}
return *reinterpret_cast<const T*>(&val);
}
/*! \brief whether this object is holding a value */
explicit operator bool() const { return !is_none; }
/*! \brief whether this object is holding a value (alternate form). */
bool has_value() const { return operator bool(); }
private:
// whether this is none
bool is_none;
// on stack storage of value
typename std::aligned_storage<sizeof(T), alignof(T)>::type val;
};
/*! \brief serialize an optional object to string.
*
* \code
* dmlc::optional<int> x;
* std::cout << x; // None
* x = 0;
* std::cout << x; // 0
* \endcode
*
* \param os output stream
* \param t source optional<T> object
* \return output stream
*/
template<typename T>
std::ostream &operator<<(std::ostream &os, const optional<T> &t) {
if (t) {
os << *t;
} else {
os << "None";
}
return os;
}
/*! \brief parse a string object into optional<T>
*
* \code
* dmlc::optional<int> x;
* std::string s1 = "1";
* std::istringstream is1(s1);
* s1 >> x; // x == optional<int>(1)
*
* std::string s2 = "None";
* std::istringstream is2(s2);
* s2 >> x; // x == optional<int>()
* \endcode
*
* \param is input stream
* \param t target optional<T> object
* \return input stream
*/
template<typename T>
std::istream &operator>>(std::istream &is, optional<T> &t) {
char buf[4];
std::streampos origin = is.tellg();
is.read(buf, 4);
if (is.fail() || buf[0] != 'N' || buf[1] != 'o' ||
buf[2] != 'n' || buf[3] != 'e') {
is.clear();
is.seekg(origin);
T x;
is >> x;
t = x;
if (std::is_integral<T>::value && !is.eof() && is.peek() == 'L') is.get();
} else {
t = nullopt;
}
return is;
}
/*! \brief specialization of '>>' istream parsing for optional<bool>
*
* Permits use of generic parameter FieldEntry<DType> class to create
* FieldEntry<optional<bool>> without explicit specialization.
*
* \code
* dmlc::optional<bool> x;
* std::string s1 = "true";
* std::istringstream is1(s1);
* s1 >> x; // x == optional<bool>(true)
*
* std::string s2 = "None";
* std::istringstream is2(s2);
* s2 >> x; // x == optional<bool>()
* \endcode
*
* \param is input stream
* \param t target optional<bool> object
* \return input stream
*/
inline std::istream &operator>>(std::istream &is, optional<bool> &t) {
// Discard initial whitespace
while (isspace(is.peek()))
is.get();
// Extract chars that might be valid into a separate string, stopping
// on whitespace or other non-alphanumerics such as ",)]".
std::string s;
while (isalnum(is.peek()))
s.push_back(is.get());
if (!is.fail()) {
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
if (s == "1" || s == "true")
t = true;
else if (s == "0" || s == "false")
t = false;
else if (s == "none")
t = nullopt;
else
is.setstate(std::ios::failbit);
}
return is;
}
/*! \brief description for optional int */
DMLC_DECLARE_TYPE_NAME(optional<int>, "int or None");
/*! \brief description for optional bool */
DMLC_DECLARE_TYPE_NAME(optional<bool>, "boolean or None");
/*! \brief description for optional float */
DMLC_DECLARE_TYPE_NAME(optional<float>, "float or None");
/*! \brief description for optional double */
DMLC_DECLARE_TYPE_NAME(optional<double>, "double or None");
} // namespace dmlc
namespace std {
/*! \brief std hash function for optional */
template<typename T>
struct hash<dmlc::optional<T> > {
/*!
* \brief returns hash of the optional value.
* \param val value.
* \return hash code.
*/
size_t operator()(const dmlc::optional<T>& val) const {
std::hash<bool> hash_bool;
size_t res = hash_bool(val.has_value());
if (val.has_value()) {
res = dmlc::HashCombine(res, val.value());
}
return res;
}
};
} // namespace std
#endif // DMLC_OPTIONAL_H_
/*!
* Copyright (c) 2015 by Contributors
* \file recordio.h
* \brief recordio that is able to pack binary data into a splittable
* format, useful to exchange data in binary serialization,
* such as binary raw data or protobuf
*/
#ifndef DMLC_RECORDIO_H_
#define DMLC_RECORDIO_H_
#include <cstring>
#include <string>
#include "./io.h"
#include "./logging.h"
namespace dmlc {
/*!
* \brief writer of binary recordio
* binary format for recordio
* recordio format: magic lrecord data pad
*
* - magic is magic number
* - pad is simply a padding space to make record align to 4 bytes
* - lrecord encodes length and continue bit
* - data.length() = (lrecord & (1U<<29U - 1));
* - cflag == (lrecord >> 29U) & 7;
*
* cflag was used to handle (rare) special case when magic number
* occured in the data sequence.
*
* In such case, the data is splitted into multiple records by
* the cells of magic number
*
* (1) cflag == 0: this is a complete record;
* (2) cflag == 1: start of a multiple-rec;
* cflag == 2: middle of multiple-rec;
* cflag == 3: end of multiple-rec
*/
class RecordIOWriter {
public:
/*!
* \brief magic number of recordio
* note: (kMagic >> 29U) & 7 > 3
* this ensures lrec will not be kMagic
*/
static const uint32_t kMagic = 0xced7230a;
/*!
* \brief encode the lrecord
* \param cflag cflag part of the lrecord
* \param length length part of lrecord
* \return the encoded data
*/
inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) {
return (cflag << 29U) | length;
}
/*!
* \brief decode the flag part of lrecord
* \param rec the lrecord
* \return the flag
*/
inline static uint32_t DecodeFlag(uint32_t rec) {
return (rec >> 29U) & 7U;
}
/*!
* \brief decode the length part of lrecord
* \param rec the lrecord
* \return the length
*/
inline static uint32_t DecodeLength(uint32_t rec) {
return rec & ((1U << 29U) - 1U);
}
/*!
* \brief constructor
* \param stream the stream to be constructed
*/
explicit RecordIOWriter(Stream *stream)
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)),
except_counter_(0) {
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes";
}
/*!
* \brief write record to the stream
* \param buf the buffer of memory region
* \param size the size of record to write out
*/
void WriteRecord(const void *buf, size_t size);
/*!
* \brief write record to the stream
* \param data the data to write out
*/
inline void WriteRecord(const std::string &data) {
this->WriteRecord(data.c_str(), data.length());
}
/*!
* \return number of exceptions(occurance of magic number)
* during the writing process
*/
inline size_t except_counter(void) const {
return except_counter_;
}
/*! \brief tell the current position of the input stream */
inline size_t Tell(void) {
CHECK(seek_stream_ != NULL) << "The input stream is not seekable";
return seek_stream_->Tell();
}
private:
/*! \brief output stream */
Stream *stream_;
/*! \brief seekable stream */
SeekStream *seek_stream_;
/*! \brief counts the number of exceptions */
size_t except_counter_;
};
/*!
* \brief reader of binary recordio to reads in record from stream
* \sa RecordIOWriter
*/
class RecordIOReader {
public:
/*!
* \brief constructor
* \param stream the stream to be constructed
*/
explicit RecordIOReader(Stream *stream)
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)),
end_of_stream_(false) {
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes";
}
/*!
* \brief read next complete record from stream
* \param out_rec used to store output record in string
* \return true of read was successful, false if end of stream was reached
*/
bool NextRecord(std::string *out_rec);
/*! \brief seek to certain position of the input stream */
inline void Seek(size_t pos) {
CHECK(seek_stream_ != NULL) << "The input stream is not seekable";
seek_stream_->Seek(pos);
}
/*! \brief tell the current position of the input stream */
inline size_t Tell(void) {
CHECK(seek_stream_ != NULL) << "The input stream is not seekable";
return seek_stream_->Tell();
}
private:
/*! \brief output stream */
Stream *stream_;
SeekStream *seek_stream_;
/*! \brief whether we are at end of stream */
bool end_of_stream_;
};
/*!
* \brief reader of binary recordio from Blob returned by InputSplit
* This class divides the blob into several independent parts specified by caller,
* and read from one segment.
* The part reading can be used together with InputSplit::NextChunk for
* multi-threaded parsing(each thread take a RecordIOChunkReader)
*
* \sa RecordIOWriter, InputSplit
*/
class RecordIOChunkReader {
public:
/*!
* \brief constructor
* \param chunk source data returned by InputSplit
* \param part_index which part we want to reado
* \param num_parts number of total segments
*/
explicit RecordIOChunkReader(InputSplit::Blob chunk,
unsigned part_index = 0,
unsigned num_parts = 1);
/*!
* \brief read next complete record from stream
* the blob contains the memory content
* NOTE: this function is not threadsafe, use one
* RecordIOChunkReader per thread
* \param out_rec used to store output blob, the header is already
* removed and out_rec only contains the memory content
* \return true of read was successful, false if end was reached
*/
bool NextRecord(InputSplit::Blob *out_rec);
private:
/*! \brief internal temporal data */
std::string temp_;
/*! \brief internal data pointer */
char *pbegin_, *pend_;
};
} // namespace dmlc
#endif // DMLC_RECORDIO_H_
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Portable thread local storage.
*/
#ifndef DMLC_THREAD_LOCAL_H_
#define DMLC_THREAD_LOCAL_H_
#include <mutex>
#include <memory>
#include <vector>
#include "./base.h"
namespace dmlc {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_THREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_THREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_THREAD_LOCAL __declspec(thread)
#endif
#if DMLC_CXX11_THREAD_LOCAL == 0
#pragma message("Warning: CXX11 thread_local is not formally supported")
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
#if DMLC_CXX11_THREAD_LOCAL && DMLC_MODERN_THREAD_LOCAL == 1
static thread_local T inst;
return &inst;
#else
static MX_THREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
// Syntactic work-around for the nvcc of the initial cuda v10.1 release,
// which fails to compile 'Singleton()->' below. Fixed in v10.1 update 1.
(*Singleton()).RegisterDelete(ptr);
}
return ptr;
#endif
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
}
/*! \brief internal mutex */
std::mutex mutex_;
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace dmlc
#endif // DMLC_THREAD_LOCAL_H_
/*!
* Copyright (c) 2015 by Contributors
* \file timer.h
* \brief cross platform timer for timing
* \author Tianqi Chen
*/
#ifndef DMLC_TIMER_H_
#define DMLC_TIMER_H_
#include "base.h"
#if DMLC_USE_CXX11
#include <chrono>
#endif
#include <time.h>
#ifdef __MACH__
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#include "./logging.h"
namespace dmlc {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
#if DMLC_USE_CXX11
return std::chrono::duration<double>(
std::chrono::high_resolution_clock::now().time_since_epoch()).count();
#elif defined __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time";
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
#if defined(__unix__) || defined(__linux__)
timespec ts;
CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time";
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else
return static_cast<double>(time(NULL));
#endif
#endif
}
} // namespace dmlc
#endif // DMLC_TIMER_H_
/*!
* Copyright (c) 2015 by Contributors
* \file type_traits.h
* \brief type traits information header
*/
#ifndef DMLC_TYPE_TRAITS_H_
#define DMLC_TYPE_TRAITS_H_
#include "./base.h"
#if DMLC_USE_CXX11
#include <type_traits>
#endif
#include <string>
namespace dmlc {
/*!
* \brief whether a type is pod type
* \tparam T the type to query
*/
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_pod<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is integer type
* \tparam T the type to query
*/
template<typename T>
struct is_integral {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_integral<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is floating point type
* \tparam T the type to query
*/
template<typename T>
struct is_floating_point {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_floating_point<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is arithemetic type
* \tparam T the type to query
*/
template<typename T>
struct is_arithmetic {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_arithmetic<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = (dmlc::is_integral<T>::value ||
dmlc::is_floating_point<T>::value);
#endif
};
/*!
* \brief helper class to construct a string that represents type name
*
* Specialized this class to defined type name of custom types
*
* \tparam T the type to query
*/
template<typename T>
struct type_name_helper {
/*!
* \return a string of typename.
*/
static inline std::string value() {
return "";
}
};
/*!
* \brief the string representation of type name
* \tparam T the type to query
* \return a const string of typename.
*/
template<typename T>
inline std::string type_name() {
return type_name_helper<T>::value();
}
/*!
* \brief whether a type have save/load function
* \tparam T the type to query
*/
template<typename T>
struct has_saveload {
/*! \brief the value of the traits */
static const bool value = false;
};
/*!
* \brief template to select type based on condition
* For example, IfThenElseType<true, int, float>::Type will give int
* \tparam cond the condition
* \tparam Then the typename to be returned if cond is true
* \tparam Else typename to be returned if cond is false
*/
template<bool cond, typename Then, typename Else>
struct IfThenElseType;
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TYPE_NAME(Type, Name) \
template<> \
struct type_name_helper<Type> { \
static inline std::string value() { \
return Name; \
} \
}
//! \cond Doxygen_Suppress
// declare special traits when C++11 is not available
#if DMLC_USE_CXX11 == 0
DMLC_DECLARE_TRAITS(is_pod, char, true);
DMLC_DECLARE_TRAITS(is_pod, int8_t, true);
DMLC_DECLARE_TRAITS(is_pod, int16_t, true);
DMLC_DECLARE_TRAITS(is_pod, int32_t, true);
DMLC_DECLARE_TRAITS(is_pod, int64_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint8_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint16_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint32_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);
DMLC_DECLARE_TRAITS(is_pod, float, true);
DMLC_DECLARE_TRAITS(is_pod, double, true);
DMLC_DECLARE_TRAITS(is_integral, char, true);
DMLC_DECLARE_TRAITS(is_integral, int8_t, true);
DMLC_DECLARE_TRAITS(is_integral, int16_t, true);
DMLC_DECLARE_TRAITS(is_integral, int32_t, true);
DMLC_DECLARE_TRAITS(is_integral, int64_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint8_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint16_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint32_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint64_t, true);
DMLC_DECLARE_TRAITS(is_floating_point, float, true);
DMLC_DECLARE_TRAITS(is_floating_point, double, true);
#endif
DMLC_DECLARE_TYPE_NAME(float, "float");
DMLC_DECLARE_TYPE_NAME(double, "double");
DMLC_DECLARE_TYPE_NAME(int, "int");
DMLC_DECLARE_TYPE_NAME(int64_t, "long");
DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)");
DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)");
DMLC_DECLARE_TYPE_NAME(std::string, "string");
DMLC_DECLARE_TYPE_NAME(bool, "boolean");
DMLC_DECLARE_TYPE_NAME(void*, "ptr");
template<typename Then, typename Else>
struct IfThenElseType<true, Then, Else> {
typedef Then Type;
};
template<typename Then, typename Else>
struct IfThenElseType<false, Then, Else> {
typedef Else Type;
};
//! \endcond
} // namespace dmlc
#endif // DMLC_TYPE_TRAITS_H_
......@@ -41,20 +41,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def_readonly("tot_time", &TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<T_TemporalGraphBlock>(m, "T_TemporalGraphBlock")
.def(py::init<th::Tensor &, th::Tensor &,
th::Tensor &>())
.def_readonly("row", &T_TemporalGraphBlock::row, py::return_value_policy::reference)
.def_readonly("col", &T_TemporalGraphBlock::col, py::return_value_policy::reference)
.def_readonly("eid", &T_TemporalGraphBlock::eid, py::return_value_policy::reference)
.def_readonly("delta_ts", &T_TemporalGraphBlock::delta_ts, py::return_value_policy::reference)
.def_readonly("src_index", &T_TemporalGraphBlock::src_index, py::return_value_policy::reference)
.def_readonly("sample_nodes", &T_TemporalGraphBlock::sample_nodes, py::return_value_policy::reference)
.def_readonly("sample_nodes_ts", &T_TemporalGraphBlock::sample_nodes_ts, py::return_value_policy::reference)
.def_readonly("sample_time", &T_TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &T_TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &T_TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<TemporalNeighborBlock>(m, "TemporalNeighborBlock")
.def(py::init<vector<vector<NodeIDType>>&,
vector<int64_t> &>())
......@@ -83,11 +69,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::class_<ParallelSampler>(m, "ParallelSampler")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
vector<int>&, int, string>())
vector<int>&, int, string, int, th::Tensor &,th::Tensor &,double>())
.def_readonly("ret", &ParallelSampler::ret, py::return_value_policy::reference)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; })
.def("sample_unique", &ParallelSampler::sample_unique)
.def_readonly("dist_nid",&ParallelSampler::dist_nid,py::return_value_policy::reference)
.def_readonly("dist_eid",&ParallelSampler::dist_eid,py::return_value_policy::reference)
.def_readonly("block_node_list",&ParallelSampler::block_node_list,py::return_value_policy::reference)
.def_readonly("eid_inv",&ParallelSampler::eid_inv,py::return_value_policy::reference)
.def_readonly("unq_id",&ParallelSampler::unq_id,py::return_value_policy::reference)
.def_readonly("first_block_id",&ParallelSampler::first_block_id,py::return_value_policy::reference);
py::class_<ParallelTppRComputer>(m, "ParallelTppRComputer")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
......
/**
* Copyright (c) 2020 by Contributors
* @file dgl/array.h
* @brief Common array operations required by DGL.
*
* Note that this is not meant for a full support of array library such as ATen.
* Only a limited set of operators required by DGL are implemented.
*/
#ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_
#include "./aten/array_ops.h"
#include "./aten/coo.h"
#include "./aten/csr.h"
#include "./aten/macro.h"
#include "./aten/spmat.h"
#include "./aten/types.h"
#endif // DGL_ARRAY_H_
/**
* Copyright (c) 2020 by Contributors
* @file dgl/array_iterator.h
* @brief Various iterators.
*/
#ifndef DGL_ARRAY_ITERATOR_H_
#define DGL_ARRAY_ITERATOR_H_
#ifdef __CUDA_ARCH__
#define CUB_INLINE __host__ __device__ __forceinline__
#else
#define CUB_INLINE inline
#endif // __CUDA_ARCH__
#include <algorithm>
#include <iterator>
#include <utility>
namespace dgl {
namespace aten {
using std::swap;
// Make std::pair work on both host and device
template <typename DType>
struct Pair {
Pair() = default;
Pair(const Pair& other) = default;
Pair(Pair&& other) = default;
CUB_INLINE Pair(DType a, DType b) : first(a), second(b) {}
CUB_INLINE Pair& operator=(const Pair& other) {
first = other.first;
second = other.second;
return *this;
}
CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(first, second);
}
CUB_INLINE bool operator==(const Pair& other) const {
return (first == other.first) && (second == other.second);
}
CUB_INLINE void swap(const Pair& other) const {
std::swap(first, other.first);
std::swap(second, other.second);
}
DType first, second;
};
template <typename DType>
CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {
r1.swap(r2);
}
// PairRef and PairIterator that serves as an iterator over a pair of arrays in
// a zipped fashion like zip(a, b).
template <typename DType>
struct PairRef {
PairRef() = delete;
PairRef(const PairRef& other) = default;
PairRef(PairRef&& other) = default;
CUB_INLINE PairRef(DType* const r, DType* const c) : a(r), b(c) {}
CUB_INLINE PairRef& operator=(const PairRef& other) {
*a = *other.a;
*b = *other.b;
return *this;
}
CUB_INLINE PairRef& operator=(const Pair<DType>& val) {
*a = val.first;
*b = val.second;
return *this;
}
CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }
CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(*a, *b);
}
CUB_INLINE bool operator==(const PairRef& other) const {
return (*a == *(other.a)) && (*b == *(other.b));
}
CUB_INLINE void swap(const PairRef& other) const {
std::swap(*a, *other.a);
std::swap(*b, *other.b);
}
DType *a, *b;
};
template <typename DType>
CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {
r1.swap(r2);
}
template <typename DType>
struct PairIterator : public std::iterator<
std::random_access_iterator_tag, Pair<DType>,
std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {
PairIterator() = default;
PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default;
CUB_INLINE PairIterator(DType* x, DType* y) : a(x), b(y) {}
PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default;
CUB_INLINE bool operator==(const PairIterator& other) const {
return a == other.a;
}
CUB_INLINE bool operator!=(const PairIterator& other) const {
return a != other.a;
}
CUB_INLINE bool operator<(const PairIterator& other) const {
return a < other.a;
}
CUB_INLINE bool operator>(const PairIterator& other) const {
return a > other.a;
}
CUB_INLINE bool operator<=(const PairIterator& other) const {
return a <= other.a;
}
CUB_INLINE bool operator>=(const PairIterator& other) const {
return a >= other.a;
}
CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {
a += movement;
b += movement;
return *this;
}
CUB_INLINE PairIterator& operator-=(const std::ptrdiff_t& movement) {
a -= movement;
b -= movement;
return *this;
}
CUB_INLINE PairIterator& operator++() {
++a;
++b;
return *this;
}
CUB_INLINE PairIterator& operator--() {
--a;
--b;
return *this;
}
CUB_INLINE PairIterator operator++(int) {
PairIterator ret(*this);
operator++();
return ret;
}
CUB_INLINE PairIterator operator--(int) {
PairIterator ret(*this);
operator--();
return ret;
}
CUB_INLINE PairIterator operator+(const std::ptrdiff_t& movement) const {
return PairIterator(a + movement, b + movement);
}
CUB_INLINE PairIterator operator-(const std::ptrdiff_t& movement) const {
return PairIterator(a - movement, b - movement);
}
CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {
return a - other.a;
}
CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator[](size_t offset) const {
return PairRef<DType>(a + offset, b + offset);
}
CUB_INLINE PairRef<DType> operator[](size_t offset) {
return PairRef<DType>(a + offset, b + offset);
}
DType *a, *b;
};
}; // namespace aten
}; // namespace dgl
#endif // DGL_ARRAY_ITERATOR_H_
/**
* Copyright (c) 2020 by Contributors
* @file dgl/aten/spmat.h
* @brief Sparse matrix definitions
*/
#ifndef DGL_ATEN_SPMAT_H_
#define DGL_ATEN_SPMAT_H_
#include <string>
#include <vector>
#include "../runtime/object.h"
#include "./types.h"
namespace dgl {
/**
* @brief Sparse format.
*/
enum class SparseFormat {
kCOO = 1,
kCSR = 2,
kCSC = 3,
};
/**
* @brief Sparse format codes
*/
const dgl_format_code_t ALL_CODE = 0x7;
const dgl_format_code_t ANY_CODE = 0x0;
const dgl_format_code_t COO_CODE = 0x1;
const dgl_format_code_t CSR_CODE = 0x2;
const dgl_format_code_t CSC_CODE = 0x4;
// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo")
return SparseFormat::kCOO;
else if (name == "csr")
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::kCSC;
else
LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kCOO;
}
// Create string from sparse format.
inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return std::string("coo");
else if (sparse_format == SparseFormat::kCSR)
return std::string("csr");
else
return std::string("csc");
}
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret;
if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);
if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);
return ret;
}
inline dgl_format_code_t SparseFormatsToCode(
const std::vector<SparseFormat>& formats) {
dgl_format_code_t ret = 0;
for (auto format : formats) {
switch (format) {
case SparseFormat::kCOO:
ret |= COO_CODE;
break;
case SparseFormat::kCSR:
ret |= CSR_CODE;
break;
case SparseFormat::kCSC:
ret |= CSC_CODE;
break;
default:
LOG(FATAL) << "Only support COO/CSR/CSC formats.";
}
}
return ret;
}
inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = "";
if (code & COO_CODE) ret += "coo ";
if (code & CSR_CODE) ret += "csr ";
if (code & CSC_CODE) ret += "csc ";
return ret;
}
inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & COO_CODE) return SparseFormat::kCOO;
if (code & CSC_CODE) return SparseFormat::kCSC;
return SparseFormat::kCSR;
}
// Sparse matrix object that is exposed to python API.
struct SparseMatrix : public runtime::Object {
// Sparse format.
int32_t format = 0;
// Shape of this matrix.
int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,
// col, data}.
std::vector<IdArray> indices;
// Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general
// solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags;
SparseMatrix() {}
SparseMatrix(
int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx, const std::vector<bool>& flg)
: format(fmt),
num_rows(nrows),
num_cols(ncols),
indices(idx),
flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
};
// Define SparseMatrixRef
DGL_DEFINE_OBJECT_REF(SparseMatrixRef, SparseMatrix);
} // namespace dgl
#endif // DGL_ATEN_SPMAT_H_
/**
* Copyright (c) 2020 by Contributors
* @file dgl/aten/types.h
* @brief Array and ID types
*/
#ifndef DGL_ATEN_TYPES_H_
#define DGL_ATEN_TYPES_H_
#include <cstdint>
#include "../runtime/ndarray.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
/** @brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
* - z indicates whether coo is in use.
*/
typedef uint8_t dgl_format_code_t;
using dgl::runtime::NDArray;
typedef NDArray IdArray;
typedef NDArray DegreeArray;
typedef NDArray BoolArray;
typedef NDArray IntArray;
typedef NDArray FloatArray;
typedef NDArray TypeArray;
namespace aten {
static const DGLContext CPU{kDGLCPU, 0};
} // namespace aten
} // namespace dgl
#endif // DGL_ATEN_TYPES_H_
/**
* Copyright (c) 2020 by Contributors
* @file dgl/aten/bcast.h
* @brief Broadcast related function C++ header.
*/
#ifndef DGL_BCAST_H_
#define DGL_BCAST_H_
#include <string>
#include <vector>
#include "./runtime/ndarray.h"
using namespace dgl::runtime;
namespace dgl {
/**
* @brief Broadcast offsets and auxiliary information.
*/
struct BcastOff {
/**
* @brief offset vector of lhs operand and rhs operand.
* @note lhs_offset[i] indicates the start position of the scalar
* in lhs operand that required to compute the i-th element
* in the output, likewise for rhs_offset.
*
* For example, when lhs array has shape (1, 3) and rhs array
* has shape (5, 1), the resulting array would have shape (5, 3),
* then both lhs_offset and rhs_offset would contain 15 elements.
*
* lhs_offset: 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2
* rhs_offset: 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4
*
* in order to compute the 7-th (row 2, column 0) element in the output,
* we need the 0-th element in the lhs array and the 2-th element in the
* rhs array.
*/
std::vector<int64_t> lhs_offset, rhs_offset;
/** @brief Whether broadcast is required or not. */
bool use_bcast;
/**
* @brief Auxiliary information for kernel computation
* @note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5)
* rhs_len refers to the right hand side operand length.
* e.g. 15 for shape (3, 1, 5)
* out_len refers to the output length.
* e.g. 45 for shape (3, 3, 5)
* reduce_size refers to the reduction size (for op like dot).
* e.g. 1 for add, 5 for dot and lhs_shape,rhs_shape=(3,5)
*/
int64_t lhs_len, rhs_len, out_len, reduce_size;
};
/**
* @brief: Compute broadcast and auxiliary information given operator
* and operands for kernel computation.
* @param op: a string indicates the operator, could be `add`, `sub`,
* `mul`, `div`, `dot`, 'copy_u`, `copy_e`.
* @param lhs The left hand side operand of NDArray class.
* @param rhs The right hand side operand of NDArray class.
* @return the broadcast information of BcastOff class.
*/
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl
#endif // DGL_BCAST_H_
/**
* Copyright (c) 2023 by Contributors
* @file dgl/env_variable.h
* @brief Class about envrionment variables.
*/
#ifndef DGL_ENV_VARIABLE_H_
#define DGL_ENV_VARIABLE_H_
#include <cstdlib>
namespace dgl {
static const char* kDGLParallelForGrainSize =
std::getenv("DGL_PARALLEL_FOR_GRAIN_SIZE");
} // namespace dgl
#endif // DGL_ENV_VARIABLE_H_
/**
* Copyright (c) 2018 by Contributors
* @file dgl/graph_op.h
* @brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
#include "immutable_graph.h"
namespace dgl {
class GraphOp {
public:
/**
* @brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original
* graph.
*
* @return the reversed graph
*/
static GraphPtr Reverse(GraphPtr graph);
/**
* @brief Return the line graph.
*
* If i~j and j~i are two edges in original graph G, then
* (i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
* the line graph.
*
* @param graph The input graph.
* @param backtracking Whether the backtracking edges are included or not
* @return the line graph
*/
static GraphPtr LineGraph(GraphPtr graph, bool backtracking);
/**
* @brief Return a disjoint union of the input graphs.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph
* sizes in the given sequence order. For example, giving input [g1, g2, g3],
* where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become
* node#7 in the result graph. Edge ids are re-assigned similarly.
*
* The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type.
*
* @param graphs A list of input graphs to be unioned.
* @return the disjoint union of the graphs
*/
static GraphPtr DisjointUnion(std::vector<GraphPtr> graphs);
/**
* @brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
*
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
* @param graph The graph to be partitioned.
* @param num The number of partitions.
* @return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num);
/**
* @brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
* @param graph The graph to be partitioned.
* @param sizes The number of partitions.
* @return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes);
/**
* @brief Map vids in the parent graph to the vids in the subgraph.
*
* If the Id doesn't exist in the subgraph, -1 will be used.
*
* @param parent_vid_map An array that maps the vids in the parent graph to
* the subgraph. The elements store the vertex Ids in the parent graph, and
* the indices indicate the vertex Ids in the subgraph.
* @param query The vertex Ids in the parent graph.
* @return an Id array that contains the subgraph node Ids.
*/
static IdArray MapParentIdToSubgraphId(IdArray parent_vid_map, IdArray query);
/**
* @brief Expand an Id array based on the offset array.
*
* For example,
* ids: [0, 1, 2, 3, 4],
* offset: [0, 2, 2, 5, 6, 7],
* result: [0, 0, 2, 2, 2, 3, 4].
* The offset array has one more element than the ids array.
* (offset[i], offset[i+1]) shows the location of ids[i] in the result array.
*
* @param ids An array that contains the node or edge Ids.
* @param offset An array that contains the offset after expansion.
* @return a expanded Id array.
*/
static IdArray ExpandIds(IdArray ids, IdArray offset);
/**
* @brief Convert the graph to a simple graph.
* @param graph The input graph.
* @return a new immutable simple graph with no multi-edge.
*/
static GraphPtr ToSimpleGraph(GraphPtr graph);
/**
* @brief Convert the graph to a mutable bidirected graph.
*
* If the original graph has m edges for i -> j and n edges for
* j -> i, the new graph will have max(m, n) edges for both
* i -> j and j -> i.
*
* @param graph The input graph.
* @return a new mutable bidirected graph.
*/
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/**
* @brief Same as BidirectedMutableGraph except that the returned graph is
* immutable.
* @param graph The input graph.
* @return a new immutable bidirected
* graph.
*/
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/**
* @brief Same as BidirectedMutableGraph except that the returned graph is
* immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* than ToBidirectedImmutableGraph. It return a null pointer if the conversion
* fails.
*
* @param graph The input graph.
* @return a new immutable bidirected graph.
*/
static GraphPtr ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig);
/**
* @brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`.
* @param graph The input graph.
* @param nodes The input nodes that form the core of the induced subgraph.
* @param num_hops The number of hops to reach.
* @return the induced subgraph with HALO nodes.
*/
static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops);
/**
* @brief Reorder the nodes in the immutable graph.
* @param graph The input graph.
* @param new_order The node Ids in the new graph. The index in `new_order` is
* old node Ids.
* @return the graph with reordered node Ids
*/
static GraphPtr ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order);
};
} // namespace dgl
#endif // DGL_GRAPH_OP_H_
/**
* Copyright (c) 2018 by Contributors
* @file graph/graph_serializer.cc
* @brief DGL serializer APIs
*/
#ifndef DGL_GRAPH_SERIALIZER_H_
#define DGL_GRAPH_SERIALIZER_H_
#include <memory>
namespace dgl {
// Util class to call the private/public empty constructor, which is needed for
// serialization
class Serializer {
public:
template <typename T>
static T* new_object() {
return new T();
}
template <typename T>
static std::shared_ptr<T> make_shared() {
return std::shared_ptr<T>(new T());
}
};
} // namespace dgl
#endif // DGL_GRAPH_SERIALIZER_H_
/**
* Copyright (c) 2020 by Contributors
* @file dgl/graph_traversal.h
* @brief common graph traversal operations
*/
#ifndef DGL_GRAPH_TRAVERSAL_H_
#define DGL_GRAPH_TRAVERSAL_H_
#include "array.h"
#include "base_heterograph.h"
namespace dgl {
///////////////////////// Graph Traverse routines //////////////////////////
/**
* @brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int
* value).
*/
struct Frontiers {
/** @brief a vector store for the nodes/edges in all the frontiers */
IdArray ids;
/**
* @brief a vector store for node/edge tags. Dtype is int64.
* Empty if no tags are requested
*/
IdArray tags;
/** @brief a section vector to indicate each frontier Dtype is int64. */
IdArray sections;
};
namespace aten {
/**
* @brief Traverse the graph in a breadth-first-search (BFS) order.
*
* @param csr The input csr matrix.
* @param sources Source nodes.
* @return A Frontiers object containing the search result
*/
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
/**
* @brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree.
*
* @param csr The input csr matrix.
* @param sources Source nodes.
* @return A Frontiers object containing the search result
*/
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
/**
* @brief Traverse the graph in topological order.
*
* @param csr The input csr matrix.
* @return A Frontiers object containing the search result
*/
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
/**
* @brief Traverse the graph in a depth-first-search (DFS) order.
*
* @param csr The input csr matrix.
* @param sources Source nodes.
* @return A Frontiers object containing the search result
*/
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
/**
* @brief Traverse the graph in a depth-first-search (DFS) order and return the
* recorded edge tag if return_labels is specified.
*
* The traversal visit edges in its DFS order. Edges have three tags:
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the
* edge is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the
* edge is NOT in the DFS tree.
*
* @param csr The input csr matrix.
* @param sources Source nodes.
* @param has_reverse_edge If true, REVERSE edges are included
* @param has_nontree_edge If true, NONTREE edges are included
* @param return_labels If true, return the recorded edge tags.
* @return A Frontiers object containing the search result
*/
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels);
} // namespace aten
} // namespace dgl
#endif // DGL_GRAPH_TRAVERSAL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment