Commit 5b92c413 by zljJoan

Merge branch 'main' of github.com:zhljJoan/startGNN_sample into main

parents 639d7bc0 31d23a02
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Tuple
from typing import Optional, Tuple, Union
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class BaseSampler(ABC):
......@@ -25,40 +105,59 @@ class BaseSampler(ABC):
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def _sample_one_layer_from_nodes(
def sample_from_edges(
self,
nodes:torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
def _sample_one_layer_from_nodes_parallel(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise NotImplementedError
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
......@@ -9,12 +9,7 @@ num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_node(node=2, fanout=2)
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([1,3]), fanout=num_neighbors)
# sampler.workers=3
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3]), fanout=num_neighbors)
# sampler.workers=4
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3,4,5]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2,3]))
# Print the result
......@@ -29,4 +24,22 @@ print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_in
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors2)
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors[0]==1)
\ No newline at end of file
# print('neighbors: \n', neighbors[0]==1)
from base import NegativeSampling
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
# sampler
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# negative
weight = torch.tensor([0.3,0.1,0.1,0.1,0.3,0.1])
negative = NegativeSampling('binary', 2, weight)
# negative = NegativeSampling('triplet', 2, weight)
label=torch.tensor([1,2])
seed_edges = torch.tensor([[0,1],
[1,4]])
# result = sampler.sample_from_edges(edges=seed_edges)
result = sampler.sample_from_edges(edges=seed_edges, edge_label=label, neg_sampling=negative)
# Print the result
print('neighbor_nodes_id: \n',result[0], '\nedge_index: \n',result[1], '\nmetadata: \n',result[2])
\ No newline at end of file
import torch
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
row,col=edge_index
row.numpy().tolist()
from neighbor_sampler import NeighborSampler
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler
neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
import math
import torch
import torch.multiprocessing as mp
from typing import Tuple
from typing import Optional, Tuple
from base import BaseSampler
from sample_cores import get_neighbors, neighbor_sample_from_nodes
from base import BaseSampler, NegativeSampling
from Sample.sample_cores import get_neighbors, neighbor_sample_from_nodes, heads_unique
class NeighborSampler(BaseSampler):
def __init__(
......@@ -22,12 +23,14 @@ class NeighborSampler(BaseSampler):
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
"""
super().__init__(edge_index, num_nodes, num_layers, workers)
super().__init__()
self.num_layers = num_layers
self.workers = workers
# 线程数不超过torch默认的omp线程数
self.workers = min(workers, torch.get_num_threads())
self.fanout = fanout
self.num_nodes = num_nodes
row, col = edge_index
tnb = get_neighbors(row.tolist(), col.tolist(), num_nodes)
tnb = get_neighbors(row.tolist(), col.tolist(), num_nodes, self.workers)
self.neighbors = tnb.neighbors
self.deg = tnb.deg
......@@ -43,20 +46,74 @@ class NeighborSampler(BaseSampler):
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
sampled_edge_index_list: the edge sampled
"""
sampled_edge_index_list=[]
sampled_nodes=torch.IntTensor([])
sampled_edge_index_list = []
sampled_nodes = torch.IntTensor([])
src_nodes = nodes.tolist()
assert self.workers > 0, 'Workers should be positive integer!!!'
for i in range(0, self.num_layers):
sampled_nodes_i, sampled_edge_index_i = self._sample_one_layer_from_nodes_parallel(nodes, self.fanout[i])
sampled_nodes_i, sampled_edge_index_i = self._sample_one_layer_from_nodes(nodes, self.fanout[i])
sampled_nodes = torch.cat([sampled_nodes, sampled_nodes_i])
nodes = torch.unique(sampled_edge_index_i[1])
sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
sampled_edge_index_list.append(sampled_edge_index_i)
return sampled_nodes, sampled_edge_index_list
sampled_nodes = heads_unique(sampled_nodes.tolist(), src_nodes)
return torch.tensor(sampled_nodes), sampled_edge_index_list
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
src, dst = edges
num_pos = src.numel()
num_neg = 0
if edge_label is None:
edge_label = torch.ones(num_pos)
if neg_sampling is not None:
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
src = torch.cat([src, src_neg], dim=0)
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst = torch.cat([dst, dst_neg], dim=0)
if neg_sampling.is_triplet():
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst = torch.cat([dst, dst_neg], dim=0)
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
sampled_nodes, sampled_edge_index_list = self.sample_from_nodes(seed)
if neg_sampling is None or neg_sampling.is_binary():
edge_label_index = inverse_seed.view(2, -1)
metadata = {'edge_label_index':edge_label_index, 'edge_label':edge_label}
elif neg_sampling.is_triplet():
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
metadata = {'src_index':src_index, 'dst_pos_index':dst_pos_index, 'dst_neg_index':dst_neg_index}
return sampled_nodes, sampled_edge_index_list, metadata
def _sample_one_layer_from_nodes(
self,
self,
nodes: torch.Tensor,
fanout: int
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -70,55 +127,55 @@ class NeighborSampler(BaseSampler):
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
tgb = neighbor_sample_from_nodes(nodes.tolist(), self.neighbors, self.deg, fanout)
tgb = neighbor_sample_from_nodes(nodes.tolist(), self.neighbors, self.deg, fanout, self.workers)
row = torch.IntTensor(tgb.row())
col = torch.IntTensor(tgb.col())
sampled_nodes = torch.IntTensor(tgb.nodes())
return sampled_nodes, torch.stack([row,col], dim=0)
return sampled_nodes, torch.stack([row,col], dim=0)
def _sample_one_layer_from_nodes_parallel(
self,
nodes: torch.Tensor,
fanout: int
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# fanout: int
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
fanout: the number of max neighbor chosen
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
sampled_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
assert self.workers > 0, 'Workers should be positive integer!!!'
with mp.Pool(processes=torch.get_num_threads()) as p:
n=len(nodes)
if(self.workers>=n):
results = [p.apply_async(self._sample_one_layer_from_nodes,
(torch.tensor([node.item()]), fanout))
for node in nodes]
else:
quotient = n//self.workers
remainder = n%self.workers
# 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
nodes2 = nodes[(quotient+1)*(remainder):n].resize_(self.workers - remainder,quotient)# 未分配余数的batch
results = [p.apply_async(self._sample_one_layer_from_nodes,
(nodes1[i], fanout))
for i in range(0, remainder)]
results.extend([p.apply_async(self._sample_one_layer_from_nodes,
(nodes2[i], fanout))
for i in range(0, self.workers - remainder)])
for result in results:
sampled_nodes_i, sampled_edge_index_i = result.get()
sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
row = torch.cat([row, sampled_edge_index_i[0]])
col = torch.cat([col, sampled_edge_index_i[1]])
return sampled_nodes, torch.stack([row, col], dim=0)
# Args:
# nodes: the list of seed nodes index
# fanout: the number of max neighbor chosen
# Returns:
# sampled_nodes: the node sampled
# sampled_edge_index: the edge sampled
# """
# sampled_nodes=torch.IntTensor([])
# row=torch.IntTensor([])
# col=torch.IntTensor([])
# assert self.workers > 0, 'Workers should be positive integer!!!'
# with mp.Pool(processes=self.workers) as p:
# n=len(nodes)
# if(self.workers>=n):
# results = [p.apply_async(self._sample_one_layer_from_nodes,
# (torch.tensor([node.item()]), fanout))
# for node in nodes]
# else:
# quotient = n//self.workers
# remainder = n%self.workers
# # 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
# nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
# nodes2 = nodes[(quotient+1)*(remainder):n].resize_(self.workers - remainder,quotient)# 未分配余数的batch
# results = [p.apply_async(self._sample_one_layer_from_nodes,
# (nodes1[i], fanout))
# for i in range(0, remainder)]
# results.extend([p.apply_async(self._sample_one_layer_from_nodes,
# (nodes2[i], fanout))
# for i in range(0, self.workers - remainder)])
# for result in results:
# sampled_nodes_i, sampled_edge_index_i = result.get()
# sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
# row = torch.cat([row, sampled_edge_index_i[0]])
# col = torch.cat([col, sampled_edge_index_i[1]])
# return sampled_nodes, torch.stack([row, col], dim=0)
# 不使用_sample_one_layer_from_node直接取所有点邻居方法:
# row, col = edge_index
......@@ -133,15 +190,10 @@ if __name__=="__main__":
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=2, workers=2, fanout=[2, 1])
sampler=NeighborSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=3, workers=4, fanout=[2, 1, 1])
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_node(node=2, fanout=2)
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([1,3]), fanout=num_neighbors)
# sampler.workers=3
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3]), fanout=num_neighbors)
# sampler.workers=4
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes_parallel(nodes=torch.tensor([1,2,3,4,5]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2,3]))
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([2,1]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]))
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling
from neighbor_sampler import NeighborSampler
class RandomWalkSampler(BaseSampler):
def __init__(
self,
edge_index: torch.Tensor,
num_nodes: int,
num_layers: int,
workers = 1
) -> None:
r"""__init__
Args:
edge_index: all edges in the graph
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
"""
super().__init__()
self.sampler = NeighborSampler(edge_index, num_nodes, num_layers,
[1 for _ in range(num_layers)], workers)
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = min(workers, torch.get_num_threads())
def sample_from_nodes(
self,
nodes: torch.Tensor
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes)
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, edge_label, neg_sampling)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
num_nodes1 = 6
# Run the random walk sampling
sampler=RandomWalkSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=3, workers=4)
sampled_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]))
# Print the result
print('sampled_nodes_id: \n',sampled_nodes, '\nedge_index: \n',sampled_edge_index)
#include <iostream>
#include<set>
#include<pybind11/pybind11.h>
#include<pybind11/numpy.h>
#include <set>
#include <omp.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
using namespace std;
......@@ -13,10 +14,10 @@ typedef int NodeIDType;
class TemporalNeighborBlock;
class TemporalGraphBlock;
TemporalNeighborBlock get_neighbors(vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes);
TemporalGraphBlock neighbor_sample_from_node(NodeIDType node, vector<NodeIDType>& neighbors, int deg, int fanout);
TemporalGraphBlock neighbor_sample_from_nodes(vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors, vector<NodeIDType>& deg, int fanout);
TemporalNeighborBlock get_neighbors(vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes, int threads);
TemporalGraphBlock neighbor_sample_from_node(NodeIDType node, vector<NodeIDType>& neighbors, int deg, int fanout, int threads);
TemporalGraphBlock neighbor_sample_from_nodes(vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors, vector<NodeIDType>& deg, int fanout, int threads);
vector<NodeIDType> heads_unique(vector<NodeIDType>& array, vector<NodeIDType>& heads);
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
......@@ -66,37 +67,47 @@ class TemporalGraphBlock
};
TemporalNeighborBlock get_neighbors(
vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes){
vector<NodeIDType>& row, vector<NodeIDType>& col, int num_nodes, int threads){
int edge_num = row.size();
TemporalNeighborBlock tnb = TemporalNeighborBlock();
tnb.deg.resize(num_nodes, 0);
double start_time = omp_get_wtime();
#pragma omp parallel for num_threads(threads)
for(int i=0; i<num_nodes; i++)
tnb.neighbors.push_back(new vector<NodeIDType>());
#pragma omp parallel for num_threads(threads)
for(int i=0; i<edge_num; i++){
//计算节点邻居
tnb.neighbors[row[i]]->push_back(col[i]);
//计算节点度
tnb.deg[row[i]]++;
}
double end_time = omp_get_wtime();
cout<<"get_neighbors consume: "<<end_time-start_time<<"s"<<endl;
return tnb;
}
TemporalGraphBlock neighbor_sample_from_nodes(
vector<NodeIDType>& nodes, vector<vector<NodeIDType>>& neighbors,
vector<NodeIDType>& deg, int fanout){
vector<NodeIDType>& deg, int fanout, int threads){
TemporalGraphBlock tgb = TemporalGraphBlock();
double start_time = omp_get_wtime();
#pragma omp parallel for num_threads(threads)
for(int i=0; i<nodes.size(); i++){
NodeIDType node = nodes[i];
TemporalGraphBlock tgb_i = neighbor_sample_from_node(node, neighbors[node], deg[node], fanout);
TemporalGraphBlock tgb_i = neighbor_sample_from_node(node, neighbors[node], deg[node], fanout, threads);
tgb.row.insert(tgb.row.end(),tgb_i.row.begin(),tgb_i.row.end());
tgb.col.insert(tgb.col.end(),tgb_i.col.begin(),tgb_i.col.end());
tgb.nodes.insert(tgb.nodes.end(),tgb_i.nodes.begin(),tgb_i.nodes.end());
}
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
tgb.nodes.assign(s.begin(), s.end());
double end_time = omp_get_wtime();
cout<<"neighbor_sample_from_nodes consume: "<<end_time-start_time<<"s"<<endl;
//sampled nodes 插入去重
start_time = end_time;
tgb.nodes.assign(tgb.col.begin(), tgb.col.end());
heads_unique(tgb.nodes, nodes);
// cout<<"nodes: "<<tgb.nodes.size()<<endl;
end_time = omp_get_wtime();
cout<<"unique consume: "<<end_time-start_time<<"s"<<endl;
return tgb;
}
......@@ -107,29 +118,39 @@ TemporalGraphBlock neighbor_sample_from_nodes(
TemporalGraphBlock neighbor_sample_from_node(
NodeIDType node, vector<NodeIDType>& neighbors,
int deg, int fanout){
int deg, int fanout, int threads){
TemporalGraphBlock tgb = TemporalGraphBlock();
tgb.col = neighbors;
srand((int)time(0));
if(deg>fanout){
//度大于扇出的话需要随机删除一些邻居
for(int i=0; i<deg-fanout; i++){
//循环删除deg-fanout个邻居
auto erase_iter = tgb.col.begin() + rand()%(deg-i);
tgb.col.erase(erase_iter);
//度大于扇出的话需要随机选择fanout个邻居
#pragma omp parallel for num_threads(threads)
for(int i=0; i<fanout; i++){
//循环选择fanout个邻居
auto chosen_iter = neighbors.begin() + rand()%(deg-i);
tgb.col.push_back(*chosen_iter);
neighbors.erase(chosen_iter);
}
}
else
tgb.col.assign(neighbors.begin(), neighbors.end());
tgb.row.resize(tgb.col.size(), node);
//sampled nodes 去重
unordered_set<int> s;
for (int i : tgb.col)
s.insert(i);
s.insert(node);
tgb.nodes.assign(s.begin(), s.end());
//sampled nodes 暂不插入也不去重,待合并后一起插入并去重
return tgb;
}
vector<NodeIDType> heads_unique(vector<NodeIDType>& array, vector<NodeIDType>& heads){
unordered_set<NodeIDType> s(array.begin(), array.end());
#pragma omp parallel for num_threads(threads)
for(int i=0; i<heads.size(); i++){
if(s.count(heads[i])==1)
s.erase(heads[i]);
}
array.assign(s.begin(), s.end());
array.insert(array.begin(), heads.begin(), heads.end());
// cout<<"s: "<<s.size()<<" array: "<<array.size()<<endl;
return array;
}
/*------------Python Bind--------------------------------------------------------------*/
PYBIND11_MODULE(sample_cores, m)
{
......@@ -137,7 +158,9 @@ PYBIND11_MODULE(sample_cores, m)
.def("neighbor_sample_from_nodes",
&neighbor_sample_from_nodes)
.def("get_neighbors",
&get_neighbors);
&get_neighbors)
.def("heads_unique",
&heads_unique);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<std::vector<NodeIDType> &, std::vector<NodeIDType> &,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment