Commit 639d7bc0 by zljJoan

Merge branch 'main' of github.com:zhljJoan/startGNN_sample into main

parents d5b33231 3eb58d91
import torch
from abc import ABC
from typing import Tuple
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
def _sample_one_layer_from_nodes(
self,
nodes:torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
def _sample_one_layer_from_nodes_parallel(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
import torch
from Sampler import NeighborSampler
from neighbor_sampler import NeighborSampler
# edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
# num_nodes = 4
......
import torch
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
row,col=edge_index
row.numpy().tolist()
from neighbor_sampler import NeighborSampler
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler
neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
from typing import Tuple
import torch
import torch_scatter
import torch.multiprocessing as mp
from abc import ABC
from Sample.sample_cores import get_neighbors, neighbor_sample_from_nodes
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_node`
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def __init__(
self,
edge_index: torch.Tensor,
num_nodes: int,
num_layers: int,
workers = 1,
**kwargs
) -> None:
r"""__init__
Args:
edge_index: all edges in the graph
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads, default value is 1
**kwargs: other kwargs
"""
super().__init__()
self.edge_index=edge_index
self.num_nodes = num_nodes
self.num_layers = num_layers
self.workers = workers
def sample_from_nodes(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
def _sample_one_layer_from_nodes(
self,
nodes:torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
def _sample_one_layer_from_nodes_parallel(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
from typing import Tuple
from base import BaseSampler
from sample_cores import get_neighbors, neighbor_sample_from_nodes
class NeighborSampler(BaseSampler):
def __init__(
......@@ -106,6 +23,8 @@ class NeighborSampler(BaseSampler):
workers: the number of threads, default value is 1
"""
super().__init__(edge_index, num_nodes, num_layers, workers)
self.num_layers = num_layers
self.workers = workers
self.fanout = fanout
row, col = edge_index
tnb = get_neighbors(row.tolist(), col.tolist(), num_nodes)
......
......@@ -11,6 +11,13 @@ typedef int NodeIDType;
// typedef int EdgeIDType;
// typedef float TimeStampType;
class TemporalNeighborBlock;
class TemporalGraphBlock;
TemporalNeighborBlock get_neighbors(vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes);
TemporalGraphBlock neighbor_sample_from_node(NodeIDType node, vector<NodeIDType>& neighbors, int deg, int fanout);
TemporalGraphBlock neighbor_sample_from_nodes(vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors, vector<NodeIDType>& deg, int fanout);
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
......@@ -74,6 +81,30 @@ TemporalNeighborBlock get_neighbors(
return tnb;
}
TemporalGraphBlock neighbor_sample_from_nodes(
vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors,
vector<NodeIDType>& deg, int fanout){
TemporalGraphBlock tgb = TemporalGraphBlock();
for(int i=0; i<nodes.size(); i++){
NodeIDType node = nodes[i];
TemporalGraphBlock tgb_i = neighbor_sample_from_node(node, neighbors[node], deg[node], fanout);
tgb.row.insert(tgb.row.end(),tgb_i.row.begin(),tgb_i.row.end());
tgb.col.insert(tgb.col.end(),tgb_i.col.begin(),tgb_i.col.end());
tgb.nodes.insert(tgb.nodes.end(),tgb_i.nodes.begin(),tgb_i.nodes.end());
}
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
tgb.nodes.assign(s.begin(), s.end());
return tgb;
}
/*-------------------------------------------------------------------------------------**
**------------Utils--------------------------------------------------------------------**
**-------------------------------------------------------------------------------------*/
TemporalGraphBlock neighbor_sample_from_node(
NodeIDType node, vector<NodeIDType>& neighbors,
int deg, int fanout){
......@@ -99,27 +130,7 @@ TemporalGraphBlock neighbor_sample_from_node(
return tgb;
}
TemporalGraphBlock neighbor_sample_from_nodes(
vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors,
vector<NodeIDType>& deg, int fanout){
TemporalGraphBlock tgb = TemporalGraphBlock();
for(int i=0; i<nodes.size(); i++){
NodeIDType node = nodes[i];
TemporalGraphBlock tgb_i = neighbor_sample_from_node(node, neighbors[node], deg[node], fanout);
tgb.row.insert(tgb.row.end(),tgb_i.row.begin(),tgb_i.row.end());
tgb.col.insert(tgb.col.end(),tgb_i.col.begin(),tgb_i.col.end());
tgb.nodes.insert(tgb.nodes.end(),tgb_i.nodes.begin(),tgb_i.nodes.end());
}
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
tgb.nodes.assign(s.begin(), s.end());
return tgb;
}
/*------------Python Bind--------------------------------------------------------------*/
PYBIND11_MODULE(sample_cores, m)
{
m
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment