Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrty_sampler
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrty_sampler
Commits
0605331d
Commit
0605331d
authored
Feb 28, 2023
by
XXX
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
v9: add basical random walk sampler and do some changes in sample_cores
parent
8e6fdda3
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
21 deletions
+78
-21
Sample/base.py
+15
-15
Sample/random_walk_sampler.py
+57
-0
Sample/sample_cores.cpp
+6
-6
Sample/sample_cores.cpython-37m-x86_64-linux-gnu.so
+0
-0
No files found.
Sample/base.py
View file @
0605331d
...
...
@@ -46,22 +46,22 @@ class BaseSampler(ABC):
# """
# raise NotImplementedError
def
_sample_one_layer_from_nodes
(
self
,
nodes
:
torch
.
Tensor
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
#
def _sample_one_layer_from_nodes(
#
self,
#
nodes:torch.Tensor,
#
**kwargs
#
) -> Tuple[torch.Tensor, torch.Tensor]:
#
r"""Performs sampling from the nodes specified in: nodes,
#
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
raise
NotImplementedError
#
Args:
#
nodes: the list of seed nodes index
#
**kwargs: other kwargs
#
Returns:
#
sampled_nodes: the nodes sampled
#
sampled_edge_index: the edges sampled
#
"""
#
raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
...
...
Sample/random_walk_sampler.py
0 → 100644
View file @
0605331d
import
torch
import
torch.multiprocessing
as
mp
from
typing
import
Tuple
from
base
import
BaseSampler
from
neighbor_sampler
import
NeighborSampler
class
RandomWalkSampler
(
BaseSampler
):
def
__init__
(
self
,
edge_index
:
torch
.
Tensor
,
num_nodes
:
int
,
num_layers
:
int
,
workers
=
1
)
->
None
:
r"""__init__
Args:
edge_index: all edges in the graph
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
"""
super
()
.
__init__
()
self
.
sampler
=
NeighborSampler
(
edge_index
,
num_nodes
,
num_layers
,
[
1
for
_
in
range
(
num_layers
)],
workers
)
self
.
num_layers
=
num_layers
# 线程数不超过torch默认的omp线程数
self
.
workers
=
min
(
workers
,
torch
.
get_num_threads
())
def
sample_from_nodes
(
self
,
nodes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
list
]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return
self
.
sampler
.
sample_from_nodes
(
nodes
)
if
__name__
==
"__main__"
:
edge_index1
=
torch
.
tensor
([[
0
,
1
,
1
,
1
,
2
,
2
,
2
,
4
,
4
,
4
,
5
],
# , 3, 3
[
1
,
0
,
2
,
4
,
1
,
3
,
0
,
3
,
5
,
0
,
2
]])
# , 2, 5
num_nodes1
=
6
# Run the random walk sampling
sampler
=
RandomWalkSampler
(
edge_index
=
edge_index1
,
num_nodes
=
num_nodes1
,
num_layers
=
3
,
workers
=
4
)
sampled_nodes
,
sampled_edge_index
=
sampler
.
sample_from_nodes
(
torch
.
tensor
([
1
,
2
]))
# Print the result
print
(
'sampled_nodes_id:
\n
'
,
sampled_nodes
,
'
\n
edge_index:
\n
'
,
sampled_edge_index
)
Sample/sample_cores.cpp
View file @
0605331d
...
...
@@ -120,15 +120,15 @@ TemporalGraphBlock neighbor_sample_from_node(
NodeIDType
node
,
vector
<
NodeIDType
>&
neighbors
,
int
deg
,
int
fanout
,
int
threads
){
TemporalGraphBlock
tgb
=
TemporalGraphBlock
();
tgb
.
col
=
neighbors
;
srand
((
int
)
time
(
0
));
if
(
deg
>
fanout
){
//度大于扇出的话需要随机
删除一些
邻居
//度大于扇出的话需要随机
选择fanout个
邻居
#pragma omp parallel for num_threads(threads)
for
(
int
i
=
0
;
i
<
deg
-
fanout
;
i
++
){
//循环删除deg-fanout个邻居
auto
erase_iter
=
tgb
.
col
.
begin
()
+
rand
()
%
(
deg
-
i
);
tgb
.
col
.
erase
(
erase_iter
);
for
(
int
i
=
0
;
i
<
fanout
;
i
++
){
//循环选择fanout个邻居
auto
chosen_iter
=
neighbors
.
begin
()
+
rand
()
%
(
deg
-
i
);
tgb
.
col
.
push_back
(
*
chosen_iter
);
neighbors
.
erase
(
chosen_iter
);
}
}
tgb
.
row
.
resize
(
tgb
.
col
.
size
(),
node
);
...
...
Sample/sample_cores.cpython-37m-x86_64-linux-gnu.so
View file @
0605331d
No preview for this file type
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment