Commit c6330899 by XXX

update Sample v5: add muti-layer sample

parent aee067ba
...@@ -4,30 +4,36 @@ import torch_scatter ...@@ -4,30 +4,36 @@ import torch_scatter
import torch.multiprocessing as mp import torch.multiprocessing as mp
from abc import ABC from abc import ABC
from Sample.sample_cores import neighbor_sample_from_node, TemporalGraphBlock from sample_cores import neighbor_sample_from_node
class BaseSampler(ABC): class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides r"""An abstract base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines. :meth:`sample_from_node`
:meth:`sample_from_nodes`
:meth:`sample_from_nodes_parallel`
:meth:`sample_mutilayer_from_nodes` routines.
""" """
def __init__(self) -> None:
super().__init__()
def sample_from_node( def sample_from_node(
self, self,
node:int, node:int,
edge_index:torch.Tensor, edge_index:torch.Tensor,
num_nodes:int, num_nodes:int,
**kwargs **kwargs
) -> Tuple[torch.tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the node specified in: node, r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
node: the seed node index node: the seed node index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
samples_nodes: the node sampled sampled_nodes: the nodes sampled
edge_index: the edge sampled sampled_edge_index: the edges sampled
""" """
raise NotImplementedError raise NotImplementedError
...@@ -37,18 +43,18 @@ class BaseSampler(ABC): ...@@ -37,18 +43,18 @@ class BaseSampler(ABC):
edge_index:torch.Tensor, edge_index:torch.Tensor,
num_nodes:int, num_nodes:int,
**kwargs **kwargs
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes, r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
nodes: the seed nodes index nodes: the list of seed nodes index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
samples_nodes: the node sampled sampled_nodes: the nodes sampled
edge_index: the edge sampled sampled_edge_index: the edges sampled
""" """
raise NotImplementedError raise NotImplementedError
...@@ -59,24 +65,49 @@ class BaseSampler(ABC): ...@@ -59,24 +65,49 @@ class BaseSampler(ABC):
num_nodes: int, num_nodes: int,
workers: int, workers: int,
**kwargs **kwargs
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes, r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
node: the seed node index nodes: the list of seed nodes index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
workers: the number of threads workers: the number of threads
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
nodes: the node sampled sampled_nodes: the nodes sampled
edge_index: the edge sampled sampled_edge_index: the edges sampled
"""
raise NotImplementedError
def sample_mutilayer_from_nodes(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
num_layers: int,
workers: int,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
edge_index: all edges in the graph
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
""" """
raise NotImplementedError raise NotImplementedError
class NeighborSampler(BaseSampler): class NeighborSampler(BaseSampler):
def __init__(self) -> None: def __init__(self, edge_index, num_nodes) -> None:
super().__init__() super().__init__()
def sample_from_node( def sample_from_node(
...@@ -85,18 +116,18 @@ class NeighborSampler(BaseSampler): ...@@ -85,18 +116,18 @@ class NeighborSampler(BaseSampler):
edge_index: torch.Tensor, edge_index: torch.Tensor,
num_nodes: int, num_nodes: int,
fanout: int fanout: int
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the node specified in: node, r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
node: the seed node index node: the seed node index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
fanout: the number of max neighbor chosen fanout: the number of max neighbor chosen
Returns: Returns:
samples_nodes: the node sampled sampled_nodes: the nodes sampled
edge_index: the edge sampled sampled_edge_index: the edges sampled
""" """
row, col = edge_index row, col = edge_index
row = row.numpy().tolist() row = row.numpy().tolist()
...@@ -104,8 +135,8 @@ class NeighborSampler(BaseSampler): ...@@ -104,8 +135,8 @@ class NeighborSampler(BaseSampler):
tgb = neighbor_sample_from_node(node, row, col, num_nodes, fanout) tgb = neighbor_sample_from_node(node, row, col, num_nodes, fanout)
row = torch.IntTensor(tgb.row()) row = torch.IntTensor(tgb.row())
col = torch.IntTensor(tgb.col()) col = torch.IntTensor(tgb.col())
samples_nodes = torch.IntTensor(tgb.nodes()) sampled_nodes = torch.IntTensor(tgb.nodes())
return samples_nodes, torch.stack([row, col], dim=0) return sampled_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes( def sample_from_nodes(
self, self,
...@@ -113,31 +144,31 @@ class NeighborSampler(BaseSampler): ...@@ -113,31 +144,31 @@ class NeighborSampler(BaseSampler):
edge_index: torch.Tensor, edge_index: torch.Tensor,
num_nodes: int, num_nodes: int,
fanout: int fanout: int
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes, r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
nodes: the seed nodes index nodes: the list of seed nodes index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
samples_nodes: the node sampled sampled_nodes: the nodes sampled
edge_index: the edge sampled sampled_edge_index: the edges sampled
""" """
if len(nodes)==1: if len(nodes)==1:
return self.sample_from_node(nodes[0], edge_index, num_nodes, fanout) return self.sample_from_node(nodes[0], edge_index, num_nodes, fanout)
samples_nodes=torch.IntTensor([]) sampled_nodes=torch.IntTensor([])
row=torch.IntTensor([]) row=torch.IntTensor([])
col=torch.IntTensor([]) col=torch.IntTensor([])
# 单线程循环法: # 单线程循环法:
for node in nodes: for node in nodes:
samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout) sampled_nodes_i,sampled_edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i])) sampled_nodes=torch.unique(torch.concat([sampled_nodes,sampled_nodes_i]))
row=torch.concat([row,edge_index_i[0]]) row=torch.concat([row,sampled_edge_index_i[0]])
col=torch.concat([col,edge_index_i[1]]) col=torch.concat([col,sampled_edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0) return sampled_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes_parallel( def sample_from_nodes_parallel(
self, self,
...@@ -146,23 +177,24 @@ class NeighborSampler(BaseSampler): ...@@ -146,23 +177,24 @@ class NeighborSampler(BaseSampler):
num_nodes: int, num_nodes: int,
workers: int, workers: int,
fanout: int fanout: int
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes, r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args: Args:
node: the seed node index nodes: the list of seed nodes index
edge_index: edges in the graph edge_index: all edges in the graph
num_nodes: the num of all node in the graph num_nodes: the num of all nodes in the graph
workers: the number of threads workers: the number of threads
fanout: the number of max neighbor chosen fanout: the number of max neighbor chosen
Returns: Returns:
nodes: the node sampled sampled_nodes: the node sampled
edge_index: the edge sampled sampled_edge_index: the edge sampled
""" """
samples_nodes=torch.IntTensor([]) sampled_nodes=torch.IntTensor([])
row=torch.IntTensor([]) row=torch.IntTensor([])
col=torch.IntTensor([]) col=torch.IntTensor([])
assert workers > 0, 'Workers should be positive integer!!!'
with mp.Pool(processes=torch.get_num_threads()) as p: with mp.Pool(processes=torch.get_num_threads()) as p:
n=len(nodes) n=len(nodes)
if(workers>=n): if(workers>=n):
...@@ -182,11 +214,11 @@ class NeighborSampler(BaseSampler): ...@@ -182,11 +214,11 @@ class NeighborSampler(BaseSampler):
(nodes2[i], edge_index, num_nodes, fanout)) (nodes2[i], edge_index, num_nodes, fanout))
for i in range(0, workers - remainder)]) for i in range(0, workers - remainder)])
for result in results: for result in results:
samples_nodes_i, edge_index_i = result.get() sampled_nodes_i, sampled_edge_index_i = result.get()
samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i])) sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
row = torch.cat([row, edge_index_i[0]]) row = torch.cat([row, sampled_edge_index_i[0]])
col = torch.cat([col, edge_index_i[1]]) col = torch.cat([col, sampled_edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0) return sampled_nodes, torch.stack([row, col], dim=0)
# 不使用sample_from_node直接取所有点邻居方法: # 不使用sample_from_node直接取所有点邻居方法:
# row, col = edge_index # row, col = edge_index
...@@ -195,16 +227,51 @@ class NeighborSampler(BaseSampler): ...@@ -195,16 +227,51 @@ class NeighborSampler(BaseSampler):
# neighbors=torch.stack([neighbors1, neighbors2], dim=0) # neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors) # print('neighbors: \n', neighbors)
def sample_mutilayer_from_nodes(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
num_layers: int,
fanout: list,
workers = 1,
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
edge_index: all edges in the graph
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
sampled_edge_index_list=[]
sampled_nodes=torch.IntTensor([])
assert workers > 0, 'Workers should be positive integer!!!'
for i in range(0, num_layers):
sampled_nodes_i, sampled_edge_index_i = self.sample_from_nodes_parallel(nodes, edge_index, num_nodes, workers, fanout[i])
nodes = torch.unique(sampled_edge_index_i[1])
sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
sampled_edge_index_list.append(sampled_edge_index_i)
return sampled_nodes, sampled_edge_index_list
if __name__=="__main__": if __name__=="__main__":
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]]) edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6 num_nodes = 6
num_neighbors = 2 num_neighbors = 2
# Run the neighbor sampling # Run the neighbor sampling
sampler=NeighborSampler() sampler=NeighborSampler()
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors) # neighbor_nodes, sampled_edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors) # neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors) # neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=1, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors) # neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_mutilayer_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, 2, workers=1, fanout=[2,1])
# Print the result # Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index) print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
from typing import Tuple
import torch
import torch_scatter
import torch.multiprocessing as mp
from abc import ABC
from Sample.sample_cores import neighbor_sample_from_node, TemporalGraphBlock
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.
"""
def sample_from_node(
self,
node:int,
edge_index:torch.Tensor,
num_nodes:int,
**kwargs
) -> Tuple[torch.tensor, torch.tensor]:
r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
def sample_from_nodes(
self,
nodes:torch.Tensor,
edge_index:torch.Tensor,
num_nodes:int,
**kwargs
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
**kwargs
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
workers: the number of threads
**kwargs: other kwargs
Returns:
nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
class NeighborSampler(BaseSampler):
def __init__(self) -> None:
super().__init__()
def sample_from_node(
self,
node: int,
edge_index: torch.Tensor,
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
fanout: the number of max neighbor chosen
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
row, col = edge_index
row = row.numpy().tolist()
col = col.numpy().tolist()
tgb = neighbor_sample_from_node(node, row, col, num_nodes, fanout)
row = torch.IntTensor(tgb.row())
col = torch.IntTensor(tgb.col())
samples_nodes = torch.IntTensor(tgb.nodes())
return samples_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
if len(nodes)==1:
return self.sample_from_node(nodes[0], edge_index, num_nodes, fanout)
samples_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
# 单线程循环法:
for node in nodes:
samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
row=torch.concat([row,edge_index_i[0]])
col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
workers: the number of threads
fanout: the number of max neighbor chosen
Returns:
nodes: the node sampled
edge_index: the edge sampled
"""
samples_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
with mp.Pool(processes=torch.get_num_threads()) as p:
n=len(nodes)
if(workers>=n):
results = [p.apply_async(self.sample_from_node,
(node, edge_index, num_nodes, fanout))
for node in nodes]
else:
quotient = n//workers
remainder = n%workers
# 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
nodes2 = nodes[(quotient+1)*(remainder):n].resize_(workers - remainder,quotient)# 未分配余数的batch
results = [p.apply_async(self.sample_from_nodes,
(nodes1[i], edge_index, num_nodes, fanout))
for i in range(0, remainder)]
results.extend([p.apply_async(self.sample_from_nodes,
(nodes2[i], edge_index, num_nodes, fanout))
for i in range(0, workers - remainder)])
for result in results:
samples_nodes_i, edge_index_i = result.get()
samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i]))
row = torch.cat([row, edge_index_i[0]])
col = torch.cat([col, edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
# 不使用sample_from_node直接取所有点邻居方法:
# row, col = edge_index
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors)
if __name__=="__main__":
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler()
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment