Commit 04887eec by zlj

Merge branch 'cmy_dev' into doc-v2

parents b792c909 c2ce9bbc
...@@ -110,19 +110,6 @@ if(WITH_MTMETIS) ...@@ -110,19 +110,6 @@ if(WITH_MTMETIS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif() endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include") include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp) add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
......
cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
message("third_party dir is ${CMAKE_SOURCE_DIR}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
add_definitions(-DWITH_METIS)
set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
include_directories(${METIS_INCLUDE_DIRS})
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
add_definitions(-DWITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
# set(LDG_DIR "csrc/partition/neighbor_clustering")
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
# add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
ERROR:root:cannot import name 'libstarrygl' from 'starrygl.lib' (unknown location)
ERROR:root:unable to import libstarrygl.so, some features may not be available.
ERROR:root:cannot import name 'libstarrygl_sampler' from 'starrygl.lib' (unknown location)
ERROR:root:unable to import libstarrygl_sampler.so, some features may not be available.
libibverbs: Warning: couldn't load driver 'libsiw-rdmav34.so': libsiw-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libbnxt_re-rdmav34.so': libbnxt_re-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libhns-rdmav34.so': libhns-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libqedr-rdmav34.so': libqedr-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libmthca-rdmav34.so': libmthca-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libcxgb4-rdmav34.so': libcxgb4-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libhfi1verbs-rdmav34.so': libhfi1verbs-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libmlx4-rdmav34.so': libmlx4-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libefa-rdmav34.so': libefa-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libipathverbs-rdmav34.so': libipathverbs-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libocrdma-rdmav34.so': libocrdma-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libvmw_pvrdma-rdmav34.so': libvmw_pvrdma-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'librxe-rdmav34.so': librxe-rdmav34.so: cannot open shared object file: No such file or directory
get_neighbors consume: 14.4084ususe cuda on 0
33336461 7121936 9088019
torch.Size([2, 33336461]) 33336461
torch.Size([2, 9088019]) 9088019
torch.Size([2, 7121936]) 7121936
tensor([33336], device='cuda:0')
Epoch 0:
torch.Size([2, 1000]) 1000
Traceback (most recent call last):
File "train_tgnn.py", line 290, in <module>
main()
File "train_tgnn.py", line 215, in main
for roots,mfgs,metadata,sample_time in trainloader:
File "/home/zlj/starrygl/starrygl/sample/data_loader.py", line 146, in __next__
batch_data = graph_sample(self.graph,
File "/home/zlj/starrygl/starrygl/sample/batch_data.py", line 141, in graph_sample
return to_block(graph,data,out,mailbox,device)
File "/home/zlj/starrygl/starrygl/sample/batch_data.py", line 81, in to_block
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
File "/home/zlj/starrygl/starrygl/distributed/utils.py", line 195, in all_to_all_ind2ptr
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
TypeError: 'int' object is not callable
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.1
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -18,13 +18,15 @@ memory: ...@@ -18,13 +18,15 @@ memory:
dim_out: 100 dim_out: 100
gnn: gnn:
- arch: 'transformer_attention' - arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1 layer: 1
att_head: 2 att_head: 2
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 5 - epoch: 20
#batch_size: 100 batch_size: 200
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory"); m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device"); m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu"); m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
......
...@@ -114,11 +114,15 @@ edge_weight_dict = {} ...@@ -114,11 +114,15 @@ edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1 edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
# #
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', # partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
......
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
...@@ -4,4 +4,4 @@ Advanced Concepts ...@@ -4,4 +4,4 @@ Advanced Concepts
.. toctree:: .. toctree::
sampling_parallel/index sampling_parallel/index
partition_parallel/index partition_parallel/index
timeline_parallel/index timeline_parallel/index
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
Distributed Temporal Sampling
=============================
.. note::
基于分布式时序图采样的训练模式
\ No newline at end of file
...@@ -7,4 +7,4 @@ Package References ...@@ -7,4 +7,4 @@ Package References
memory memory
data_loader data_loader
graph_core graph_core
cache cache
\ No newline at end of file
...@@ -13,4 +13,4 @@ torch_spline_conv==1.2.2+pt21cu118 ...@@ -13,4 +13,4 @@ torch_spline_conv==1.2.2+pt21cu118
ogb ogb
tqdm tqdm
networkx networkx
\ No newline at end of file
...@@ -294,7 +294,7 @@ class DistributedTensor: ...@@ -294,7 +294,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
...@@ -308,7 +308,7 @@ class DistributedTensor: ...@@ -308,7 +308,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
......
...@@ -47,7 +47,7 @@ class GeneralModel(torch.nn.Module): ...@@ -47,7 +47,7 @@ class GeneralModel(torch.nn.Module):
self.edge_predictor = EdgePredictor(gnn_param['dim_out']) self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn': if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out']) self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1): def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node': if self.memory_param['type'] == 'node':
...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module): ...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module):
out = torch.stack(out, dim=0) out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :] out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id #metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None: if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0) #out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0) out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples) return self.edge_predictor(out, neg_samples=neg_samples)
......
import yaml import yaml
import numpy as np
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
...@@ -7,4 +7,32 @@ def parse_config(f): ...@@ -7,4 +7,32 @@ def parse_config(f):
memory_param = conf['memory'][0] memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0] gnn_param = conf['gnn'][0]
train_param = conf['train'][0] train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param return sample_param, memory_param, gnn_param, train_param
\ No newline at end of file
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
...@@ -192,7 +192,10 @@ class DistributedDataLoader: ...@@ -192,7 +192,10 @@ class DistributedDataLoader:
self.device) self.device)
self.recv_idxs += 1 self.recv_idxs += 1
assert batch_data is not None assert batch_data is not None
return batch_data end_event.record()
torch.cuda.synchronize()
sample_time = start_event.elapsed_time(end_event)
return *batch_data,sample_time
else : else :
raise StopIteration raise StopIteration
else: else:
......
...@@ -166,6 +166,7 @@ class DataSet: ...@@ -166,6 +166,7 @@ class DataSet:
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1] self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
...@@ -222,7 +223,7 @@ class DataSet: ...@@ -222,7 +223,7 @@ class DataSet:
class TemporalGraphData(DistributedGraphStore): class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device): def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device) super(DistributedTempoGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size): def _set_temporal_batch_cache(self,size,pin_size):
pass pass
def _load_feature_to_cuda(self,ids): def _load_feature_to_cuda(self,ids):
......
...@@ -299,7 +299,7 @@ class SharedMailBox(): ...@@ -299,7 +299,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory): memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -309,12 +309,14 @@ class SharedMailBox(): ...@@ -309,12 +309,14 @@ class SharedMailBox():
mem_src = memory[src] mem_src = memory[src]
mem_dst = memory[dst] mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None: if edge_feats is not None:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1) src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1]) mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
...@@ -324,7 +326,6 @@ class SharedMailBox(): ...@@ -324,7 +326,6 @@ class SharedMailBox():
index = unq_index index = unq_index
return index,mail,mail_ts return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts): def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
......
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment