Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
a70d1b3b
Commit
a70d1b3b
authored
Dec 19, 2023
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add sample module
parent
abce8f53
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
295 additions
and
91 deletions
+295
-91
a.py
+0
-20
b.py
+133
-0
data_maker.py
+2
-0
nohup.out
+14
-0
starrygl/distributed/utils.py
+33
-2
starrygl/sample/batch_data.py
+41
-25
starrygl/sample/data_loader.py
+13
-5
starrygl/sample/graph_core/__init__.py
+16
-5
starrygl/sample/memory/shared_mailbox.py
+6
-3
starrygl/sample/stream_manager.py
+5
-2
train_tgnn.py
+32
-29
No files found.
a.py
deleted
100644 → 0
View file @
abce8f53
def
foo
():
a
=
1
def
fa
():
print
(
a
)
a
+=
1
print
(
a
)
def
fb
(
a
):
def
apply
():
print
(
a
)
return
apply
fc
=
lambda
:
print
(
a
)
return
fa
,
fb
(
a
),
fc
fa
,
fb
,
fc
=
foo
()
fa
()
fb
()
fc
()
b.py
0 → 100644
View file @
a70d1b3b
import
argparse
import
os
import
sys
from
os.path
import
abspath
,
join
,
dirname
from
starrygl.distributed.context
import
DistributedContext
from
starrygl.distributed.utils
import
DistIndex
from
starrygl.module.modules
import
GeneralModel
from
starrygl.module.utils
import
parse_config
from
starrygl.sample.graph_core
import
DataSet
,
GraphData
,
TemporalNeighborSampleGraph
from
starrygl.sample.memory.shared_mailbox
import
SharedMailBox
from
starrygl.sample.sample_core.base
import
NegativeSampling
from
starrygl.sample.sample_core.neighbor_sampler
import
NeighborSampler
from
starrygl.sample.part_utils.partition_tgnn
import
partition_load
import
torch
import
time
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.distributed
import
init_process_group
,
destroy_process_group
import
os
from
starrygl.sample.data_loader
import
DistributedDataLoader
from
starrygl.sample.batch_data
import
SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser
=
argparse
.
ArgumentParser
(
description
=
"RPC Reinforcement Learning Example"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
'--rank'
,
default
=
0
,
type
=
str
,
metavar
=
'W'
,
help
=
'name of dataset'
)
parser
.
add_argument
(
'--world_size'
,
default
=
1
,
type
=
int
,
metavar
=
'W'
,
help
=
'number of negative samples'
)
args
=
parser
.
parse_args
()
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
import
torch
import
time
import
random
import
dgl
import
numpy
as
np
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
rank
)
os
.
environ
[
"RANK"
]
=
str
(
args
.
rank
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
args
.
world_size
)
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
0
)
os
.
environ
[
"MASTER_ADDR"
]
=
'127.0.0.1'
os
.
environ
[
"MASTER_PORT"
]
=
'9337'
def
seed_everything
(
seed
=
42
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
seed_everything
(
1234
)
def
main
():
use_cuda
=
True
sample_param
,
memory_param
,
gnn_param
,
train_param
=
parse_config
(
'./config/TGN.yml'
)
torch
.
set_num_threads
(
12
)
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
device_id
=
torch
.
cuda
.
current_device
()
print
(
'use cuda on'
,
device_id
)
pdata
=
partition_load
(
"./dataset/here/WIKI"
,
algo
=
"metis_for_tgnn"
)
graph
=
GraphData
(
pdata
=
pdata
)
#dist.barrier()
#for i in range(100):
# print(i)
dist
.
barrier
()
idx
=
((
graph
.
eids_mapper
>>
48
)
.
int
()
&
0xFFFF
)
print
((
idx
==
0
)
.
nonzero
()
.
shape
,(
idx
==
1
)
.
nonzero
()
.
shape
)
t1
=
time
.
time
()
"""
fut = []
for i in range(1000):
#print(i)
out = graph.edge_attr.index_select(graph.eids_mapper[(idx== 0)|(idx ==1)].to('cuda'))
fut.append(out)
#out.wait()
#out.value()
if i>0 and i
%100
==0:
f = torch.futures.collect_all(fut)
f.wait()
f.value()
fut = []
"""
partptr
=
torch
.
tensor
([
((
i
&
0xFFFF
)
<<
48
)
for
i
in
range
(
3
)
],
device
=
'cuda'
)
for
i
in
range
(
1000
):
if
i
%
100
==
0
:
idx
=
graph
.
eids_mapper
.
to
(
'cuda'
)
idx
,
inv
=
idx
.
unique
(
return_inverse
=
True
)
ind
=
torch
.
searchsorted
(
idx
,
partptr
,
right
=
False
)
len
=
ind
[
1
:]
-
ind
[:
-
1
]
gatherlen
=
torch
.
empty
([
2
],
dtype
=
torch
.
long
,
device
=
'cuda'
)
dist
.
all_to_all_single
(
gatherlen
,
len
)
query_idx
=
torch
.
empty
([
gatherlen
.
sum
()],
dtype
=
torch
.
long
,
device
=
'cuda'
)
input_s
=
list
(
len
)
output_s
=
list
(
gatherlen
)
dist
.
all_to_all_single
(
query_idx
,
idx
,
output_s
,
input_s
)
input_f
=
graph
.
edge_attr
.
accessor
.
data
[
DistIndex
(
query_idx
)
.
loc
]
f
=
torch
.
empty
([
idx
.
shape
[
0
],
graph
.
edge_attr
.
accessor
.
data
.
shape
[
1
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
dist
.
all_to_all_single
(
f
,
input_f
,
input_s
,
output_s
)
torch
.
cuda
.
synchronize
()
t2
=
time
.
time
()
-
t1
print
(
t2
)
#dist.barrier()
ctx
.
shutdown
()
if
__name__
==
"__main__"
:
main
()
data_maker.py
View file @
a70d1b3b
...
...
@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict
=
edge_weight_dict
)
partition_save
(
'./dataset/here/'
+
data_name
,
data
,
2
,
'metis_for_tgnn'
,
edge_weight_dict
=
edge_weight_dict
)
partition_save
(
'./dataset/here/'
+
data_name
,
data
,
4
,
'metis_for_tgnn'
,
edge_weight_dict
=
edge_weight_dict
)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
...
...
nohup.out
0 → 100644
View file @
a70d1b3b
ERROR:root:unable to import libstarrygl.so, some features may not be available.
the number of nodes in graph is 1980, the number of edges in graph is 1293103
directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_1' not empty and cleared
running partition algorithm: metis_for_tgnn
saving partition data: 1/1
running partition algorithm: metis_for_tgnn
saving partition data: 1/2
saving partition data: 2/2
creating directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_4'
running partition algorithm: metis_for_tgnn
saving partition data: 1/4
saving partition data: 2/4
saving partition data: 3/4
saving partition data: 4/4
starrygl/distributed/utils.py
View file @
a70d1b3b
...
...
@@ -15,6 +15,7 @@ class TensorAccessor:
self
.
_data
=
data
self
.
_ctx
=
DistributedContext
.
get_default_context
()
self
.
_rref
=
rpc
.
RRef
(
data
)
self
.
_rref
.
confirmed_by_owner
@property
def
data
(
self
):
...
...
@@ -28,6 +29,20 @@ class TensorAccessor:
def
ctx
(
self
):
return
self
.
_ctx
def
all_gather_index
(
self
,
index
,
input_split
)
->
Tensor
:
out_split
=
torch
.
empty_like
(
input_split
)
torch
.
distributed
.
all_to_all_single
(
out_split
,
input_split
)
input_split
=
list
(
input_split
)
output
=
torch
.
empty
([
out_split
.
sum
()],
dtype
=
index
.
dtype
,
device
=
index
.
device
)
out_split
=
list
(
out_split
)
torch
.
distributed
.
all_to_all_single
(
output
,
index
,
out_split
,
input_split
)
return
output
,
out_split
,
input_split
def
all_gather_data
(
self
,
index
,
input_split
,
out_split
):
output
=
torch
.
empty
([
int
(
Tensor
(
out_split
)
.
sum
()
.
item
()),
*
self
.
_data
.
shape
[
1
:]],
dtype
=
self
.
_data
.
dtype
,
device
=
'cuda'
)
#self._data.device)
torch
.
distributed
.
all_to_all_single
(
output
,
self
.
data
[
index
.
to
(
self
.
data
.
device
)]
.
to
(
'cuda'
),
output_split_sizes
=
out_split
,
input_split_sizes
=
input_split
)
return
output
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
...
...
@@ -101,7 +116,6 @@ class DistributedTensor:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
self
.
accessor
=
TensorAccessor
(
data
)
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
# self.num_parts = len(self.rrefs)
local_sizes
=
[]
...
...
@@ -110,6 +124,7 @@ class DistributedTensor:
local_sizes
.
append
(
n
)
self
.
num_nodes
=
DistInt
(
local_sizes
)
self
.
num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
distptr
=
torch
.
tensor
([((
part_ids
&
0xFFFF
)
<<
48
)
for
part_ids
in
range
(
self
.
num_parts
()
+
1
)],
device
=
'cuda'
)
#data.device)
@property
...
...
@@ -129,7 +144,18 @@ class DistributedTensor:
@property
def
ctx
(
self
):
return
self
.
accessor
.
ctx
def
gather_select_index
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
data
=
dist_index
.
data
posptr
=
torch
.
searchsorted
(
data
,
self
.
distptr
,
right
=
False
)
input_split
=
posptr
[
1
:]
-
posptr
[:
-
1
]
return
self
.
accessor
.
all_gather_index
(
DistIndex
(
data
)
.
loc
,
input_split
)
def
scatter_data
(
self
,
local_index
,
input_split
,
out_split
):
return
self
.
accessor
.
all_gather_data
(
local_index
,
input_split
=
input_split
,
out_split
=
out_split
)
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
...
...
@@ -139,6 +165,10 @@ class DistributedTensor:
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
self
.
num_parts
()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f
=
self
.
accessor
.
async_index_select
(
0
,
index
[
part_idx
==
i
],
self
.
rrefs
[
i
])
futs
.
append
(
f
)
...
...
@@ -150,6 +180,7 @@ class DistributedTensor:
result
=
torch
.
empty
(
part_idx
.
size
(
0
),
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
,
)
result
[
part_idx
==
i
]
=
t
return
result
return
torch
.
futures
.
collect_all
(
futs
)
.
then
(
callback
)
...
...
starrygl/sample/batch_data.py
View file @
a70d1b3b
from
typing
import
List
,
Tuple
import
torch
import
torch.distributed
as
dist
from
starrygl.distributed.utils
import
DistributedTensor
from
starrygl.module.memorys
import
MailBox
from
starrygl.sample.graph_core
import
DataSet
from
starrygl.sample.graph_core
import
GraphData
from
starrygl.sample.sample_core.base
import
BaseSampler
,
NegativeSampling
...
...
@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return
mfgs
def
to_block
(
graph
:
GraphData
,
data
,
sample_out
,
mailbox
=
None
,
device
=
torch
.
device
(
'cuda'
)):
def
to_block
(
graph
:
GraphData
,
data
,
sample_out
,
mailbox
:
MailBox
=
None
,
device
=
torch
.
device
(
'cuda'
)):
if
len
(
sample_out
)
>
1
:
sample_out
,
metadata
=
sample_out
...
...
@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
dist_eid
,
eid_inv
=
dist_eid
.
unique
(
return_inverse
=
True
)
src_node
=
graph
.
sample_graph
[
'edge_index'
][
0
,
eid_tensor
*
2
]
.
to
(
graph
.
nids_mapper
.
device
)
src_ts
=
None
edge_feat
=
graph
.
_get_edge_attr
(
dist_eid
)
if
metadata
is
None
:
root_node
=
data
.
nodes
.
to
(
graph
.
nids_mapper
.
device
)
root_len
=
[
root_node
.
shape
[
0
]]
...
...
@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
nid_tensor
=
torch
.
cat
([
root_node
,
src_node
],
dim
=
0
)
dist_nid
=
nid_mapper
[
nid_tensor
]
.
to
(
device
)
dist_nid
,
nid_inv
=
dist_nid
.
unique
(
return_inverse
=
True
)
node_feat
=
graph
.
_get_node_attr
(
dist_nid
)
if
isinstance
(
graph
.
edge_attr
,
DistributedTensor
):
local_index
,
input_split
,
output_split
=
graph
.
edge_attr
.
gather_select_index
(
dist_eid
)
edge_feat
=
graph
.
edge_attr
.
scatter_data
(
local_index
,
input_split
=
input_split
,
out_split
=
output_split
)
else
:
edge_feat
=
graph
.
_get_edge_attr
(
dist_eid
)
local_index
=
None
if
isinstance
(
graph
.
x
,
DistributedTensor
):
local_index
,
input_split
,
output_split
=
graph
.
x
.
gather_select_index
(
dist_nid
)
node_feat
=
graph
.
x
.
scatter_data
(
local_index
,
input_split
=
input_split
,
out_split
=
output_split
)
else
:
node_feat
=
graph
.
_get_node_attr
(
dist_nid
)
if
mailbox
is
not
None
:
mem
=
mailbox
.
_get_memory
(
dist_nid
)
if
torch
.
distributed
.
get_world_size
()
>
1
:
if
node_feat
is
None
:
local_index
,
input_split
,
output_split
=
mailbox
.
node_memory
.
gather_select_index
(
dist_nid
)
mem
=
mailbox
.
gather_memory
(
local_index
,
input_split
,
output_split
)
else
:
mem
=
mailbox
.
get_memory
(
dist_nid
)
else
:
mem
=
None
...
...
@@ -120,28 +136,28 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
return
data
,
mfgs
,
metadata
data
,
mfgs
,
metadata
=
build_block
()
if
dist
.
get_world_size
()
>
1
:
if
(
node_feat
is
None
):
node_feat
=
torch
.
futures
.
Future
()
node_feat
.
set_result
(
None
)
if
(
edge_feat
is
None
):
edge_feat
=
torch
.
futures
.
Future
()
edge_feat
.
set_result
(
None
)
if
(
mem
is
None
):
mem
=
torch
.
futures
.
Future
()
mem
.
set_result
(
None
)
def
callback
(
fs
,
mfgs
,
dist_nid
,
dist_eid
):
node_feat
,
edge_feat
,
mem_embedding
=
fs
.
value
()
node_feat
=
node_feat
.
value
()
edge_feat
=
edge_feat
.
value
()
mem_embedding
=
mem_embedding
.
value
()
return
prepare_input
(
node_feat
,
edge_feat
,
mem_embedding
,
mfgs
,
dist_nid
,
dist_eid
)
cal
=
lambda
fut
:
callback
(
fs
=
fut
,
mfgs
=
mfgs
,
dist_nid
=
dist_nid
,
dist_eid
=
dist_eid
)
return
data
,
torch
.
futures
.
collect_all
([
node_feat
,
edge_feat
,
mem
])
.
then
(
cal
),
metadata
else
:
mfgs
=
prepare_input
(
node_feat
,
edge_feat
,
mem
,
mfgs
,
dist_nid
,
dist_eid
)
#
if dist.get_world_size() > 1:
#
if(node_feat is None):
#
node_feat = torch.futures.Future()
#
node_feat.set_result(None)
#
if(edge_feat is None):
#
edge_feat = torch.futures.Future()
#
edge_feat.set_result(None)
#
if(mem is None):
#
mem = torch.futures.Future()
#
mem.set_result(None)
#
def callback(fs,mfgs,dist_nid,dist_eid):
#
node_feat,edge_feat,mem_embedding = fs.value()
#
node_feat = node_feat.value()
#
edge_feat = edge_feat.value()
#
mem_embedding = mem_embedding.value()
#
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
#
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
#
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
#
else:
mfgs
=
prepare_input
(
node_feat
,
edge_feat
,
mem
,
mfgs
,
dist_nid
,
dist_eid
)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return
data
,
mfgs
,
metadata
return
data
,
mfgs
,
metadata
def
graph_sample
(
graph
,
sampler
:
BaseSampler
,
...
...
starrygl/sample/data_loader.py
View file @
a70d1b3b
...
...
@@ -66,7 +66,8 @@ class DistributedDataLoader:
if
train
is
True
:
self
.
_get_expected_idx
(
self
.
dataset
.
len
)
else
:
self
.
expected_idx
=
int
(
math
.
ceil
(
self
.
dataset
.
len
/
self
.
batch_size
))
self
.
_get_expected_idx
(
self
.
dataset
.
len
,
op
=
dist
.
ReduceOp
.
MAX
)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def
__iter__
(
self
):
if
self
.
chunk_size
is
None
:
...
...
@@ -102,18 +103,19 @@ class DistributedDataLoader:
self
.
neg_sampler
.
set_next_pos
(
self
.
current_pos
)
return
self
def
_get_expected_idx
(
self
,
data_size
):
def
_get_expected_idx
(
self
,
data_size
,
op
=
dist
.
ReduceOp
.
MIN
):
world_size
=
dist
.
get_world_size
()
self
.
expected_idx
=
data_size
//
self
.
batch_size
if
self
.
drop_last
is
True
else
int
(
math
.
ceil
(
data_size
/
self
.
batch_size
))
if
dist
.
get_world_size
()
>
1
:
num_epochs
=
torch
.
tensor
([
self
.
expected_idx
],
dtype
=
torch
.
long
,
device
=
self
.
device
)
dist
.
all_reduce
(
num_epochs
,
op
=
dist
.
ReduceOp
.
MIN
)
print
(
num_epochs
)
dist
.
all_reduce
(
num_epochs
,
op
=
op
)
self
.
expected_idx
=
int
(
num_epochs
.
item
())
def
_next_data
(
self
):
if
self
.
current_pos
>=
self
.
dataset
.
len
:
return
None
return
self
.
input_dataset
.
_get_empty
()
if
self
.
current_pos
+
self
.
batch_size
>
self
.
input_dataset
.
len
:
if
self
.
drop_last
:
...
...
@@ -132,7 +134,7 @@ class DistributedDataLoader:
return
next_data
def
__next__
(
self
):
if
(
dist
.
get_world_size
()
==
1
):
if
(
dist
.
get_world_size
()
>
0
):
if
self
.
recv_idxs
<
self
.
expected_idx
:
data
=
self
.
_next_data
()
batch_data
=
graph_sample
(
self
.
graph
,
...
...
@@ -165,6 +167,12 @@ class DistributedDataLoader:
next_data
,
self
.
neg_sampler
,
self
.
mailbox
,
self
.
device
)
batch_data
[
1
]
.
wait
()
self
.
submitted
=
self
.
submitted
+
1
self
.
num_pending
=
self
.
num_pending
+
1
self
.
recv_idxs
+=
1
self
.
num_pending
-=
1
return
batch_data
[
0
],
batch_data
[
1
]
.
value
(),
batch_data
[
2
]
self
.
result_queue
.
append
(
batch_data
)
self
.
submitted
=
self
.
submitted
+
1
self
.
num_pending
=
self
.
num_pending
+
1
...
...
starrygl/sample/graph_core/__init__.py
View file @
a70d1b3b
...
...
@@ -23,16 +23,16 @@ class GraphData():
world_size
=
dist
.
get_world_size
()
if
hasattr
(
pdata
,
'x'
)
and
pdata
.
x
is
not
None
:
if
world_size
>
1
:
self
.
x
=
DistributedTensor
(
pdata
.
x
.
to
(
self
.
device
))
self
.
x
=
DistributedTensor
(
pdata
.
x
.
to
(
self
.
device
)
.
to
(
torch
.
float
)
)
else
:
self
.
x
=
pdata
.
x
.
to
(
device
)
.
to
(
torch
.
float
)
else
:
self
.
x
=
None
if
hasattr
(
pdata
,
'edge_attr'
)
and
pdata
.
edge_attr
is
not
None
:
if
world_size
>
1
:
self
.
edge_attr
=
DistributedTensor
(
pdata
.
edge_attr
.
to
(
self
.
device
))
self
.
edge_attr
=
DistributedTensor
(
pdata
.
edge_attr
.
to
(
'cpu'
)
.
to
(
torch
.
float
))
else
:
self
.
edge_attr
=
pdata
.
edge_attr
.
to
(
'c
uda
'
)
.
to
(
torch
.
float
)
self
.
edge_attr
=
pdata
.
edge_attr
.
to
(
'c
pu
'
)
.
to
(
torch
.
float
)
else
:
self
.
edge_attr
=
None
...
...
@@ -43,14 +43,15 @@ class GraphData():
return
self
.
x
[
ids
]
else
:
return
self
.
x
.
index_select
(
ids
)
def
_get_edge_attr
(
self
,
ids
):
def
_get_edge_attr
(
self
,
ids
,
):
if
self
.
edge_attr
is
None
:
return
None
elif
dist
.
get_world_size
()
==
1
:
return
self
.
edge_attr
[
ids
.
to
(
'cpu'
)]
.
to
(
'cuda'
)
return
self
.
edge_attr
[
ids
]
else
:
return
self
.
edge_attr
.
index_select
(
ids
)
class
DataSet
:
def
__init__
(
self
,
nodes
=
None
,
edges
=
None
,
...
...
@@ -69,6 +70,16 @@ class DataSet:
for
k
,
v
in
kwargs
.
items
():
assert
isinstance
(
v
,
torch
.
Tensor
)
and
v
.
shape
[
0
]
==
self
.
len
setattr
(
self
,
k
,
v
.
to
(
device
))
def
_get_empty
(
self
):
nodes
=
torch
.
empty
([],
dtype
=
self
.
nodes
.
dtype
,
device
=
self
.
nodes
.
device
)
if
hasattr
(
self
,
'nodes'
)
else
None
edges
=
torch
.
empty
([[],[]],
dtype
=
self
.
edges
.
dtype
,
device
=
self
.
edge
.
device
)
if
hasattr
(
self
,
'edges'
)
else
None
d
=
DataSet
(
nodes
,
edges
)
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
==
'edges'
or
k
==
'nodes'
or
k
==
'len'
:
continue
else
:
setattr
(
d
,
k
,
torch
.
empty
([]))
return
d
#@staticmethod
def
get_next
(
self
,
indx
):
...
...
starrygl/sample/memory/shared_mailbox.py
View file @
a70d1b3b
...
...
@@ -214,7 +214,7 @@ class SharedMailBox():
return
index
,
memory
,
ts
def
_
get_memory
(
self
,
index
):
def
get_memory
(
self
,
index
):
if
self
.
num_parts
==
1
:
return
self
.
node_memory
.
accessor
.
data
[
index
],
\
self
.
node_memory_ts
.
accessor
.
data
[
index
],
\
...
...
@@ -234,6 +234,9 @@ class SharedMailBox():
#print(memory.shape[0])
return
memory
,
memory_ts
,
mail
,
mail_ts
return
torch
.
futures
.
collect_all
([
memory
,
memory_ts
,
mail
,
mail_ts
])
.
then
(
callback
)
def
gather_memory
(
self
,
index
,
input_split
,
out_split
):
return
self
.
node_memory
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
node_memory_ts
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
mailbox
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
mailbox_ts
.
scatter_data
(
index
,
input_split
,
out_split
)
starrygl/sample/stream_manager.py
View file @
a70d1b3b
...
...
@@ -4,4 +4,8 @@ class WorkStreamEvent:
self
.
train_stream
=
torch
.
cuda
.
Stream
()
self
.
write_memory_stream
=
torch
.
cuda
.
Stream
()
self
.
fetch_stream
=
torch
.
cuda
.
Stream
()
self
.
write_mail_stream
=
torch
.
cuda
.
Stream
()
\ No newline at end of file
self
.
write_mail_stream
=
torch
.
cuda
.
Stream
()
self
.
event
=
None
event
=
WorkStreamEvent
()
def
get_event
():
return
event
.
event
train_tgnn.py
View file @
a70d1b3b
...
...
@@ -82,8 +82,9 @@ def main():
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
device_id
=
torch
.
cuda
.
current_device
()
print
(
'use cuda on'
,
device_id
)
pdata
=
partition_load
(
"./dataset/here/
WIKI
"
,
algo
=
"metis_for_tgnn"
)
pdata
=
partition_load
(
"./dataset/here/
GDELT
"
,
algo
=
"metis_for_tgnn"
)
graph
=
GraphData
(
pdata
=
pdata
)
sample_graph
=
TemporalNeighborSampleGraph
(
sample_graph
=
pdata
.
sample_graph
,
mode
=
'full'
)
mailbox
=
SharedMailBox
(
pdata
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
pdata
.
edge_attr
.
shape
[
1
]
if
pdata
.
edge_attr
is
not
None
else
0
)
sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
1
,
fanout
=
[
10
],
graph_data
=
sample_graph
,
workers
=
10
,
policy
=
'recent'
,
graph_name
=
"wiki_train"
)
...
...
@@ -102,17 +103,17 @@ def main():
trainloader
=
DistributedDataLoader
(
graph
,
train_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
drop_last
=
True
,
chunk_size
=
None
,
train
=
True
,
queue_size
=
100
,
queue_size
=
100
0
,
mailbox
=
mailbox
)
testloader
=
DistributedDataLoader
(
graph
,
test_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
drop_last
=
False
,
chunk_size
=
None
,
...
...
@@ -122,7 +123,7 @@ def main():
valloader
=
DistributedDataLoader
(
graph
,
val_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
drop_last
=
False
,
chunk_size
=
None
,
...
...
@@ -158,9 +159,9 @@ def main():
with
torch
.
no_grad
():
total_loss
=
0
signal
=
torch
.
tensor
([
0
],
dtype
=
int
,
device
=
device
)
for
roots
,
mfgs
,
metadata
in
loader
:
signal
[
0
]
=
0
dist
.
all_reduce
(
signal
,
async_op
=
False
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
total_loss
+=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
total_loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
...
...
@@ -190,32 +191,28 @@ def main():
src
,
dst
,
ts
,
edge_feats
,
model
.
module
.
memory_updater
.
last_updated_memory
,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
if
mode
==
'val'
:
val_losses
.
append
(
float
(
total_loss
))
while
(
signal
[
0
]
.
item
()
!=
dist
.
get_world_size
()):
signal
[
0
]
=
1
dist
.
all_reduce
(
signal
,
async_op
=
False
)
if
(
signal
[
0
]
.
item
()
==
dist
.
get_world_size
()):
break
if
mailbox
is
not
None
:
mailbox
.
set_mailbox_all_to_all
(
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
,
mailbox
.
memory_size
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
,
mailbox
.
mailbox
.
accessor
.
data
.
size
(
2
)),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
reduce_Op
=
'max'
)
mailbox
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
)
ap
=
float
(
torch
.
tensor
(
aps
)
.
mean
())
if
neg_samples
>
1
:
auc_mrr
=
float
(
torch
.
cat
(
aucs_mrrs
)
.
mean
())
else
:
auc_mrr
=
float
(
torch
.
tensor
(
aucs_mrrs
)
.
mean
())
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size
=
dist
.
get_world_size
()
apc
=
torch
.
empty
([
loader
.
expected_idx
*
world_size
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
auc_mrr
=
torch
.
empty
([
loader
.
expected_idx
*
world_size
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
apc
,
torch
.
tensor
(
aps
,
device
=
'cuda'
,
dtype
=
torch
.
float
))
dist
.
all_gather_into_tensor
(
auc_mrr
,
torch
.
tensor
(
aucs_mrrs
,
device
=
'cuda'
,
dtype
=
torch
.
float
))
ap
=
float
(
torch
.
tensor
(
apc
)
.
mean
())
auc_mrr
=
float
(
torch
.
tensor
(
auc_mrr
)
.
mean
())
return
ap
,
auc_mrr
creterion
=
torch
.
nn
.
BCEWithLogitsLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
train_param
[
'lr'
])
for
e
in
range
(
train_param
[
'epoch'
]):
torch
.
cuda
.
synchronize
()
epoch_start_time
=
time
.
time
()
train_aps
=
list
()
print
(
'Epoch {:d}:'
.
format
(
e
))
...
...
@@ -229,14 +226,18 @@ def main():
model
.
module
.
memory_updater
.
last_updated_ts
=
None
for
roots
,
mfgs
,
metadata
in
trainloader
:
t_prep_s
=
time
.
time
()
optimizer
.
zero_grad
()
with
torch
.
cuda
.
stream
(
train_stream
):
optimizer
.
zero_grad
()
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
total_loss
+=
float
(
loss
)
loss
.
backward
()
optimizer
.
step
()
#torch.cuda.synchronize()
t_prep_s
=
time
.
time
()
y_pred
=
torch
.
cat
([
pred_pos
,
pred_neg
],
dim
=
0
)
.
sigmoid
()
.
cpu
()
y_true
=
torch
.
cat
([
torch
.
ones
(
pred_pos
.
size
(
0
)),
torch
.
zeros
(
pred_neg
.
size
(
0
))],
dim
=
0
)
...
...
@@ -262,7 +263,9 @@ def main():
src
,
dst
,
ts
,
edge_feats
,
model
.
module
.
memory_updater
.
last_updated_memory
,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
mailbox
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
)
torch
.
cuda
.
synchronize
()
time_prep
=
time
.
time
()
-
epoch_start_time
avg_time
+=
time
.
time
()
-
epoch_start_time
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment