Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
S
starrygl-DynamicHistory
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
starrygl-DynamicHistory
Commits
a70d1b3b
Commit
a70d1b3b
authored
Dec 19, 2023
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add sample module
parent
abce8f53
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
290 additions
and
85 deletions
+290
-85
a.py
+0
-20
b.py
+133
-0
data_maker.py
+2
-0
nohup.out
+14
-0
starrygl/distributed/utils.py
+32
-1
starrygl/sample/batch_data.py
+38
-22
starrygl/sample/data_loader.py
+13
-5
starrygl/sample/graph_core/__init__.py
+16
-5
starrygl/sample/memory/shared_mailbox.py
+6
-3
starrygl/sample/stream_manager.py
+4
-0
train_tgnn.py
+32
-29
No files found.
a.py
deleted
100644 → 0
View file @
abce8f53
def
foo
():
a
=
1
def
fa
():
print
(
a
)
a
+=
1
print
(
a
)
def
fb
(
a
):
def
apply
():
print
(
a
)
return
apply
fc
=
lambda
:
print
(
a
)
return
fa
,
fb
(
a
),
fc
fa
,
fb
,
fc
=
foo
()
fa
()
fb
()
fc
()
b.py
0 → 100644
View file @
a70d1b3b
import
argparse
import
os
import
sys
from
os.path
import
abspath
,
join
,
dirname
from
starrygl.distributed.context
import
DistributedContext
from
starrygl.distributed.utils
import
DistIndex
from
starrygl.module.modules
import
GeneralModel
from
starrygl.module.utils
import
parse_config
from
starrygl.sample.graph_core
import
DataSet
,
GraphData
,
TemporalNeighborSampleGraph
from
starrygl.sample.memory.shared_mailbox
import
SharedMailBox
from
starrygl.sample.sample_core.base
import
NegativeSampling
from
starrygl.sample.sample_core.neighbor_sampler
import
NeighborSampler
from
starrygl.sample.part_utils.partition_tgnn
import
partition_load
import
torch
import
time
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.distributed
import
init_process_group
,
destroy_process_group
import
os
from
starrygl.sample.data_loader
import
DistributedDataLoader
from
starrygl.sample.batch_data
import
SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser
=
argparse
.
ArgumentParser
(
description
=
"RPC Reinforcement Learning Example"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
'--rank'
,
default
=
0
,
type
=
str
,
metavar
=
'W'
,
help
=
'name of dataset'
)
parser
.
add_argument
(
'--world_size'
,
default
=
1
,
type
=
int
,
metavar
=
'W'
,
help
=
'number of negative samples'
)
args
=
parser
.
parse_args
()
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
import
torch
import
time
import
random
import
dgl
import
numpy
as
np
from
sklearn.metrics
import
average_precision_score
,
roc_auc_score
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
rank
)
os
.
environ
[
"RANK"
]
=
str
(
args
.
rank
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
args
.
world_size
)
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
0
)
os
.
environ
[
"MASTER_ADDR"
]
=
'127.0.0.1'
os
.
environ
[
"MASTER_PORT"
]
=
'9337'
def
seed_everything
(
seed
=
42
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
seed_everything
(
1234
)
def
main
():
use_cuda
=
True
sample_param
,
memory_param
,
gnn_param
,
train_param
=
parse_config
(
'./config/TGN.yml'
)
torch
.
set_num_threads
(
12
)
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
device_id
=
torch
.
cuda
.
current_device
()
print
(
'use cuda on'
,
device_id
)
pdata
=
partition_load
(
"./dataset/here/WIKI"
,
algo
=
"metis_for_tgnn"
)
graph
=
GraphData
(
pdata
=
pdata
)
#dist.barrier()
#for i in range(100):
# print(i)
dist
.
barrier
()
idx
=
((
graph
.
eids_mapper
>>
48
)
.
int
()
&
0xFFFF
)
print
((
idx
==
0
)
.
nonzero
()
.
shape
,(
idx
==
1
)
.
nonzero
()
.
shape
)
t1
=
time
.
time
()
"""
fut = []
for i in range(1000):
#print(i)
out = graph.edge_attr.index_select(graph.eids_mapper[(idx== 0)|(idx ==1)].to('cuda'))
fut.append(out)
#out.wait()
#out.value()
if i>0 and i
%100
==0:
f = torch.futures.collect_all(fut)
f.wait()
f.value()
fut = []
"""
partptr
=
torch
.
tensor
([
((
i
&
0xFFFF
)
<<
48
)
for
i
in
range
(
3
)
],
device
=
'cuda'
)
for
i
in
range
(
1000
):
if
i
%
100
==
0
:
idx
=
graph
.
eids_mapper
.
to
(
'cuda'
)
idx
,
inv
=
idx
.
unique
(
return_inverse
=
True
)
ind
=
torch
.
searchsorted
(
idx
,
partptr
,
right
=
False
)
len
=
ind
[
1
:]
-
ind
[:
-
1
]
gatherlen
=
torch
.
empty
([
2
],
dtype
=
torch
.
long
,
device
=
'cuda'
)
dist
.
all_to_all_single
(
gatherlen
,
len
)
query_idx
=
torch
.
empty
([
gatherlen
.
sum
()],
dtype
=
torch
.
long
,
device
=
'cuda'
)
input_s
=
list
(
len
)
output_s
=
list
(
gatherlen
)
dist
.
all_to_all_single
(
query_idx
,
idx
,
output_s
,
input_s
)
input_f
=
graph
.
edge_attr
.
accessor
.
data
[
DistIndex
(
query_idx
)
.
loc
]
f
=
torch
.
empty
([
idx
.
shape
[
0
],
graph
.
edge_attr
.
accessor
.
data
.
shape
[
1
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
dist
.
all_to_all_single
(
f
,
input_f
,
input_s
,
output_s
)
torch
.
cuda
.
synchronize
()
t2
=
time
.
time
()
-
t1
print
(
t2
)
#dist.barrier()
ctx
.
shutdown
()
if
__name__
==
"__main__"
:
main
()
data_maker.py
View file @
a70d1b3b
...
@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
...
@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict
=
edge_weight_dict
)
edge_weight_dict
=
edge_weight_dict
)
partition_save
(
'./dataset/here/'
+
data_name
,
data
,
2
,
'metis_for_tgnn'
,
partition_save
(
'./dataset/here/'
+
data_name
,
data
,
2
,
'metis_for_tgnn'
,
edge_weight_dict
=
edge_weight_dict
)
edge_weight_dict
=
edge_weight_dict
)
partition_save
(
'./dataset/here/'
+
data_name
,
data
,
4
,
'metis_for_tgnn'
,
edge_weight_dict
=
edge_weight_dict
)
#
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
# edge_weight_dict=edge_weight_dict )
...
...
nohup.out
0 → 100644
View file @
a70d1b3b
ERROR:root:unable to import libstarrygl.so, some features may not be available.
the number of nodes in graph is 1980, the number of edges in graph is 1293103
directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_1' not empty and cleared
running partition algorithm: metis_for_tgnn
saving partition data: 1/1
running partition algorithm: metis_for_tgnn
saving partition data: 1/2
saving partition data: 2/2
creating directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_4'
running partition algorithm: metis_for_tgnn
saving partition data: 1/4
saving partition data: 2/4
saving partition data: 3/4
saving partition data: 4/4
starrygl/distributed/utils.py
View file @
a70d1b3b
...
@@ -15,6 +15,7 @@ class TensorAccessor:
...
@@ -15,6 +15,7 @@ class TensorAccessor:
self
.
_data
=
data
self
.
_data
=
data
self
.
_ctx
=
DistributedContext
.
get_default_context
()
self
.
_ctx
=
DistributedContext
.
get_default_context
()
self
.
_rref
=
rpc
.
RRef
(
data
)
self
.
_rref
=
rpc
.
RRef
(
data
)
self
.
_rref
.
confirmed_by_owner
@property
@property
def
data
(
self
):
def
data
(
self
):
...
@@ -28,6 +29,20 @@ class TensorAccessor:
...
@@ -28,6 +29,20 @@ class TensorAccessor:
def
ctx
(
self
):
def
ctx
(
self
):
return
self
.
_ctx
return
self
.
_ctx
def
all_gather_index
(
self
,
index
,
input_split
)
->
Tensor
:
out_split
=
torch
.
empty_like
(
input_split
)
torch
.
distributed
.
all_to_all_single
(
out_split
,
input_split
)
input_split
=
list
(
input_split
)
output
=
torch
.
empty
([
out_split
.
sum
()],
dtype
=
index
.
dtype
,
device
=
index
.
device
)
out_split
=
list
(
out_split
)
torch
.
distributed
.
all_to_all_single
(
output
,
index
,
out_split
,
input_split
)
return
output
,
out_split
,
input_split
def
all_gather_data
(
self
,
index
,
input_split
,
out_split
):
output
=
torch
.
empty
([
int
(
Tensor
(
out_split
)
.
sum
()
.
item
()),
*
self
.
_data
.
shape
[
1
:]],
dtype
=
self
.
_data
.
dtype
,
device
=
'cuda'
)
#self._data.device)
torch
.
distributed
.
all_to_all_single
(
output
,
self
.
data
[
index
.
to
(
self
.
data
.
device
)]
.
to
(
'cuda'
),
output_split_sizes
=
out_split
,
input_split_sizes
=
input_split
)
return
output
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
def
all_gather_rrefs
(
self
)
->
List
[
rpc
.
RRef
]:
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
return
self
.
ctx
.
all_gather_remote_objects
(
self
.
rref
)
...
@@ -101,7 +116,6 @@ class DistributedTensor:
...
@@ -101,7 +116,6 @@ class DistributedTensor:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
def
__init__
(
self
,
data
:
Tensor
)
->
None
:
self
.
accessor
=
TensorAccessor
(
data
)
self
.
accessor
=
TensorAccessor
(
data
)
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
self
.
rrefs
=
self
.
accessor
.
all_gather_rrefs
()
# self.num_parts = len(self.rrefs)
# self.num_parts = len(self.rrefs)
local_sizes
=
[]
local_sizes
=
[]
...
@@ -110,6 +124,7 @@ class DistributedTensor:
...
@@ -110,6 +124,7 @@ class DistributedTensor:
local_sizes
.
append
(
n
)
local_sizes
.
append
(
n
)
self
.
num_nodes
=
DistInt
(
local_sizes
)
self
.
num_nodes
=
DistInt
(
local_sizes
)
self
.
num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
num_parts
=
DistInt
([
1
]
*
len
(
self
.
rrefs
))
self
.
distptr
=
torch
.
tensor
([((
part_ids
&
0xFFFF
)
<<
48
)
for
part_ids
in
range
(
self
.
num_parts
()
+
1
)],
device
=
'cuda'
)
#data.device)
@property
@property
...
@@ -130,6 +145,17 @@ class DistributedTensor:
...
@@ -130,6 +145,17 @@ class DistributedTensor:
def
ctx
(
self
):
def
ctx
(
self
):
return
self
.
accessor
.
ctx
return
self
.
accessor
.
ctx
def
gather_select_index
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
data
=
dist_index
.
data
posptr
=
torch
.
searchsorted
(
data
,
self
.
distptr
,
right
=
False
)
input_split
=
posptr
[
1
:]
-
posptr
[:
-
1
]
return
self
.
accessor
.
all_gather_index
(
DistIndex
(
data
)
.
loc
,
input_split
)
def
scatter_data
(
self
,
local_index
,
input_split
,
out_split
):
return
self
.
accessor
.
all_gather_data
(
local_index
,
input_split
=
input_split
,
out_split
=
out_split
)
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
def
index_select
(
self
,
dist_index
:
Union
[
Tensor
,
DistIndex
]):
if
isinstance
(
dist_index
,
Tensor
):
if
isinstance
(
dist_index
,
Tensor
):
dist_index
=
DistIndex
(
dist_index
)
dist_index
=
DistIndex
(
dist_index
)
...
@@ -139,6 +165,10 @@ class DistributedTensor:
...
@@ -139,6 +165,10 @@ class DistributedTensor:
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
futs
:
List
[
torch
.
futures
.
Future
]
=
[]
for
i
in
range
(
self
.
num_parts
()):
for
i
in
range
(
self
.
num_parts
()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f
=
self
.
accessor
.
async_index_select
(
0
,
index
[
part_idx
==
i
],
self
.
rrefs
[
i
])
f
=
self
.
accessor
.
async_index_select
(
0
,
index
[
part_idx
==
i
],
self
.
rrefs
[
i
])
futs
.
append
(
f
)
futs
.
append
(
f
)
...
@@ -150,6 +180,7 @@ class DistributedTensor:
...
@@ -150,6 +180,7 @@ class DistributedTensor:
result
=
torch
.
empty
(
result
=
torch
.
empty
(
part_idx
.
size
(
0
),
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
,
part_idx
.
size
(
0
),
*
t
.
shape
[
1
:],
dtype
=
t
.
dtype
,
device
=
t
.
device
,
)
)
result
[
part_idx
==
i
]
=
t
result
[
part_idx
==
i
]
=
t
return
result
return
result
return
torch
.
futures
.
collect_all
(
futs
)
.
then
(
callback
)
return
torch
.
futures
.
collect_all
(
futs
)
.
then
(
callback
)
...
...
starrygl/sample/batch_data.py
View file @
a70d1b3b
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
starrygl.distributed.utils
import
DistributedTensor
from
starrygl.module.memorys
import
MailBox
from
starrygl.sample.graph_core
import
DataSet
from
starrygl.sample.graph_core
import
DataSet
from
starrygl.sample.graph_core
import
GraphData
from
starrygl.sample.graph_core
import
GraphData
from
starrygl.sample.sample_core.base
import
BaseSampler
,
NegativeSampling
from
starrygl.sample.sample_core.base
import
BaseSampler
,
NegativeSampling
...
@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
...
@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return
mfgs
return
mfgs
def
to_block
(
graph
:
GraphData
,
data
,
sample_out
,
mailbox
=
None
,
device
=
torch
.
device
(
'cuda'
)):
def
to_block
(
graph
:
GraphData
,
data
,
sample_out
,
mailbox
:
MailBox
=
None
,
device
=
torch
.
device
(
'cuda'
)):
if
len
(
sample_out
)
>
1
:
if
len
(
sample_out
)
>
1
:
sample_out
,
metadata
=
sample_out
sample_out
,
metadata
=
sample_out
...
@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
...
@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
dist_eid
,
eid_inv
=
dist_eid
.
unique
(
return_inverse
=
True
)
dist_eid
,
eid_inv
=
dist_eid
.
unique
(
return_inverse
=
True
)
src_node
=
graph
.
sample_graph
[
'edge_index'
][
0
,
eid_tensor
*
2
]
.
to
(
graph
.
nids_mapper
.
device
)
src_node
=
graph
.
sample_graph
[
'edge_index'
][
0
,
eid_tensor
*
2
]
.
to
(
graph
.
nids_mapper
.
device
)
src_ts
=
None
src_ts
=
None
edge_feat
=
graph
.
_get_edge_attr
(
dist_eid
)
if
metadata
is
None
:
if
metadata
is
None
:
root_node
=
data
.
nodes
.
to
(
graph
.
nids_mapper
.
device
)
root_node
=
data
.
nodes
.
to
(
graph
.
nids_mapper
.
device
)
root_len
=
[
root_node
.
shape
[
0
]]
root_len
=
[
root_node
.
shape
[
0
]]
...
@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
...
@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
nid_tensor
=
torch
.
cat
([
root_node
,
src_node
],
dim
=
0
)
nid_tensor
=
torch
.
cat
([
root_node
,
src_node
],
dim
=
0
)
dist_nid
=
nid_mapper
[
nid_tensor
]
.
to
(
device
)
dist_nid
=
nid_mapper
[
nid_tensor
]
.
to
(
device
)
dist_nid
,
nid_inv
=
dist_nid
.
unique
(
return_inverse
=
True
)
dist_nid
,
nid_inv
=
dist_nid
.
unique
(
return_inverse
=
True
)
if
isinstance
(
graph
.
edge_attr
,
DistributedTensor
):
local_index
,
input_split
,
output_split
=
graph
.
edge_attr
.
gather_select_index
(
dist_eid
)
edge_feat
=
graph
.
edge_attr
.
scatter_data
(
local_index
,
input_split
=
input_split
,
out_split
=
output_split
)
else
:
edge_feat
=
graph
.
_get_edge_attr
(
dist_eid
)
local_index
=
None
if
isinstance
(
graph
.
x
,
DistributedTensor
):
local_index
,
input_split
,
output_split
=
graph
.
x
.
gather_select_index
(
dist_nid
)
node_feat
=
graph
.
x
.
scatter_data
(
local_index
,
input_split
=
input_split
,
out_split
=
output_split
)
else
:
node_feat
=
graph
.
_get_node_attr
(
dist_nid
)
node_feat
=
graph
.
_get_node_attr
(
dist_nid
)
if
mailbox
is
not
None
:
if
mailbox
is
not
None
:
mem
=
mailbox
.
_get_memory
(
dist_nid
)
if
torch
.
distributed
.
get_world_size
()
>
1
:
if
node_feat
is
None
:
local_index
,
input_split
,
output_split
=
mailbox
.
node_memory
.
gather_select_index
(
dist_nid
)
mem
=
mailbox
.
gather_memory
(
local_index
,
input_split
,
output_split
)
else
:
mem
=
mailbox
.
get_memory
(
dist_nid
)
else
:
else
:
mem
=
None
mem
=
None
...
@@ -120,25 +136,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
...
@@ -120,25 +136,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
return
data
,
mfgs
,
metadata
return
data
,
mfgs
,
metadata
data
,
mfgs
,
metadata
=
build_block
()
data
,
mfgs
,
metadata
=
build_block
()
if
dist
.
get_world_size
()
>
1
:
#
if dist.get_world_size() > 1:
if
(
node_feat
is
None
):
#
if(node_feat is None):
node_feat
=
torch
.
futures
.
Future
()
#
node_feat = torch.futures.Future()
node_feat
.
set_result
(
None
)
#
node_feat.set_result(None)
if
(
edge_feat
is
None
):
#
if(edge_feat is None):
edge_feat
=
torch
.
futures
.
Future
()
#
edge_feat = torch.futures.Future()
edge_feat
.
set_result
(
None
)
#
edge_feat.set_result(None)
if
(
mem
is
None
):
#
if(mem is None):
mem
=
torch
.
futures
.
Future
()
#
mem = torch.futures.Future()
mem
.
set_result
(
None
)
#
mem.set_result(None)
def
callback
(
fs
,
mfgs
,
dist_nid
,
dist_eid
):
#
def callback(fs,mfgs,dist_nid,dist_eid):
node_feat
,
edge_feat
,
mem_embedding
=
fs
.
value
()
#
node_feat,edge_feat,mem_embedding = fs.value()
node_feat
=
node_feat
.
value
()
#
node_feat = node_feat.value()
edge_feat
=
edge_feat
.
value
()
#
edge_feat = edge_feat.value()
mem_embedding
=
mem_embedding
.
value
()
#
mem_embedding = mem_embedding.value()
return
prepare_input
(
node_feat
,
edge_feat
,
mem_embedding
,
mfgs
,
dist_nid
,
dist_eid
)
#
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
cal
=
lambda
fut
:
callback
(
fs
=
fut
,
mfgs
=
mfgs
,
dist_nid
=
dist_nid
,
dist_eid
=
dist_eid
)
#
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
return
data
,
torch
.
futures
.
collect_all
([
node_feat
,
edge_feat
,
mem
])
.
then
(
cal
),
metadata
#
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
else
:
#
else:
mfgs
=
prepare_input
(
node_feat
,
edge_feat
,
mem
,
mfgs
,
dist_nid
,
dist_eid
)
mfgs
=
prepare_input
(
node_feat
,
edge_feat
,
mem
,
mfgs
,
dist_nid
,
dist_eid
)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return
data
,
mfgs
,
metadata
return
data
,
mfgs
,
metadata
...
...
starrygl/sample/data_loader.py
View file @
a70d1b3b
...
@@ -66,7 +66,8 @@ class DistributedDataLoader:
...
@@ -66,7 +66,8 @@ class DistributedDataLoader:
if
train
is
True
:
if
train
is
True
:
self
.
_get_expected_idx
(
self
.
dataset
.
len
)
self
.
_get_expected_idx
(
self
.
dataset
.
len
)
else
:
else
:
self
.
expected_idx
=
int
(
math
.
ceil
(
self
.
dataset
.
len
/
self
.
batch_size
))
self
.
_get_expected_idx
(
self
.
dataset
.
len
,
op
=
dist
.
ReduceOp
.
MAX
)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def
__iter__
(
self
):
def
__iter__
(
self
):
if
self
.
chunk_size
is
None
:
if
self
.
chunk_size
is
None
:
...
@@ -102,18 +103,19 @@ class DistributedDataLoader:
...
@@ -102,18 +103,19 @@ class DistributedDataLoader:
self
.
neg_sampler
.
set_next_pos
(
self
.
current_pos
)
self
.
neg_sampler
.
set_next_pos
(
self
.
current_pos
)
return
self
return
self
def
_get_expected_idx
(
self
,
data_size
):
def
_get_expected_idx
(
self
,
data_size
,
op
=
dist
.
ReduceOp
.
MIN
):
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
self
.
expected_idx
=
data_size
//
self
.
batch_size
if
self
.
drop_last
is
True
else
int
(
math
.
ceil
(
data_size
/
self
.
batch_size
))
self
.
expected_idx
=
data_size
//
self
.
batch_size
if
self
.
drop_last
is
True
else
int
(
math
.
ceil
(
data_size
/
self
.
batch_size
))
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
num_epochs
=
torch
.
tensor
([
self
.
expected_idx
],
dtype
=
torch
.
long
,
device
=
self
.
device
)
num_epochs
=
torch
.
tensor
([
self
.
expected_idx
],
dtype
=
torch
.
long
,
device
=
self
.
device
)
dist
.
all_reduce
(
num_epochs
,
op
=
dist
.
ReduceOp
.
MIN
)
print
(
num_epochs
)
dist
.
all_reduce
(
num_epochs
,
op
=
op
)
self
.
expected_idx
=
int
(
num_epochs
.
item
())
self
.
expected_idx
=
int
(
num_epochs
.
item
())
def
_next_data
(
self
):
def
_next_data
(
self
):
if
self
.
current_pos
>=
self
.
dataset
.
len
:
if
self
.
current_pos
>=
self
.
dataset
.
len
:
return
None
return
self
.
input_dataset
.
_get_empty
()
if
self
.
current_pos
+
self
.
batch_size
>
self
.
input_dataset
.
len
:
if
self
.
current_pos
+
self
.
batch_size
>
self
.
input_dataset
.
len
:
if
self
.
drop_last
:
if
self
.
drop_last
:
...
@@ -132,7 +134,7 @@ class DistributedDataLoader:
...
@@ -132,7 +134,7 @@ class DistributedDataLoader:
return
next_data
return
next_data
def
__next__
(
self
):
def
__next__
(
self
):
if
(
dist
.
get_world_size
()
==
1
):
if
(
dist
.
get_world_size
()
>
0
):
if
self
.
recv_idxs
<
self
.
expected_idx
:
if
self
.
recv_idxs
<
self
.
expected_idx
:
data
=
self
.
_next_data
()
data
=
self
.
_next_data
()
batch_data
=
graph_sample
(
self
.
graph
,
batch_data
=
graph_sample
(
self
.
graph
,
...
@@ -165,6 +167,12 @@ class DistributedDataLoader:
...
@@ -165,6 +167,12 @@ class DistributedDataLoader:
next_data
,
self
.
neg_sampler
,
next_data
,
self
.
neg_sampler
,
self
.
mailbox
,
self
.
mailbox
,
self
.
device
)
self
.
device
)
batch_data
[
1
]
.
wait
()
self
.
submitted
=
self
.
submitted
+
1
self
.
num_pending
=
self
.
num_pending
+
1
self
.
recv_idxs
+=
1
self
.
num_pending
-=
1
return
batch_data
[
0
],
batch_data
[
1
]
.
value
(),
batch_data
[
2
]
self
.
result_queue
.
append
(
batch_data
)
self
.
result_queue
.
append
(
batch_data
)
self
.
submitted
=
self
.
submitted
+
1
self
.
submitted
=
self
.
submitted
+
1
self
.
num_pending
=
self
.
num_pending
+
1
self
.
num_pending
=
self
.
num_pending
+
1
...
...
starrygl/sample/graph_core/__init__.py
View file @
a70d1b3b
...
@@ -23,16 +23,16 @@ class GraphData():
...
@@ -23,16 +23,16 @@ class GraphData():
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
if
hasattr
(
pdata
,
'x'
)
and
pdata
.
x
is
not
None
:
if
hasattr
(
pdata
,
'x'
)
and
pdata
.
x
is
not
None
:
if
world_size
>
1
:
if
world_size
>
1
:
self
.
x
=
DistributedTensor
(
pdata
.
x
.
to
(
self
.
device
))
self
.
x
=
DistributedTensor
(
pdata
.
x
.
to
(
self
.
device
)
.
to
(
torch
.
float
)
)
else
:
else
:
self
.
x
=
pdata
.
x
.
to
(
device
)
.
to
(
torch
.
float
)
self
.
x
=
pdata
.
x
.
to
(
device
)
.
to
(
torch
.
float
)
else
:
else
:
self
.
x
=
None
self
.
x
=
None
if
hasattr
(
pdata
,
'edge_attr'
)
and
pdata
.
edge_attr
is
not
None
:
if
hasattr
(
pdata
,
'edge_attr'
)
and
pdata
.
edge_attr
is
not
None
:
if
world_size
>
1
:
if
world_size
>
1
:
self
.
edge_attr
=
DistributedTensor
(
pdata
.
edge_attr
.
to
(
self
.
device
))
self
.
edge_attr
=
DistributedTensor
(
pdata
.
edge_attr
.
to
(
'cpu'
)
.
to
(
torch
.
float
))
else
:
else
:
self
.
edge_attr
=
pdata
.
edge_attr
.
to
(
'c
uda
'
)
.
to
(
torch
.
float
)
self
.
edge_attr
=
pdata
.
edge_attr
.
to
(
'c
pu
'
)
.
to
(
torch
.
float
)
else
:
else
:
self
.
edge_attr
=
None
self
.
edge_attr
=
None
...
@@ -43,14 +43,15 @@ class GraphData():
...
@@ -43,14 +43,15 @@ class GraphData():
return
self
.
x
[
ids
]
return
self
.
x
[
ids
]
else
:
else
:
return
self
.
x
.
index_select
(
ids
)
return
self
.
x
.
index_select
(
ids
)
def
_get_edge_attr
(
self
,
ids
):
def
_get_edge_attr
(
self
,
ids
,
):
if
self
.
edge_attr
is
None
:
if
self
.
edge_attr
is
None
:
return
None
return
None
elif
dist
.
get_world_size
()
==
1
:
elif
dist
.
get_world_size
()
==
1
:
return
self
.
edge_attr
[
ids
.
to
(
'cpu'
)]
.
to
(
'cuda'
)
return
self
.
edge_attr
[
ids
]
else
:
else
:
return
self
.
edge_attr
.
index_select
(
ids
)
return
self
.
edge_attr
.
index_select
(
ids
)
class
DataSet
:
class
DataSet
:
def
__init__
(
self
,
nodes
=
None
,
def
__init__
(
self
,
nodes
=
None
,
edges
=
None
,
edges
=
None
,
...
@@ -69,6 +70,16 @@ class DataSet:
...
@@ -69,6 +70,16 @@ class DataSet:
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
assert
isinstance
(
v
,
torch
.
Tensor
)
and
v
.
shape
[
0
]
==
self
.
len
assert
isinstance
(
v
,
torch
.
Tensor
)
and
v
.
shape
[
0
]
==
self
.
len
setattr
(
self
,
k
,
v
.
to
(
device
))
setattr
(
self
,
k
,
v
.
to
(
device
))
def
_get_empty
(
self
):
nodes
=
torch
.
empty
([],
dtype
=
self
.
nodes
.
dtype
,
device
=
self
.
nodes
.
device
)
if
hasattr
(
self
,
'nodes'
)
else
None
edges
=
torch
.
empty
([[],[]],
dtype
=
self
.
edges
.
dtype
,
device
=
self
.
edge
.
device
)
if
hasattr
(
self
,
'edges'
)
else
None
d
=
DataSet
(
nodes
,
edges
)
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
==
'edges'
or
k
==
'nodes'
or
k
==
'len'
:
continue
else
:
setattr
(
d
,
k
,
torch
.
empty
([]))
return
d
#@staticmethod
#@staticmethod
def
get_next
(
self
,
indx
):
def
get_next
(
self
,
indx
):
...
...
starrygl/sample/memory/shared_mailbox.py
View file @
a70d1b3b
...
@@ -214,7 +214,7 @@ class SharedMailBox():
...
@@ -214,7 +214,7 @@ class SharedMailBox():
return
index
,
memory
,
ts
return
index
,
memory
,
ts
def
_
get_memory
(
self
,
index
):
def
get_memory
(
self
,
index
):
if
self
.
num_parts
==
1
:
if
self
.
num_parts
==
1
:
return
self
.
node_memory
.
accessor
.
data
[
index
],
\
return
self
.
node_memory
.
accessor
.
data
[
index
],
\
self
.
node_memory_ts
.
accessor
.
data
[
index
],
\
self
.
node_memory_ts
.
accessor
.
data
[
index
],
\
...
@@ -234,6 +234,9 @@ class SharedMailBox():
...
@@ -234,6 +234,9 @@ class SharedMailBox():
#print(memory.shape[0])
#print(memory.shape[0])
return
memory
,
memory_ts
,
mail
,
mail_ts
return
memory
,
memory_ts
,
mail
,
mail_ts
return
torch
.
futures
.
collect_all
([
memory
,
memory_ts
,
mail
,
mail_ts
])
.
then
(
callback
)
return
torch
.
futures
.
collect_all
([
memory
,
memory_ts
,
mail
,
mail_ts
])
.
then
(
callback
)
def
gather_memory
(
self
,
index
,
input_split
,
out_split
):
return
self
.
node_memory
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
node_memory_ts
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
mailbox
.
scatter_data
(
index
,
input_split
,
out_split
),
\
self
.
mailbox_ts
.
scatter_data
(
index
,
input_split
,
out_split
)
starrygl/sample/stream_manager.py
View file @
a70d1b3b
...
@@ -5,3 +5,7 @@ class WorkStreamEvent:
...
@@ -5,3 +5,7 @@ class WorkStreamEvent:
self
.
write_memory_stream
=
torch
.
cuda
.
Stream
()
self
.
write_memory_stream
=
torch
.
cuda
.
Stream
()
self
.
fetch_stream
=
torch
.
cuda
.
Stream
()
self
.
fetch_stream
=
torch
.
cuda
.
Stream
()
self
.
write_mail_stream
=
torch
.
cuda
.
Stream
()
self
.
write_mail_stream
=
torch
.
cuda
.
Stream
()
self
.
event
=
None
event
=
WorkStreamEvent
()
def
get_event
():
return
event
.
event
train_tgnn.py
View file @
a70d1b3b
...
@@ -82,8 +82,9 @@ def main():
...
@@ -82,8 +82,9 @@ def main():
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
ctx
=
DistributedContext
.
init
(
backend
=
"nccl"
,
use_gpu
=
True
)
device_id
=
torch
.
cuda
.
current_device
()
device_id
=
torch
.
cuda
.
current_device
()
print
(
'use cuda on'
,
device_id
)
print
(
'use cuda on'
,
device_id
)
pdata
=
partition_load
(
"./dataset/here/
WIKI
"
,
algo
=
"metis_for_tgnn"
)
pdata
=
partition_load
(
"./dataset/here/
GDELT
"
,
algo
=
"metis_for_tgnn"
)
graph
=
GraphData
(
pdata
=
pdata
)
graph
=
GraphData
(
pdata
=
pdata
)
sample_graph
=
TemporalNeighborSampleGraph
(
sample_graph
=
pdata
.
sample_graph
,
mode
=
'full'
)
sample_graph
=
TemporalNeighborSampleGraph
(
sample_graph
=
pdata
.
sample_graph
,
mode
=
'full'
)
mailbox
=
SharedMailBox
(
pdata
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
pdata
.
edge_attr
.
shape
[
1
]
if
pdata
.
edge_attr
is
not
None
else
0
)
mailbox
=
SharedMailBox
(
pdata
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
pdata
.
edge_attr
.
shape
[
1
]
if
pdata
.
edge_attr
is
not
None
else
0
)
sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
1
,
fanout
=
[
10
],
graph_data
=
sample_graph
,
workers
=
10
,
policy
=
'recent'
,
graph_name
=
"wiki_train"
)
sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
1
,
fanout
=
[
10
],
graph_data
=
sample_graph
,
workers
=
10
,
policy
=
'recent'
,
graph_name
=
"wiki_train"
)
...
@@ -102,17 +103,17 @@ def main():
...
@@ -102,17 +103,17 @@ def main():
trainloader
=
DistributedDataLoader
(
graph
,
train_data
,
sampler
=
sampler
,
trainloader
=
DistributedDataLoader
(
graph
,
train_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
True
,
drop_last
=
True
,
chunk_size
=
None
,
chunk_size
=
None
,
train
=
True
,
train
=
True
,
queue_size
=
100
,
queue_size
=
100
0
,
mailbox
=
mailbox
)
mailbox
=
mailbox
)
testloader
=
DistributedDataLoader
(
graph
,
test_data
,
sampler
=
sampler
,
testloader
=
DistributedDataLoader
(
graph
,
test_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
chunk_size
=
None
,
chunk_size
=
None
,
...
@@ -122,7 +123,7 @@ def main():
...
@@ -122,7 +123,7 @@ def main():
valloader
=
DistributedDataLoader
(
graph
,
val_data
,
sampler
=
sampler
,
valloader
=
DistributedDataLoader
(
graph
,
val_data
,
sampler
=
sampler
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
sampler_fn
=
SAMPLE_TYPE
.
SAMPLE_FROM_TEMPORAL_EDGES
,
neg_sampler
=
neg_sampler
,
neg_sampler
=
neg_sampler
,
batch_size
=
2
000
,
batch_size
=
1
000
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
chunk_size
=
None
,
chunk_size
=
None
,
...
@@ -158,9 +159,9 @@ def main():
...
@@ -158,9 +159,9 @@ def main():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
total_loss
=
0
total_loss
=
0
signal
=
torch
.
tensor
([
0
],
dtype
=
int
,
device
=
device
)
signal
=
torch
.
tensor
([
0
],
dtype
=
int
,
device
=
device
)
for
roots
,
mfgs
,
metadata
in
loader
:
for
roots
,
mfgs
,
metadata
in
loader
:
signal
[
0
]
=
0
dist
.
all_reduce
(
signal
,
async_op
=
False
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
total_loss
+=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
total_loss
+=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
total_loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
total_loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
...
@@ -190,32 +191,28 @@ def main():
...
@@ -190,32 +191,28 @@ def main():
src
,
dst
,
ts
,
edge_feats
,
src
,
dst
,
ts
,
edge_feats
,
model
.
module
.
memory_updater
.
last_updated_memory
,
model
.
module
.
memory_updater
.
last_updated_memory
,
)
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
mailbox
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
)
if
mode
==
'val'
:
val_losses
.
append
(
float
(
total_loss
))
while
(
signal
[
0
]
.
item
()
!=
dist
.
get_world_size
()):
signal
[
0
]
=
1
dist
.
all_reduce
(
signal
,
async_op
=
False
)
if
(
signal
[
0
]
.
item
()
==
dist
.
get_world_size
()):
break
if
mailbox
is
not
None
:
mailbox
.
set_mailbox_all_to_all
(
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
,
mailbox
.
memory_size
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
,
mailbox
.
mailbox
.
accessor
.
data
.
size
(
2
)),
torch
.
tensor
([],
device
=
device
)
.
reshape
(
-
1
),
reduce_Op
=
'max'
)
ap
=
float
(
torch
.
tensor
(
aps
)
.
mean
())
if
neg_samples
>
1
:
#ap = float(torch.tensor(aps).mean())
auc_mrr
=
float
(
torch
.
cat
(
aucs_mrrs
)
.
mean
())
#if neg_samples > 1:
else
:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
auc_mrr
=
float
(
torch
.
tensor
(
aucs_mrrs
)
.
mean
())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size
=
dist
.
get_world_size
()
apc
=
torch
.
empty
([
loader
.
expected_idx
*
world_size
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
auc_mrr
=
torch
.
empty
([
loader
.
expected_idx
*
world_size
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
apc
,
torch
.
tensor
(
aps
,
device
=
'cuda'
,
dtype
=
torch
.
float
))
dist
.
all_gather_into_tensor
(
auc_mrr
,
torch
.
tensor
(
aucs_mrrs
,
device
=
'cuda'
,
dtype
=
torch
.
float
))
ap
=
float
(
torch
.
tensor
(
apc
)
.
mean
())
auc_mrr
=
float
(
torch
.
tensor
(
auc_mrr
)
.
mean
())
return
ap
,
auc_mrr
return
ap
,
auc_mrr
creterion
=
torch
.
nn
.
BCEWithLogitsLoss
()
creterion
=
torch
.
nn
.
BCEWithLogitsLoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
train_param
[
'lr'
])
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
train_param
[
'lr'
])
for
e
in
range
(
train_param
[
'epoch'
]):
for
e
in
range
(
train_param
[
'epoch'
]):
torch
.
cuda
.
synchronize
()
epoch_start_time
=
time
.
time
()
epoch_start_time
=
time
.
time
()
train_aps
=
list
()
train_aps
=
list
()
print
(
'Epoch {:d}:'
.
format
(
e
))
print
(
'Epoch {:d}:'
.
format
(
e
))
...
@@ -229,14 +226,18 @@ def main():
...
@@ -229,14 +226,18 @@ def main():
model
.
module
.
memory_updater
.
last_updated_ts
=
None
model
.
module
.
memory_updater
.
last_updated_ts
=
None
for
roots
,
mfgs
,
metadata
in
trainloader
:
for
roots
,
mfgs
,
metadata
in
trainloader
:
t_prep_s
=
time
.
time
()
t_prep_s
=
time
.
time
()
optimizer
.
zero_grad
()
with
torch
.
cuda
.
stream
(
train_stream
):
with
torch
.
cuda
.
stream
(
train_stream
):
optimizer
.
zero_grad
()
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
)
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
total_loss
+=
float
(
loss
)
total_loss
+=
float
(
loss
)
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
#torch.cuda.synchronize()
t_prep_s
=
time
.
time
()
t_prep_s
=
time
.
time
()
y_pred
=
torch
.
cat
([
pred_pos
,
pred_neg
],
dim
=
0
)
.
sigmoid
()
.
cpu
()
y_pred
=
torch
.
cat
([
pred_pos
,
pred_neg
],
dim
=
0
)
.
sigmoid
()
.
cpu
()
y_true
=
torch
.
cat
([
torch
.
ones
(
pred_pos
.
size
(
0
)),
torch
.
zeros
(
pred_neg
.
size
(
0
))],
dim
=
0
)
y_true
=
torch
.
cat
([
torch
.
ones
(
pred_pos
.
size
(
0
)),
torch
.
zeros
(
pred_neg
.
size
(
0
))],
dim
=
0
)
...
@@ -262,7 +263,9 @@ def main():
...
@@ -262,7 +263,9 @@ def main():
src
,
dst
,
ts
,
edge_feats
,
src
,
dst
,
ts
,
edge_feats
,
model
.
module
.
memory_updater
.
last_updated_memory
,
model
.
module
.
memory_updater
.
last_updated_memory
,
)
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
mailbox
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
time_prep
=
time
.
time
()
-
epoch_start_time
time_prep
=
time
.
time
()
-
epoch_start_time
avg_time
+=
time
.
time
()
-
epoch_start_time
avg_time
+=
time
.
time
()
-
epoch_start_time
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment