Commit d5b33231 by zljJoan

Merge branch 'main' of github.com:zhljJoan/startGNN_sample into main

parents e3346a38 b86ba1e3
...@@ -7,15 +7,18 @@ edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, ...@@ -7,15 +7,18 @@ edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4,
num_nodes = 6 num_nodes = 6
num_neighbors = 2 num_neighbors = 2
# Run the neighbor sampling # Run the neighbor sampling
sampler=NeighborSampler() sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2]), edge_index, num_nodes, workers=1, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result # neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_node(node=2, fanout=2)
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index) # neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([1,3]), fanout=num_neighbors)
# sampler.workers=3
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3]), fanout=num_neighbors)
# sampler.workers=4
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3,4,5]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2,3]))
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
# import torch_scatter # import torch_scatter
# nodes=torch.Tensor([1,2]) # nodes=torch.Tensor([1,2])
......
#include <iostream>
#include<set>
#include<pybind11/pybind11.h>
#include<pybind11/numpy.h>
#include <pybind11/stl.h>
using namespace std;
namespace py = pybind11;
typedef int NodeIDType;
// typedef int EdgeIDType;
// typedef float TimeStampType;
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
// non-copy value transfer
auto v = new std::vector<T>(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast<std::vector<T> *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
/*
* NeighborSampler Utils
*/
class TemporalNeighborBlock
{
public:
std::vector<vector<NodeIDType>*> neighbors;
std::vector<int> deg;
TemporalNeighborBlock(){}
TemporalNeighborBlock(std::vector<vector<NodeIDType>*>& neighbors,
std::vector<int> &deg):
neighbors(neighbors), deg(deg){}
py::array get_node_neighbor(int node_id){
return vec2npy(*(neighbors[node_id]));
}
int get_node_deg(int node_id){
return deg[node_id];
}
};
class TemporalGraphBlock
{
public:
std::vector<NodeIDType> row;
std::vector<NodeIDType> col;
std::vector<NodeIDType> nodes;
TemporalGraphBlock(){}
TemporalGraphBlock(std::vector<NodeIDType> &_row, std::vector<NodeIDType> &_col,
std::vector<NodeIDType> &_nodes):
row(_row), col(_col), nodes(_nodes){}
};
TemporalNeighborBlock get_neighbors(
vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes){
int edge_num = row.size();
TemporalNeighborBlock tnb = TemporalNeighborBlock();
tnb.deg.resize(num_nodes, 0);
for(int i=0; i<num_nodes; i++)
tnb.neighbors.push_back(new vector<NodeIDType>());
for(int i=0; i<edge_num; i++){
//计算节点邻居
tnb.neighbors[row[i]]->push_back(col[i]);
//计算节点度
tnb.deg[row[i]]++;
}
return tnb;
}
TemporalGraphBlock neighbor_sample_from_node(
NodeIDType node, vector<NodeIDType>& neighbors,
int deg, int fanout){
TemporalGraphBlock tgb = TemporalGraphBlock();
tgb.col = neighbors;
srand((int)time(0));
if(deg>fanout){
//度大于扇出的话需要随机删除一些邻居
for(int i=0; i<deg-fanout; i++){
//循环删除deg-fanout个邻居
auto erase_iter = tgb.col.begin() + rand()%(deg-i);
tgb.col.erase(erase_iter);
}
}
tgb.row.resize(tgb.col.size(), node);
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
s.insert(node);
tgb.nodes.assign(s.begin(), s.end());
return tgb;
}
TemporalGraphBlock neighbor_sample_from_nodes(
vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors,
vector<NodeIDType>& deg, int fanout){
TemporalGraphBlock tgb = TemporalGraphBlock();
for(int i=0; i<nodes.size(); i++){
NodeIDType node = nodes[i];
TemporalGraphBlock tgb_i = neighbor_sample_from_node(node, neighbors[node], deg[node], fanout);
tgb.row.insert(tgb.row.end(),tgb_i.row.begin(),tgb_i.row.end());
tgb.col.insert(tgb.col.end(),tgb_i.col.begin(),tgb_i.col.end());
tgb.nodes.insert(tgb.nodes.end(),tgb_i.nodes.begin(),tgb_i.nodes.end());
}
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
tgb.nodes.assign(s.begin(), s.end());
return tgb;
}
PYBIND11_MODULE(sample_cores, m)
{
m
.def("neighbor_sample_from_nodes",
&neighbor_sample_from_nodes)
.def("get_neighbors",
&get_neighbors);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<std::vector<NodeIDType> &, std::vector<NodeIDType> &,
std::vector<NodeIDType> &>())
.def("row", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.col); })
.def("nodes", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.nodes); });
py::class_<TemporalNeighborBlock>(m, "TemporalNeighborBlock")
.def(py::init<std::vector<vector<NodeIDType>*>&,
std::vector<int> &>())
.def_readonly("neighbors", &TemporalNeighborBlock::neighbors)
.def_readonly("deg", &TemporalNeighborBlock::deg);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment