Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
b9450874
Commit
b9450874
authored
Dec 27, 2023
by
xxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
modify static graph sampler
parent
f84aa32e
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
332 additions
and
158 deletions
+332
-158
csrc/sampler/include/neighbors.h
+1
-4
csrc/sampler/include/sampler.h
+38
-24
starrygl/sample/sample_core/my_sample_test_static.py
+57
-48
starrygl/sample/sample_core/neighbor_sampler.py
+3
-2
starrygl/sample/sample_core/ogb_demo_static.py
+75
-70
starrygl/sample/sample_core/reddit_demo_before.py
+23
-10
starrygl/sample/sample_core/reddit_demo_static.py
+135
-0
No files found.
csrc/sampler/include/neighbors.h
View file @
b9450874
...
@@ -258,19 +258,16 @@ TemporalNeighborBlock& get_neighbors(
...
@@ -258,19 +258,16 @@ TemporalNeighborBlock& get_neighbors(
if
(
is_distinct
){
if
(
is_distinct
){
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
){
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
){
//收集
单边去重节点度
//收集
去重邻居
phmap
::
parallel_flat_hash_set
<
NodeIDType
>
temp_s
;
phmap
::
parallel_flat_hash_set
<
NodeIDType
>
temp_s
;
temp_s
.
insert
(
tnb
.
neighbors
[
i
].
begin
(),
tnb
.
neighbors
[
i
].
end
());
temp_s
.
insert
(
tnb
.
neighbors
[
i
].
begin
(),
tnb
.
neighbors
[
i
].
end
());
tnb
.
neighbors_set
.
emplace_back
(
temp_s
);
tnb
.
neighbors_set
.
emplace_back
(
temp_s
);
tnb
.
deg
[
i
]
=
tnb
.
neighbors_set
[
i
].
size
();
}
}
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
){
for
(
int64_t
i
=
0
;
i
<
num_nodes
;
i
++
){
//收集单边节点度
//收集单边节点度
tnb
.
deg
[
i
]
=
tnb
.
neighbors
[
i
].
size
();
tnb
.
deg
[
i
]
=
tnb
.
neighbors
[
i
].
size
();
}
}
}
double
end_time
=
omp_get_wtime
();
double
end_time
=
omp_get_wtime
();
cout
<<
"get_neighbors consume: "
<<
end_time
-
start_time
<<
"s"
<<
endl
;
cout
<<
"get_neighbors consume: "
<<
end_time
-
start_time
<<
"s"
<<
endl
;
return
tnb
;
return
tnb
;
...
...
csrc/sampler/include/sampler.h
View file @
b9450874
...
@@ -33,16 +33,16 @@ class ParallelSampler
...
@@ -33,16 +33,16 @@ class ParallelSampler
ret
.
resize
(
num_layers
);
ret
.
resize
(
num_layers
);
}
}
void
neighbor_sample_from_nodes
(
th
::
Tensor
nodes
,
optional
<
th
::
Tensor
>
root_ts
);
void
neighbor_sample_from_nodes
(
th
::
Tensor
nodes
,
optional
<
th
::
Tensor
>
root_ts
,
optional
<
bool
>
part_unique
);
void
neighbor_sample_from_nodes_static
(
th
::
Tensor
nodes
);
void
neighbor_sample_from_nodes_static
(
th
::
Tensor
nodes
,
bool
part_unique
);
void
neighbor_sample_from_nodes_static_layer
(
th
::
Tensor
nodes
,
int
cur_layer
);
void
neighbor_sample_from_nodes_static_layer
(
th
::
Tensor
nodes
,
int
cur_layer
,
bool
part_unique
);
void
neighbor_sample_from_nodes_with_before
(
th
::
Tensor
nodes
,
th
::
Tensor
root_ts
);
void
neighbor_sample_from_nodes_with_before
(
th
::
Tensor
nodes
,
th
::
Tensor
root_ts
);
void
neighbor_sample_from_nodes_with_before_layer
(
th
::
Tensor
nodes
,
th
::
Tensor
root_ts
,
int
cur_layer
);
void
neighbor_sample_from_nodes_with_before_layer
(
th
::
Tensor
nodes
,
th
::
Tensor
root_ts
,
int
cur_layer
);
};
};
void
ParallelSampler
::
neighbor_sample_from_nodes
(
th
::
Tensor
nodes
,
optional
<
th
::
Tensor
>
root_ts
)
void
ParallelSampler
::
neighbor_sample_from_nodes
(
th
::
Tensor
nodes
,
optional
<
th
::
Tensor
>
root_ts
,
optional
<
bool
>
part_unique
)
{
{
omp_set_num_threads
(
threads
);
omp_set_num_threads
(
threads
);
if
(
policy
==
"weighted"
)
if
(
policy
==
"weighted"
)
...
@@ -60,11 +60,12 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th
...
@@ -60,11 +60,12 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th
neighbor_sample_from_nodes_with_before
(
nodes
,
root_ts
.
value
());
neighbor_sample_from_nodes_with_before
(
nodes
,
root_ts
.
value
());
}
}
else
{
else
{
neighbor_sample_from_nodes_static
(
nodes
);
bool
flag
=
part_unique
.
has_value
()
?
part_unique
.
value
()
:
true
;
neighbor_sample_from_nodes_static
(
nodes
,
flag
);
}
}
}
}
void
ParallelSampler
::
neighbor_sample_from_nodes_static_layer
(
th
::
Tensor
nodes
,
int
cur_layer
){
void
ParallelSampler
::
neighbor_sample_from_nodes_static_layer
(
th
::
Tensor
nodes
,
int
cur_layer
,
bool
part_unique
){
py
::
gil_scoped_release
release
;
py
::
gil_scoped_release
release
;
double
tot_start_time
=
omp_get_wtime
();
double
tot_start_time
=
omp_get_wtime
();
...
@@ -73,24 +74,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
...
@@ -73,24 +74,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
ret
[
cur_layer
]
=
TemporalGraphBlock
();
ret
[
cur_layer
]
=
TemporalGraphBlock
();
auto
nodes_data
=
get_data_ptr
<
NodeIDType
>
(
nodes
);
auto
nodes_data
=
get_data_ptr
<
NodeIDType
>
(
nodes
);
vector
<
phmap
::
parallel_flat_hash_set
<
NodeIDType
>>
node_s_threads
(
threads
);
vector
<
phmap
::
parallel_flat_hash_set
<
NodeIDType
>>
node_s_threads
(
threads
);
//
vector<vector<NodeIDType>> node_threads(threads);
vector
<
vector
<
NodeIDType
>>
node_threads
(
threads
);
phmap
::
parallel_flat_hash_set
<
NodeIDType
>
node_s
;
phmap
::
parallel_flat_hash_set
<
NodeIDType
>
node_s
;
vector
<
vector
<
NodeIDType
>>
eid_threads
(
threads
);
//row_threads(threads),col_threads(threads);
vector
<
vector
<
NodeIDType
>>
eid_threads
(
threads
);
vector
<
vector
<
NodeIDType
>>
src_index_threads
(
threads
);
vector
<
vector
<
NodeIDType
>>
src_index_threads
(
threads
);
AT_ASSERTM
(
tnb
.
with_eid
,
"Tnb has no eid infomation! We need eid!"
);
AT_ASSERTM
(
tnb
.
with_eid
,
"Tnb has no eid infomation! We need eid!"
);
// double start_time = omp_get_wtime();
// double start_time = omp_get_wtime();
int
reserve_capacity
=
int
(
ceil
(
nodes
.
size
(
0
)
/
threads
))
*
fanout
;
int
reserve_capacity
=
int
(
ceil
(
nodes
.
size
(
0
)
/
threads
))
*
fanout
;
#pragma omp parallel
#pragma omp parallel
{
{
int
tid
=
omp_get_thread_num
();
int
tid
=
omp_get_thread_num
();
unsigned
int
loc_seed
=
tid
;
unsigned
int
loc_seed
=
tid
;
eid_threads
[
tid
].
reserve
(
reserve_capacity
);
eid_threads
[
tid
].
reserve
(
reserve_capacity
);
src_index_threads
[
tid
].
reserve
(
reserve_capacity
);
src_index_threads
[
tid
].
reserve
(
reserve_capacity
);
// node_threads[tid].reserve(reserve_capacity);
if
(
!
part_unique
)
node_threads
[
tid
].
reserve
(
reserve_capacity
);
#pragma omp for schedule(static, int(ceil(static_cast<float>((nodes.size(0)) / threads))))
#pragma omp for schedule(static, int(ceil(static_cast<float>((nodes.size(0)) / threads))))
for
(
int64_t
i
=
0
;
i
<
nodes
.
size
(
0
);
i
++
){
for
(
int64_t
i
=
0
;
i
<
nodes
.
size
(
0
);
i
++
){
// int tid = omp_get_thread_num();
NodeIDType
node
=
nodes_data
[
i
];
NodeIDType
node
=
nodes_data
[
i
];
vector
<
NodeIDType
>&
nei
=
tnb
.
neighbors
[
node
];
vector
<
NodeIDType
>&
nei
=
tnb
.
neighbors
[
node
];
vector
<
EdgeIDType
>
edge
;
vector
<
EdgeIDType
>
edge
;
...
@@ -98,12 +100,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
...
@@ -98,12 +100,11 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
double
s_start_time
=
omp_get_wtime
();
double
s_start_time
=
omp_get_wtime
();
if
(
tnb
.
deg
[
node
]
>
fanout
){
if
(
tnb
.
deg
[
node
]
>
fanout
){
src_index_threads
[
tid
].
insert
(
src_index_threads
[
tid
].
end
(),
fanout
,
i
);
phmap
::
flat_hash_set
<
NodeIDType
>
temp_s
;
phmap
::
flat_hash_set
<
NodeIDType
>
temp_s
;
default_random_engine
e
(
8
);
//(time(0));
default_random_engine
e
(
8
);
//(time(0));
uniform_int_distribution
<>
u
(
0
,
tnb
.
deg
[
node
]
-
1
);
// uniform_int_distribution<> u(0, tnb.deg[node]-1);
while
(
temp_s
.
size
()
!=
fanout
){
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()
){
//
for(int i=0;i<fanout;i++){
for
(
int
i
=
0
;
i
<
fanout
;
i
++
){
//循环选择fanout个邻居
//循环选择fanout个邻居
NodeIDType
indice
;
NodeIDType
indice
;
if
(
policy
==
"weighted"
){
//考虑边权重信息
if
(
policy
==
"weighted"
){
//考虑边权重信息
...
@@ -112,24 +113,34 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
...
@@ -112,24 +113,34 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
}
}
else
if
(
policy
==
"uniform"
){
//均匀采样
else
if
(
policy
==
"uniform"
){
//均匀采样
// indice = u(e);
// indice = u(e);
indice
=
rand_r
(
&
loc_seed
)
%
(
tnb
.
deg
[
node
]
);
indice
=
rand_r
(
&
loc_seed
)
%
(
nei
.
size
()
);
}
}
auto
chosen_n_iter
=
nei
.
begin
()
+
indice
;
auto
chosen_n_iter
=
nei
.
begin
()
+
indice
;
// auto chosen_e_iter = edge.begin() + indice;
auto
chosen_e_iter
=
edge
.
begin
()
+
indice
;
// eid_threads[tid].emplace_back(*chosen_e_iter);
if
(
part_unique
){
// node_threads[tid].emplace_back(*chosen_n_iter);
auto
rst
=
temp_s
.
insert
(
*
chosen_n_iter
);
auto
rst
=
temp_s
.
insert
(
*
chosen_n_iter
);
if
(
rst
.
second
){
//不重复
if
(
rst
.
second
){
//不重复
auto
chosen_e_iter
=
edge
.
begin
()
+
indice
;
eid_threads
[
tid
].
emplace_back
(
*
chosen_e_iter
);
eid_threads
[
tid
].
emplace_back
(
*
chosen_e_iter
);
node_s_threads
[
tid
].
insert
(
*
chosen_n_iter
);
node_s_threads
[
tid
].
insert
(
*
chosen_n_iter
);
if
(
temp_s
.
size
()
<
fanout
&&
!
tnb
.
neighbors_set
.
empty
()
&&
temp_s
.
size
()
<
tnb
.
neighbors_set
[
node
].
size
())
fanout
++
;
}
}
}
}
else
{
eid_threads
[
tid
].
emplace_back
(
*
chosen_e_iter
);
node_threads
[
tid
].
emplace_back
(
*
chosen_n_iter
);
}
}
if
(
part_unique
)
src_index_threads
[
tid
].
insert
(
src_index_threads
[
tid
].
end
(),
temp_s
.
size
(),
i
);
else
src_index_threads
[
tid
].
insert
(
src_index_threads
[
tid
].
end
(),
fanout
,
i
);
}
}
else
{
else
{
src_index_threads
[
tid
].
insert
(
src_index_threads
[
tid
].
end
(),
tnb
.
deg
[
node
],
i
);
src_index_threads
[
tid
].
insert
(
src_index_threads
[
tid
].
end
(),
tnb
.
deg
[
node
],
i
);
// node_threads[tid].insert(node_threads[tid].end(), nei.begin(), nei.end());
if
(
part_unique
)
node_s_threads
[
tid
].
insert
(
nei
.
begin
(),
nei
.
end
());
node_s_threads
[
tid
].
insert
(
nei
.
begin
(),
nei
.
end
());
else
node_threads
[
tid
].
insert
(
node_threads
[
tid
].
end
(),
nei
.
begin
(),
nei
.
end
());
eid_threads
[
tid
].
insert
(
eid_threads
[
tid
].
end
(),
edge
.
begin
(),
edge
.
end
());
eid_threads
[
tid
].
insert
(
eid_threads
[
tid
].
end
(),
edge
.
begin
(),
edge
.
end
());
}
}
if
(
tid
==
0
)
if
(
tid
==
0
)
...
@@ -153,22 +164,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
...
@@ -153,22 +164,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for schedule(static, 1)
for
(
int
i
=
0
;
i
<
threads
;
i
++
){
for
(
int
i
=
0
;
i
<
threads
;
i
++
){
copy
(
eid_threads
[
i
].
begin
(),
eid_threads
[
i
].
end
(),
ret
[
cur_layer
].
eid
.
begin
()
+
each_begin
[
i
]);
copy
(
eid_threads
[
i
].
begin
(),
eid_threads
[
i
].
end
(),
ret
[
cur_layer
].
eid
.
begin
()
+
each_begin
[
i
]);
// copy(node_threads[i].begin(), node_threads[i].end(), ret[cur_layer].sample_nodes.begin()+each_begin[i]);
if
(
!
part_unique
)
copy
(
node_threads
[
i
].
begin
(),
node_threads
[
i
].
end
(),
ret
[
cur_layer
].
sample_nodes
.
begin
()
+
each_begin
[
i
]);
copy
(
src_index_threads
[
i
].
begin
(),
src_index_threads
[
i
].
end
(),
ret
[
cur_layer
].
src_index
.
begin
()
+
each_begin
[
i
]);
copy
(
src_index_threads
[
i
].
begin
(),
src_index_threads
[
i
].
end
(),
ret
[
cur_layer
].
src_index
.
begin
()
+
each_begin
[
i
]);
}
}
if
(
part_unique
){
for
(
int
i
=
0
;
i
<
threads
;
i
++
)
for
(
int
i
=
0
;
i
<
threads
;
i
++
)
node_s
.
insert
(
node_s_threads
[
i
].
begin
(),
node_s_threads
[
i
].
end
());
node_s
.
insert
(
node_s_threads
[
i
].
begin
(),
node_s_threads
[
i
].
end
());
ret
[
cur_layer
].
sample_nodes
.
assign
(
node_s
.
begin
(),
node_s
.
end
());
ret
[
cur_layer
].
sample_nodes
.
assign
(
node_s
.
begin
(),
node_s
.
end
());
}
ret
[
0
].
tot_time
+=
omp_get_wtime
()
-
tot_start_time
;
ret
[
0
].
tot_time
+=
omp_get_wtime
()
-
tot_start_time
;
ret
[
0
].
sample_edge_num
+=
ret
[
cur_layer
].
eid
.
size
();
ret
[
0
].
sample_edge_num
+=
ret
[
cur_layer
].
eid
.
size
();
py
::
gil_scoped_acquire
acquire
;
py
::
gil_scoped_acquire
acquire
;
}
}
void
ParallelSampler
::
neighbor_sample_from_nodes_static
(
th
::
Tensor
nodes
){
void
ParallelSampler
::
neighbor_sample_from_nodes_static
(
th
::
Tensor
nodes
,
bool
part_unique
){
for
(
int
i
=
0
;
i
<
num_layers
;
i
++
){
for
(
int
i
=
0
;
i
<
num_layers
;
i
++
){
if
(
i
==
0
)
neighbor_sample_from_nodes_static_layer
(
nodes
,
i
);
if
(
i
==
0
)
neighbor_sample_from_nodes_static_layer
(
nodes
,
i
,
part_unique
);
else
neighbor_sample_from_nodes_static_layer
(
vecToTensor
<
NodeIDType
>
(
ret
[
i
-
1
].
sample_nodes
),
i
);
else
neighbor_sample_from_nodes_static_layer
(
vecToTensor
<
NodeIDType
>
(
ret
[
i
-
1
].
sample_nodes
),
i
,
part_unique
);
}
}
}
}
...
...
starrygl/sample/sample_core/my_sample_test_static.py
View file @
b9450874
import
torch
import
torch
import
time
import
time
from
Utils
import
GraphData
from
.Utils
import
GraphData
seed
=
10
# 你可以选择任何整数作为种子
def
test
():
torch
.
manual_seed
(
seed
)
seed
=
10
# 你可以选择任何整数作为种子
torch
.
manual_seed
(
seed
)
num_nodes1
=
10
num_nodes1
=
10
fanout1
=
[
2
,
2
]
# index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
fanout1
=
[
2
,
2
]
# index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
edge_index1
=
torch
.
tensor
([[
1
,
5
,
7
,
9
,
2
,
4
,
6
,
7
,
8
,
0
,
1
,
6
,
2
,
0
,
1
,
3
,
5
,
8
,
9
,
7
,
4
,
8
,
2
,
3
,
5
,
8
],
edge_index1
=
torch
.
tensor
([[
1
,
5
,
7
,
9
,
2
,
4
,
6
,
7
,
8
,
0
,
1
,
6
,
2
,
0
,
1
,
3
,
5
,
8
,
9
,
7
,
4
,
8
,
2
,
3
,
5
,
8
],
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
,
4
,
4
,
4
,
4
,
5
,
6
,
6
,
7
,
7
,
8
,
9
]])
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
,
4
,
4
,
4
,
4
,
5
,
6
,
6
,
7
,
7
,
8
,
9
]])
edge_weight1
=
torch
.
tensor
([
2
,
1
,
2
,
1
,
8
,
6
,
3
,
1
,
1
,
1
,
1
,
5
,
1
,
1
,
2
,
1
,
1
,
1
,
1
,
5
,
1
,
2
,
2
,
2
,
1
,
1
])
.
double
()
edge_weight1
=
torch
.
tensor
([
2
,
1
,
2
,
1
,
8
,
6
,
3
,
1
,
1
,
1
,
1
,
5
,
1
,
1
,
2
,
1
,
1
,
1
,
1
,
5
,
1
,
2
,
2
,
2
,
1
,
1
])
.
double
()
g_data
=
GraphData
(
id
=
0
,
edge_index
=
edge_index1
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes1
]))
src
,
dst
=
edge_index1
row
=
torch
.
cat
([
src
,
dst
])
col
=
torch
.
cat
([
dst
,
src
])
edge_index1
=
torch
.
stack
([
row
,
col
])
edge_weight1
=
None
g_data
=
GraphData
(
id
=
0
,
edge_index
=
edge_index1
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes1
]))
# g_data.eid=None
from
neighbor_sampler
import
NeighborSampler
,
SampleType
edge_weight1
=
None
pre
=
time
.
time
()
# g_data.eid=None
sampler
=
NeighborSampler
(
num_nodes1
,
from
.neighbor_sampler
import
NeighborSampler
,
SampleType
pre
=
time
.
time
()
sampler
=
NeighborSampler
(
num_nodes1
,
num_layers
=
2
,
num_layers
=
2
,
fanout
=
fanout1
,
fanout
=
fanout1
,
edge_weight
=
edge_weight1
,
edge_weight
=
edge_weight1
,
...
@@ -25,39 +31,42 @@ sampler = NeighborSampler(num_nodes1,
...
@@ -25,39 +31,42 @@ sampler = NeighborSampler(num_nodes1,
workers
=
2
,
workers
=
2
,
graph_name
=
'a'
,
graph_name
=
'a'
,
policy
=
"uniform"
)
policy
=
"uniform"
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"init time:"
,
end
-
pre
)
print
(
"init time:"
,
end
-
pre
)
print
(
"tnb.neighbors:"
,
sampler
.
tnb
.
neighbors
)
print
(
"tnb.neighbors:"
,
sampler
.
tnb
.
neighbors
)
print
(
"tnb.eid:"
,
sampler
.
tnb
.
eid
)
print
(
"tnb.eid:"
,
sampler
.
tnb
.
eid
)
print
(
"tnb.deg:"
,
sampler
.
tnb
.
deg
)
print
(
"tnb.deg:"
,
sampler
.
tnb
.
deg
)
print
(
"tnb.weight:"
,
sampler
.
tnb
.
edge_weight
)
print
(
"tnb.weight:"
,
sampler
.
tnb
.
edge_weight
)
# row,col = edge_index1
# row,col = edge_index1
# update_edge_row = row
# update_edge_row = row
# update_edge_col = col
# update_edge_col = col
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('begin update')
# print('begin update')
# pre = time.time()
# pre = time.time()
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# end = time.time()
# print("update time:", end-pre)
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
pre
=
time
.
time
()
pre
=
time
.
time
()
out
=
sampler
.
sample_from_nodes
(
torch
.
tensor
([
1
,
2
]),
out
=
sampler
.
sample_from_nodes
(
torch
.
tensor
([
1
,
2
]),
with_outer_sample
=
SampleType
.
Whole
)
# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
with_outer_sample
=
SampleType
.
Whole
)
# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
end
=
time
.
time
()
end
=
time
.
time
()
print
(
'node1:
\t
'
,
out
[
0
]
.
sample_nodes
()
.
tolist
())
print
(
'node1:
\t
'
,
out
[
0
]
.
sample_nodes
()
.
tolist
())
print
(
'eid1:
\t
'
,
out
[
0
]
.
eid
()
.
tolist
())
print
(
'eid1:
\t
'
,
out
[
0
]
.
eid
()
.
tolist
())
print
(
'edge1:
\t
'
,
edge_index1
[:,
out
[
0
]
.
eid
()]
.
tolist
())
print
(
'edge1:
\t
'
,
edge_index1
[:,
out
[
0
]
.
eid
()]
.
tolist
())
print
(
'node2:
\t
'
,
out
[
1
]
.
sample_nodes
()
.
tolist
())
print
(
'node2:
\t
'
,
out
[
1
]
.
sample_nodes
()
.
tolist
())
print
(
'eid2:
\t
'
,
out
[
1
]
.
eid
()
.
tolist
())
print
(
'eid2:
\t
'
,
out
[
1
]
.
eid
()
.
tolist
())
print
(
'edge2:
\t
'
,
edge_index1
[:,
out
[
1
]
.
eid
()]
.
tolist
())
print
(
'edge2:
\t
'
,
edge_index1
[:,
out
[
1
]
.
eid
()]
.
tolist
())
print
(
"sample time:"
,
end
-
pre
)
print
(
"sample time:"
,
end
-
pre
)
\ No newline at end of file
if
__name__
==
"__main__"
:
test
()
\ No newline at end of file
starrygl/sample/sample_core/neighbor_sampler.py
View file @
b9450874
...
@@ -157,11 +157,12 @@ class NeighborSampler(BaseSampler):
...
@@ -157,11 +157,12 @@ class NeighborSampler(BaseSampler):
sampled_edge_index_list: the edge sampled
sampled_edge_index_list: the edge sampled
"""
"""
if
(
ts
is
None
):
if
(
ts
is
None
):
self
.
p_sampler
.
neighbor_sample_from_nodes
(
nodes
.
contiguous
(),
None
)
self
.
part_unique
=
True
self
.
p_sampler
.
neighbor_sample_from_nodes
(
nodes
.
contiguous
(),
None
,
self
.
part_unique
)
ret
=
self
.
p_sampler
.
get_ret
()
ret
=
self
.
p_sampler
.
get_ret
()
return
ret
return
ret
else
:
else
:
self
.
p_sampler
.
neighbor_sample_from_nodes
(
nodes
.
contiguous
(),
ts
.
float
()
.
contiguous
())
self
.
p_sampler
.
neighbor_sample_from_nodes
(
nodes
.
contiguous
(),
ts
.
float
()
.
contiguous
()
,
None
)
ret
=
self
.
p_sampler
.
get_ret
()
ret
=
self
.
p_sampler
.
get_ret
()
return
ret
return
ret
...
...
starrygl/sample/sample_core/ogb_demo_static.py
View file @
b9450874
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
ogb.nodeproppred
import
PygNodePropPredDataset
from
ogb.nodeproppred
import
PygNodePropPredDataset
from
torch_geometric
import
datasets
from
torch_geometric
import
datasets
import
time
import
time
from
Utils
import
GraphData
from
.
Utils
import
GraphData
def
load_ogb_dataset
(
name
,
data_path
):
def
load_ogb_dataset
(
name
,
data_path
):
dataset
=
PygNodePropPredDataset
(
name
=
name
,
root
=
data_path
)
dataset
=
PygNodePropPredDataset
(
name
=
name
,
root
=
data_path
)
...
@@ -18,56 +18,57 @@ def load_ogb_dataset(name, data_path):
...
@@ -18,56 +18,57 @@ def load_ogb_dataset(name, data_path):
node_data
[
'test_mask'
][
split_idx
[
"test"
]]
=
True
node_data
[
'test_mask'
][
split_idx
[
"test"
]]
=
True
return
g
,
node_data
return
g
,
node_data
g
,
node_data
=
load_ogb_dataset
(
'ogbn-products'
,
"/home/zlj/hzq/code/gnn/dataset/"
)
def
test
():
print
(
g
)
g
,
node_data
=
load_ogb_dataset
(
'ogbn-products'
,
"/home/zlj/hzq/code/gnn/dataset/"
)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
print
(
g
)
g_data
=
GraphData
(
id
=
1
,
edge_index
=
g
.
edge_index
,
data
=
g
,
partptr
=
torch
.
tensor
([
0
,
g
.
num_nodes
//
4
,
g
.
num_nodes
//
4
*
2
,
g
.
num_nodes
//
4
*
3
,
g
.
num_nodes
]))
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
g_data
=
GraphData
(
id
=
1
,
edge_index
=
g
.
edge_index
,
data
=
g
,
partptr
=
torch
.
tensor
([
0
,
g
.
num_nodes
//
4
,
g
.
num_nodes
//
4
*
2
,
g
.
num_nodes
//
4
*
3
,
g
.
num_nodes
]))
row
,
col
=
g
.
edge_index
row
,
col
=
g
.
edge_index
# edge_weight = torch.ones(g.num_edges).float()
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight = torch.ones(g.num_edges).float()
# edge_weight[indices] = 2.0
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight[indices] = 2.0
# g_data.eid = None
edge_weight
=
None
# g_data.eid = None
timestamp
=
None
edge_weight
=
None
timestamp
=
None
from
neighbor_sampler
import
NeighborSampler
,
SampleType
from
neighbor_sampler
import
get_neighbors
from
.neighbor_sampler
import
NeighborSampler
,
SampleType
from
.neighbor_sampler
import
get_neighbors
update_edge_row
=
row
update_edge_col
=
col
update_edge_row
=
row
update_edge_w
=
torch
.
DoubleTensor
([
i
for
i
in
range
(
g
.
num_edges
)])
update_edge_col
=
col
print
(
'begin update'
)
update_edge_w
=
torch
.
DoubleTensor
([
i
for
i
in
range
(
g
.
num_edges
)])
pre
=
time
.
time
()
# print('begin update')
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# pre = time.time()
end
=
time
.
time
()
# # update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
print
(
"update time:"
,
end
-
pre
)
# end = time.time()
# print("update time:", end-pre)
print
(
'begin tnb'
)
pre
=
time
.
time
()
print
(
'begin tnb'
)
tnb
=
get_neighbors
(
"a"
,
pre
=
time
.
time
()
tnb
=
get_neighbors
(
"a"
,
row
.
contiguous
(),
row
.
contiguous
(),
col
.
contiguous
(),
col
.
contiguous
(),
g
.
num_nodes
,
0
,
g
.
num_nodes
,
0
,
g_data
.
eid
,
g_data
.
eid
,
edge_weight
,
edge_weight
,
timestamp
)
timestamp
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"init tnb time:"
,
end
-
pre
)
print
(
"init tnb time:"
,
end
-
pre
)
torch
.
save
(
tnb
,
"/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my"
)
#
torch.save(tnb, "/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# print('begin load')
# print('begin load')
# pre = time.time()
# pre = time.time()
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# end = time.time()
# end = time.time()
# print("load time:", end-pre)
# print("load time:", end-pre)
print
(
'begin init'
)
print
(
'begin init'
)
pre
=
time
.
time
()
pre
=
time
.
time
()
sampler
=
NeighborSampler
(
g
.
num_nodes
,
sampler
=
NeighborSampler
(
g
.
num_nodes
,
tnb
=
tnb
,
tnb
=
tnb
,
num_layers
=
2
,
num_layers
=
2
,
fanout
=
[
100
,
100
],
fanout
=
[
100
,
100
],
...
@@ -75,30 +76,34 @@ sampler = NeighborSampler(g.num_nodes,
...
@@ -75,30 +76,34 @@ sampler = NeighborSampler(g.num_nodes,
workers
=
10
,
workers
=
10
,
graph_name
=
'a'
,
graph_name
=
'a'
,
policy
=
"uniform"
)
policy
=
"uniform"
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"init time:"
,
end
-
pre
)
print
(
"init time:"
,
end
-
pre
)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# end = time.time()
# print("init time:", end-pre)
# print("init time:", end-pre)
pre
=
time
.
time
()
pre
=
time
.
time
()
out
=
sampler
.
sample_from_nodes
(
torch
.
tensor
(
range
(
g
.
num_nodes
//
4
,
g
.
num_nodes
//
4
+
600000
)),
with_outer_sample
=
SampleType
.
Inner
)
# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
out
=
sampler
.
sample_from_nodes
(
torch
.
tensor
(
range
(
g
.
num_nodes
//
4
,
g
.
num_nodes
//
4
+
600000
)))
# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# node = out.node
# edge = [out.row, out.col]
# edge = [out.row, out.col]
end
=
time
.
time
()
end
=
time
.
time
()
print
(
'node1:
\t
'
,
out
[
0
]
.
sample_nodes
())
print
(
'node1:
\t
'
,
out
[
0
]
.
sample_nodes
())
print
(
'eid1:
\t
'
,
out
[
0
]
.
eid
())
print
(
'eid1:
\t
'
,
out
[
0
]
.
eid
())
print
(
'edge1:
\t
'
,
g
.
edge_index
[:,
out
[
0
]
.
eid
()])
print
(
'edge1:
\t
'
,
g
.
edge_index
[:,
out
[
0
]
.
eid
()])
print
(
'node2:
\t
'
,
out
[
1
]
.
sample_nodes
())
print
(
'node2:
\t
'
,
out
[
1
]
.
sample_nodes
())
print
(
'eid2:
\t
'
,
out
[
1
]
.
eid
())
print
(
'eid2:
\t
'
,
out
[
1
]
.
eid
())
print
(
'edge2:
\t
'
,
g
.
edge_index
[:,
out
[
1
]
.
eid
()])
print
(
'edge2:
\t
'
,
g
.
edge_index
[:,
out
[
1
]
.
eid
()])
print
(
"sample time"
,
end
-
pre
)
print
(
"sample time"
,
end
-
pre
)
\ No newline at end of file
if
__name__
==
"__main__"
:
test
()
\ No newline at end of file
starrygl/sample/sample_core/reddit_demo_before.py
View file @
b9450874
...
@@ -3,8 +3,6 @@ import random
...
@@ -3,8 +3,6 @@ import random
import
pandas
as
pd
import
pandas
as
pd
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
ogb.nodeproppred
import
PygNodePropPredDataset
from
torch_geometric.datasets
import
Reddit
import
time
import
time
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -25,7 +23,7 @@ class NegLinkInductiveSampler:
...
@@ -25,7 +23,7 @@ class NegLinkInductiveSampler:
def
sample
(
self
,
n
):
def
sample
(
self
,
n
):
return
np
.
random
.
choice
(
self
.
nodes
,
size
=
n
)
return
np
.
random
.
choice
(
self
.
nodes
,
size
=
n
)
def
load_reddit_dataset
(
data_path
):
def
load_reddit_dataset
():
df
=
pd
.
read_csv
(
'/mnt/data/hzq/DATA/{}/edges.csv'
.
format
(
"REDDIT"
))
df
=
pd
.
read_csv
(
'/mnt/data/hzq/DATA/{}/edges.csv'
.
format
(
"REDDIT"
))
num_nodes
=
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
num_nodes
=
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
src
=
torch
.
tensor
(
df
[
'src'
]
.
to_numpy
(
dtype
=
int
))
src
=
torch
.
tensor
(
df
[
'src'
]
.
to_numpy
(
dtype
=
int
))
...
@@ -35,6 +33,16 @@ def load_reddit_dataset(data_path):
...
@@ -35,6 +33,16 @@ def load_reddit_dataset(data_path):
g
=
GraphData
(
0
,
edge_index
,
timestamp
=
timestamp
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes
]))
g
=
GraphData
(
0
,
edge_index
,
timestamp
=
timestamp
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes
]))
return
g
,
df
return
g
,
df
def
load_gdelt_dataset
():
df
=
pd
.
read_csv
(
'/mnt/data/hzq/DATA/{}/edges.csv'
.
format
(
"GDELT"
))
num_nodes
=
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
src
=
torch
.
tensor
(
df
[
'src'
]
.
to_numpy
(
dtype
=
int
))
dst
=
torch
.
tensor
(
df
[
'dst'
]
.
to_numpy
(
dtype
=
int
))
edge_index
=
torch
.
stack
([
src
,
dst
])
timestamp
=
torch
.
tensor
(
df
[
'time'
])
.
float
()
g
=
GraphData
(
0
,
edge_index
,
timestamp
=
timestamp
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes
]))
return
g
,
df
def
test
():
def
test
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -43,6 +51,7 @@ def test():
...
@@ -43,6 +51,7 @@ def test():
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
600
,
help
=
'path to config file'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
600
,
help
=
'path to config file'
)
parser
.
add_argument
(
'--num_thread'
,
type
=
int
,
default
=
64
,
help
=
'number of thread'
)
parser
.
add_argument
(
'--num_thread'
,
type
=
int
,
default
=
64
,
help
=
'number of thread'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset
=
"gdelt"
#"reddit"#"gdelt"
seed
=
10
seed
=
10
torch
.
manual_seed
(
seed
)
# 为CPU设置随机种子
torch
.
manual_seed
(
seed
)
# 为CPU设置随机种子
...
@@ -53,7 +62,7 @@ def test():
...
@@ -53,7 +62,7 @@ def test():
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
g_data
,
df
=
load_
reddit_dataset
(
"/mnt/data/hzq/DATA/REDDIT"
)
g_data
,
df
=
load_
gdelt_dataset
(
)
print
(
g_data
)
print
(
g_data
)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# import random
...
@@ -84,11 +93,13 @@ def test():
...
@@ -84,11 +93,13 @@ def test():
row
=
row
[
ind
]
row
=
row
[
ind
]
col
=
col
[
ind
]
col
=
col
[
ind
]
print
(
row
,
col
)
print
(
row
,
col
)
g2
=
GraphData
(
0
,
torch
.
stack
([
row
,
col
]),
timestamp
=
timestamp
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
]))
print
(
g2
)
pre
=
time
.
time
()
pre
=
time
.
time
()
tnb
=
get_neighbors
(
"reddit"
,
row
.
contiguous
(),
col
.
contiguous
(),
g_data
.
num_nodes
,
0
,
eid
,
edge_weight
,
timestamp
)
tnb
=
get_neighbors
(
dataset
,
row
.
contiguous
(),
col
.
contiguous
(),
g_data
.
num_nodes
,
0
,
eid
,
edge_weight
,
timestamp
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"init tnb time:"
,
end
-
pre
)
print
(
"init tnb time:"
,
end
-
pre
)
torch
.
save
(
tnb
,
"tnb_reddit_before.my"
)
# torch.save(tnb, "tnb_{}_before.my".format(dataset), pickle_protocol=4
)
pre
=
time
.
time
()
pre
=
time
.
time
()
...
@@ -138,17 +149,17 @@ def test():
...
@@ -138,17 +149,17 @@ def test():
sam_time
+=
outi
[
0
]
.
sample_time
sam_time
+=
outi
[
0
]
.
sample_time
# print(outi[0].sample_edge_num)
# print(outi[0].sample_edge_num)
sam_edge
+=
outi
[
0
]
.
sample_edge_num
sam_edge
+=
outi
[
0
]
.
sample_edge_num
out
.
append
(
outi
)
#
out.append(outi)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"row"
,
out
[
23
][
0
]
.
row
())
#
print("row", out[23][0].row())
print
(
"sample time"
,
end
-
pre
)
print
(
"sample time"
,
end
-
pre
)
print
(
"tot_time"
,
tot_time
)
print
(
"tot_time"
,
tot_time
)
print
(
"sam_time"
,
sam_time
)
print
(
"sam_time"
,
sam_time
)
print
(
"sam_edge"
,
sam_edge
)
print
(
"sam_edge"
,
sam_edge
)
print
(
'eid_list:'
,
out
[
23
][
0
]
.
eid
())
#
print('eid_list:', out[23][0].eid())
# print('delta_ts_list:', out[10][0].delta_ts)
# print('delta_ts_list:', out[10][0].delta_ts)
print
(
'node:'
,
out
[
23
][
0
]
.
sample_nodes
())
#
print('node:', out[23][0].sample_nodes())
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('eid_list:', out[23][1].eid)
# print('eid_list:', out[23][1].eid)
# print('node:', out[23][1].sample_nodes)
# print('node:', out[23][1].sample_nodes)
...
@@ -158,6 +169,7 @@ def test():
...
@@ -158,6 +169,7 @@ def test():
# print("min_than_ten_sum", min_than_ten_sum)
# print("min_than_ten_sum", min_than_ten_sum)
# print("seed_node_sum", seed_node_sum)
# print("seed_node_sum", seed_node_sum)
# print("predict edge_num", (seed_node_sum-min_than_ten)*9+min_than_ten_sum)
# print("predict edge_num", (seed_node_sum-min_than_ten)*9+min_than_ten_sum)
print
(
'吞吐量 : {:.4f}'
.
format
(
sam_edge
/
(
end
-
pre
)))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test
()
test
()
\ No newline at end of file
starrygl/sample/sample_core/reddit_demo_static.py
0 → 100644
View file @
b9450874
import
argparse
import
random
import
pandas
as
pd
import
numpy
as
np
import
torch
import
time
from
tqdm
import
tqdm
from
.Utils
import
GraphData
def
load_reddit_dataset
():
df
=
pd
.
read_csv
(
'/mnt/data/hzq/DATA/{}/edges.csv'
.
format
(
"REDDIT"
))
num_nodes
=
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
src
=
torch
.
tensor
(
df
[
'src'
]
.
to_numpy
(
dtype
=
int
))
dst
=
torch
.
tensor
(
df
[
'dst'
]
.
to_numpy
(
dtype
=
int
))
row
=
torch
.
cat
([
src
,
dst
])
col
=
torch
.
cat
([
dst
,
src
])
edge_index
=
torch
.
stack
([
row
,
col
])
timestamp
=
torch
.
tensor
(
df
[
'time'
])
.
float
()
g
=
GraphData
(
0
,
edge_index
,
timestamp
=
None
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes
]))
return
g
def
load_gdelt_dataset
():
df
=
pd
.
read_csv
(
'/mnt/data/hzq/DATA/{}/edges.csv'
.
format
(
"GDELT"
))
num_nodes
=
max
(
int
(
df
[
'src'
]
.
max
()),
int
(
df
[
'dst'
]
.
max
()))
+
1
src
=
torch
.
tensor
(
df
[
'src'
]
.
to_numpy
(
dtype
=
int
))
dst
=
torch
.
tensor
(
df
[
'dst'
]
.
to_numpy
(
dtype
=
int
))
row
=
torch
.
cat
([
src
,
dst
])
col
=
torch
.
cat
([
dst
,
src
])
edge_index
=
torch
.
stack
([
row
,
col
])
timestamp
=
torch
.
tensor
(
df
[
'time'
])
.
float
()
g
=
GraphData
(
0
,
edge_index
,
timestamp
=
None
,
data
=
None
,
partptr
=
torch
.
tensor
([
0
,
num_nodes
]))
return
g
def
load_ogb_dataset
():
from
ogb.nodeproppred
import
PygNodePropPredDataset
dataset
=
PygNodePropPredDataset
(
name
=
'ogbn-products'
,
root
=
"/home/zlj/hzq/code/gnn/dataset/"
)
split_idx
=
dataset
.
get_idx_split
()
g
=
dataset
[
0
]
n_node
=
g
.
num_nodes
node_data
=
{}
node_data
[
'train_mask'
]
=
torch
.
zeros
(
n_node
,
dtype
=
torch
.
bool
)
node_data
[
'val_mask'
]
=
torch
.
zeros
(
n_node
,
dtype
=
torch
.
bool
)
node_data
[
'test_mask'
]
=
torch
.
zeros
(
n_node
,
dtype
=
torch
.
bool
)
node_data
[
'train_mask'
][
split_idx
[
"train"
]]
=
True
node_data
[
'val_mask'
][
split_idx
[
"valid"
]]
=
True
node_data
[
'test_mask'
][
split_idx
[
"test"
]]
=
True
src
,
dst
=
g
.
edge_index
row
=
torch
.
cat
([
src
,
dst
])
col
=
torch
.
cat
([
dst
,
src
])
edge_index
=
torch
.
stack
([
row
,
col
])
g
=
GraphData
(
id
=
0
,
edge_index
=
edge_index
,
data
=
g
,
partptr
=
torch
.
tensor
([
0
,
g
.
num_nodes
]))
return
g
# , node_data
def
test
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--data'
,
type
=
str
,
help
=
'dataset name'
,
default
=
"REDDIT"
)
# parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
600
,
help
=
'path to config file'
)
# parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args
=
parser
.
parse_args
()
seed
=
10
torch
.
manual_seed
(
seed
)
# 为CPU设置随机种子
torch
.
cuda
.
manual_seed
(
seed
)
# 为当前GPU设置随机种子
torch
.
cuda
.
manual_seed_all
(
seed
)
# if you are using multi-GPU,为所有GPU设置随机种子
np
.
random
.
seed
(
seed
)
# Numpy module.
random
.
seed
(
seed
)
# Python random module.
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
deterministic
=
True
g_data
=
load_reddit_dataset
()
print
(
g_data
)
from
.neighbor_sampler
import
NeighborSampler
,
get_neighbors
# print('begin tnb')
# row, col = g_data.edge_index
# pre = time.time()
# tnb = get_neighbors("a",
# row.contiguous(),
# col.contiguous(),
# g_data.num_nodes, 0,
# g_data.eid,
# None,
# None)
# end = time.time()
# print("init tnb time:", end-pre)
pre
=
time
.
time
()
sampler
=
NeighborSampler
(
g_data
.
num_nodes
,
num_layers
=
2
,
fanout
=
[
10
,
10
],
graph_data
=
g_data
,
workers
=
32
,
policy
=
"uniform"
,
graph_name
=
'a'
,
is_distinct
=
1
)
end
=
time
.
time
()
print
(
"init time:"
,
end
-
pre
)
# print(sampler.tnb.deg[0:10])
n_list
=
[]
for
i
in
range
(
g_data
.
num_nodes
.
item
()
//
args
.
batch_size
):
if
i
+
args
.
batch_size
<
g_data
.
num_nodes
.
item
():
n_list
.
append
(
range
(
i
*
args
.
batch_size
,
i
*
args
.
batch_size
+
args
.
batch_size
))
else
:
n_list
.
append
(
range
(
i
*
args
.
batch_size
,
g_data
.
num_nodes
.
item
()))
# print(n_list)
out
=
[]
tot_time
=
0
sam_time
=
0
sam_edge
=
0
sam_node
=
0
pre
=
time
.
time
()
for
i
,
nodes
in
tqdm
(
enumerate
(
n_list
),
total
=
g_data
.
num_nodes
.
item
()
//
args
.
batch_size
):
# for nodes in n_list:
root_nodes
=
torch
.
tensor
(
nodes
)
.
long
()
outi
=
sampler
.
sample_from_nodes
(
root_nodes
)
sam_node
+=
outi
[
0
]
.
sample_nodes
()
.
size
(
0
)
sam_node
+=
outi
[
1
]
.
sample_nodes
()
.
size
(
0
)
sam_edge
+=
outi
[
0
]
.
sample_edge_num
end
=
time
.
time
()
print
(
"sample time"
,
end
-
pre
)
print
(
"sam_edge"
,
sam_edge
)
print
(
"sam_node"
,
sam_node
)
print
(
'边吞吐量 : {:.4f}'
.
format
(
sam_edge
/
(
end
-
pre
)))
print
(
'点吞吐量 : {:.4f}'
.
format
(
sam_node
/
(
end
-
pre
)))
if
__name__
==
"__main__"
:
test
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment