Commit 3e80e491 by XXX

update Sample v3 with parallel

parent 05e596e1
......@@ -8,20 +8,20 @@ class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.
"""
def sample_from_nodes(
def sample_from_node(
self,
node:int,
edge_index:torch.Tensor,
num_nodes:int,
**kwargs
) -> Tuple[torch.tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
......@@ -36,19 +36,42 @@ class BaseSampler(ABC):
num_nodes:int,
**kwargs
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling paralleled from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
workers: the number of threads
fanout: the number of max neighbor chosen
Returns:
nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
class NeighborSampler(BaseSampler):
def __init__(self) -> None:
......@@ -61,13 +84,13 @@ class NeighborSampler(BaseSampler):
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
r"""Performs sampling from the node specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node
num_nodes: the num of all node in the graph
fanout: the number of max neighbor chosen
Returns:
samples_nodes: the node sampled
......@@ -85,7 +108,7 @@ class NeighborSampler(BaseSampler):
edge_index = neighbors.index_select(dim = 1, index=random_index)
samples_nodes = torch.unique(edge_index.view(-1), dim=0)
return samples_nodes, edge_index
def sample_from_nodes(
self,
nodes: torch.Tensor,
......@@ -93,13 +116,47 @@ class NeighborSampler(BaseSampler):
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all node in the graph
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
if len(nodes)==1:
return self.sample_from_node(nodes[0], edge_index, num_nodes, fanout)
samples_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
# 单线程循环法:
for node in nodes:
samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
row=torch.concat([row,edge_index_i[0]])
col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
def sample_from_nodes_parallel(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
workers: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node
num_nodes: the num of all node in the graph
workers: the number of threads
fanout: the number of max neighbor chosen
Returns:
nodes: the node sampled
......@@ -109,20 +166,28 @@ class NeighborSampler(BaseSampler):
row=torch.IntTensor([])
col=torch.IntTensor([])
with mp.Pool(processes=torch.get_num_threads()) as p:
results = [p.apply_async(self.sample_from_node,
(node, edge_index, num_nodes, fanout))
for node in nodes]
n=len(nodes)
if(workers>=n):
results = [p.apply_async(self.sample_from_node,
(node, edge_index, num_nodes, fanout))
for node in nodes]
else:
quotient = n//workers
remainder = n%workers
# 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
nodes2 = nodes[(quotient+1)*(remainder):n].resize_(workers - remainder,quotient)# 未分配余数的batch
results = [p.apply_async(self.sample_from_nodes,
(nodes1[i], edge_index, num_nodes, fanout))
for i in range(0, remainder)]
results.extend([p.apply_async(self.sample_from_nodes,
(nodes2[i], edge_index, num_nodes, fanout))
for i in range(0, workers - remainder)])
for result in results:
samples_nodes_i, edge_index_i = result.get()
samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i]))
row = torch.cat([row, edge_index_i[0]])
col = torch.cat([col, edge_index_i[1]])
# 单线程循环法:
# for node in nodes:
# samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
# samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
# row=torch.concat([row,edge_index_i[0]])
# col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
# 不使用sample_from_node直接取所有点邻居方法:
......@@ -131,3 +196,17 @@ class NeighborSampler(BaseSampler):
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors)
if __name__=="__main__":
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler()
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3]), edge_index, num_nodes, workers=3, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index)
import torch
from Sampler import NeighborSampler
edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
num_nodes = 4
num_neighbors = 1
# edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
# num_nodes = 4
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler()
# neighbors, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
neighbors, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2]), edge_index, num_nodes, workers=1, fanout=num_neighbors)
neighbor_nodes, edge_index = sampler.sample_from_nodes_parallel(torch.tensor([1,2,3,4,5]), edge_index, num_nodes, workers=4, fanout=num_neighbors)
# Print the result
print('neighbor_nodes_id: \n',neighbors, '\nedge_index: \n',edge_index)
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',edge_index)
# import torch_scatter
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment