Commit 3e80e491 by XXX

update Sample v3 with parallel

parent 05e596e1
...@@ -8,20 +8,20 @@ class BaseSampler(ABC): ...@@ -8,20 +8,20 @@ class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides r"""An abstract base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines. :meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.
""" """
def sample_from_nodes( def sample_from_node(
self, self,
node:int, node:int,
edge_index:torch.Tensor, edge_index:torch.Tensor,
num_nodes:int, num_nodes:int,
**kwargs **kwargs
) -> Tuple[torch.tensor, torch.tensor]: ) -> Tuple[torch.tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node, r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args: Args:
node: the seed node index node: the seed node index
edge_index: edges in the graph edge_index: edges in the graph
num_nodes: the num of all num_nodes: the num of all node in the graph
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
samples_nodes: the node sampled samples_nodes: the node sampled
...@@ -36,13 +36,13 @@ class BaseSampler(ABC): ...@@ -36,13 +36,13 @@ class BaseSampler(ABC):
num_nodes:int, num_nodes:int,
**kwargs **kwargs
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node, r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args: Args:
nodes: the seed nodes index nodes: the seed nodes index
edge_index: edges in the graph edge_index: edges in the graph
num_nodes: the num of all num_nodes: the num of all node in the graph
**kwargs: other kwargs **kwargs: other kwargs
Returns: Returns:
samples_nodes: the node sampled samples_nodes: the node sampled
...@@ -50,6 +50,29 @@ class BaseSampler(ABC): ...@@ -50,6 +50,29 @@ class BaseSampler(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
workers: the number of threads
fanout: the number of max neighbor chosen
Returns:
nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
class NeighborSampler(BaseSampler): class NeighborSampler(BaseSampler):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -61,13 +84,13 @@ class NeighborSampler(BaseSampler): ...@@ -61,13 +84,13 @@ class NeighborSampler(BaseSampler):
num_nodes: int, num_nodes: int,
fanout: int fanout: int
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node, r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args: Args:
node: the seed node index node: the seed node index
edge_index: edges in the graph edge_index: edges in the graph
num_nodes: the num of all node num_nodes: the num of all node in the graph
fanout: the number of max neighbor chosen fanout: the number of max neighbor chosen
Returns: Returns:
samples_nodes: the node sampled samples_nodes: the node sampled
...@@ -93,13 +116,47 @@ class NeighborSampler(BaseSampler): ...@@ -93,13 +116,47 @@ class NeighborSampler(BaseSampler):
num_nodes: int, num_nodes: int,
fanout: int fanout: int
) -> Tuple[torch.Tensor, torch.tensor]: ) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node, r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
if len(nodes)==1:
return self.sample_from_node(nodes[0], edge_index, num_nodes, fanout)
samples_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
# 单线程循环法:
for node in nodes:
samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
row=torch.concat([row,edge_index_i[0]])
col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor]. returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args: Args:
node: the seed node index node: the seed node index
edge_index: edges in the graph edge_index: edges in the graph
num_nodes: the num of all node num_nodes: the num of all node in the graph
workers: the number of threads
fanout: the number of max neighbor chosen fanout: the number of max neighbor chosen
Returns: Returns:
nodes: the node sampled nodes: the node sampled
...@@ -109,20 +166,28 @@ class NeighborSampler(BaseSampler): ...@@ -109,20 +166,28 @@ class NeighborSampler(BaseSampler):
row=torch.IntTensor([]) row=torch.IntTensor([])
col=torch.IntTensor([]) col=torch.IntTensor([])
with mp.Pool(processes=torch.get_num_threads()) as p: with mp.Pool(processes=torch.get_num_threads()) as p:
n=len(nodes)
if(workers>=n):
results = [p.apply_async(self.sample_from_node, results = [p.apply_async(self.sample_from_node,
(node, edge_index, num_nodes, fanout)) (node, edge_index, num_nodes, fanout))
for node in nodes] for node in nodes]
else:
quotient = n//workers
remainder = n%workers
# 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
nodes2 = nodes[(quotient+1)*(remainder):n].resize_(workers - remainder,quotient)# 未分配余数的batch
results = [p.apply_async(self.sample_from_nodes,
(nodes1[i], edge_index, num_nodes, fanout))
for i in range(0, remainder)]
results.extend([p.apply_async(self.sample_from_nodes,
(nodes2[i], edge_index, num_nodes, fanout))
for i in range(0, workers - remainder)])
for result in results: for result in results:
samples_nodes_i, edge_index_i = result.get() samples_nodes_i, edge_index_i = result.get()
samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i])) samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i]))
row = torch.cat([row, edge_index_i[0]]) row = torch.cat([row, edge_index_i[0]])
col = torch.cat([col, edge_index_i[1]]) col = torch.cat([col, edge_index_i[1]])
# 单线程循环法:
# for node in nodes:
# samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
# samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
# row=torch.concat([row,edge_index_i[0]])
# col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0) return samples_nodes, torch.stack([row, col], dim=0)
# 不使用sample_from_node直接取所有点邻居方法: # 不使用sample_from_node直接取所有点邻居方法:
...@@ -131,3 +196,17 @@ class NeighborSampler(BaseSampler): ...@@ -131,3 +196,17 @@ class NeighborSampler(BaseSampler):
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])]) # neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors=torch.stack([neighbors1, neighbors2], dim=0) # neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors) # print('neighbors: \n', neighbors)
if __name__=="__main__":
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler()
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index)
import torch import torch
from Sampler import NeighborSampler from Sampler import NeighborSampler
edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]]) # edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
num_nodes = 4 # num_nodes = 4
num_neighbors = 1 edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling # Run the neighbor sampling
sampler=NeighborSampler() sampler=NeighborSampler()
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbors, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors) # neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2]), edge_index, num_nodes, workers=1, fanout=num_neighbors)
neighbors, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors) neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result # Print the result
print('neighbor_nodes_id: \n',neighbors, '\nedge_index: \n',edge_index) print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index)
# import torch_scatter # import torch_scatter
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment