Commit d8158bb9 by zlj

temp

parent b792c909
...@@ -174,3 +174,5 @@ cython_debug/ ...@@ -174,3 +174,5 @@ cython_debug/
/dataset /dataset
/test_* /test_*
/*.ipynb /*.ipynb
saved_models/
saved_checkpoints/
\ No newline at end of file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pt
/*.out
/a.out
/third_party
/.vscode
/run_route.py
/dataset
/test_*
/*.ipynb
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pt
/*.out
/a.out
/third_party
/.vscode
/run_route.py
/dataset
/test_*
/*.ipynb
saved_models/
saved_checkpoints/
\ No newline at end of file
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pt
/*.out
/a.out
/third_party
/.vscode
/run_route.py
/dataset
/test_*
/*.ipynb
saved_models/
saved_checkpoints/
\ No newline at end of file
<<<<<<< HEAD
[submodule "third_party/ldg_partition"]
path = third_party/ldg_partition
url = https://gitee.com/onlynagesha/graph-partition-v4
[submodule "third_party/METIS"]
path = third_party/METIS
url = https://github.com/KarypisLab/METIS
branch = v5.1.1-DistDGL-v0.5
=======
[submodule "csrc/partition/neighbor_clustering"]
path = csrc/partition/neighbor_clustering
url = https://gitee.com/onlynagesha/graph-partition-v4
>>>>>>> cmy_dev
[submodule "third_party/ldg_partition"]
path = third_party/ldg_partition
url = https://gitee.com/onlynagesha/graph-partition-v4
[submodule "third_party/METIS"]
path = third_party/METIS
url = https://github.com/KarypisLab/METIS
branch = v5.1.1-DistDGL-v0.5
cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
<<<<<<< HEAD
option(WITH_MTMETIS "Link to multi-threaded METIS when building" OFF)
=======
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
>>>>>>> cmy_dev
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
# add_definitions(-DWITH_METIS)
# set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
# set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
# set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
# file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
# set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
# file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
# include_directories(${METIS_INCLUDE_DIRS})
<<<<<<< HEAD
# add_library(metis_partition SHARED "csrc/partition/metis.cpp")
# target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
# target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
# target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
add_definitions(-DWITH_METIS)
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(METIS_GKLIB_DIR "${METIS_DIR}/GKlib")
file(GLOB METIS_SRCS "${METIS_DIR}/libmetis/*.c")
file(GLOB METIS_GKLIB_SRCS "${METIS_GKLIB_DIR}/*.c")
if (MSVC)
file(GLOB METIS_GKLIB_WIN32_SRCS "${METIS_GKLIB_DIR}/win32/*.c")
set(METIS_GKLIB_SRCS ${METIS_GKLIB_SRCS} ${METIS_GKLIB_WIN32_SRCS})
endif()
add_library(metis_partition SHARED
"csrc/partition/metis.cpp"
${METIS_SRCS} ${METIS_GKLIB_SRCS}
)
target_include_directories(metis_partition PRIVATE "${METIS_DIR}/include")
target_include_directories(metis_partition PRIVATE "${METIS_GKLIB_DIR}")
if (MSVC)
target_include_directories(metis_partition PRIVATE "${METIS_GKLIB_DIR}/win32")
endif()
target_compile_definitions(metis_partition PRIVATE -DIDXTYPEWIDTH=64)
target_compile_definitions(metis_partition PRIVATE -DREALTYPEWIDTH=32)
target_compile_options(metis_partition PRIVATE -O3)
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
if (UNIX)
target_link_libraries(metis_partition PRIVATE m)
endif()
=======
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
>>>>>>> cmy_dev
endif()
if(WITH_MTMETIS)
add_definitions(-DWITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
<<<<<<< HEAD
set(LDG_DIR "third_party/ldg_partition")
=======
set(LDG_DIR "csrc/partition/neighbor_clustering")
>>>>>>> cmy_dev
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
# set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
# include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" OFF)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
# add_definitions(-DWITH_METIS)
# set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
# set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
# set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
# file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
# set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
# file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
# include_directories(${METIS_INCLUDE_DIRS})
# add_library(metis_partition SHARED "csrc/partition/metis.cpp")
# target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
# target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
# target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
add_definitions(-DWITH_METIS)
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(METIS_GKLIB_DIR "${METIS_DIR}/GKlib")
file(GLOB METIS_SRCS "${METIS_DIR}/libmetis/*.c")
file(GLOB METIS_GKLIB_SRCS "${METIS_GKLIB_DIR}/*.c")
if (MSVC)
file(GLOB METIS_GKLIB_WIN32_SRCS "${METIS_GKLIB_DIR}/win32/*.c")
set(METIS_GKLIB_SRCS ${METIS_GKLIB_SRCS} ${METIS_GKLIB_WIN32_SRCS})
endif()
add_library(metis_partition SHARED
"csrc/partition/metis.cpp"
${METIS_SRCS} ${METIS_GKLIB_SRCS}
)
target_include_directories(metis_partition PRIVATE "${METIS_DIR}/include")
target_include_directories(metis_partition PRIVATE "${METIS_GKLIB_DIR}")
if (MSVC)
target_include_directories(metis_partition PRIVATE "${METIS_GKLIB_DIR}/win32")
endif()
target_compile_definitions(metis_partition PRIVATE -DIDXTYPEWIDTH=64)
target_compile_definitions(metis_partition PRIVATE -DREALTYPEWIDTH=32)
target_compile_options(metis_partition PRIVATE -O3)
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
if (UNIX)
target_link_libraries(metis_partition PRIVATE m)
endif()
endif()
if(WITH_MTMETIS)
add_definitions(-DWITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
# set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
# include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
#include "extension.h"
#include "uvm.h"
#include "partition.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
.value("cudaMemAdviseSetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseSetPreferredLocation)
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
<<<<<<< HEAD
m.def("metis_cache_friendly_reordering", &metis_cache_friendly_reordering, "metis cache-friendly reordering");
=======
>>>>>>> cmy_dev
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
}
#include "extension.h"
#include "uvm.h"
#include "partition.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
.value("cudaMemAdviseSetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseSetPreferredLocation)
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("metis_cache_friendly_reordering", &metis_cache_friendly_reordering, "metis cache-friendly reordering");
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
}
#include "extension.h"
#include "uvm.h"
#include "partition.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
.value("cudaMemAdviseSetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseSetPreferredLocation)
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("metis_cache_friendly_reordering", &metis_cache_friendly_reordering, "metis cache-friendly reordering");
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
}
Advanced Concepts
=================
.. toctree::
<<<<<<< HEAD
sampling_parallel/index
partition_parallel/index
timeline_parallel/index
=======
ts_sampling
pp_training
tp_training
data_proc
>>>>>>> cmy_dev
Advanced Concepts
=================
.. toctree::
sampling_parallel/index
partition_parallel/index
timeline_parallel/index
Package References
==================
.. toctree::
distributed
neighbor_sampler
memory
data_loader
graph_core
cache
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import starrygl
project = 'StarryGL'
copyright = '2023, StarryGL Team'
author = 'StarryGL Team'
version = starrygl.__version__
release = starrygl.__version__
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.duration",
"sphinx.ext.viewcode",
]
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
#!/bin/bash
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
#!/bin/bash
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.8/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
#!/bin/bash
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/.local/cuda-11.8" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
#!/bin/bash
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/.local/cuda-11.8" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.1+cu118
torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118
torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118
ogb
<<<<<<< HEAD
tqdm
networkx
=======
tqdm
>>>>>>> cmy_dev
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.1+cu118
torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118
torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118
ogb
tqdm
networkx
import torch
from torch import Tensor
from typing import *
import shutil
from pathlib import Path
from torch_sparse import SparseTensor
from starrygl.utils.partition import *
from starrygl.parallel.route import Route
from starrygl.parallel.sparse import SparseBlocks
from .utils import init_vc_edge_index
import logging
__all__ = [
"GraphData",
]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
num_nodes: Union[int, Dict[str, int]],
) -> None:
if isinstance(edge_indices, Tensor):
self._heterogeneous = False
edge_indices = {("#", "@", "#"): edge_indices}
num_nodes = {"#": int(num_nodes)}
else:
self._heterogeneous = True
self._num_nodes: Dict[str, int] = {}
self._node_data: Dict[str, 'NodeData'] = {}
for ntype, num in num_nodes.items():
ntype, num = str(ntype), int(num)
self._num_nodes[ntype] = num
self._node_data[ntype] = NodeData(ntype, num)
self._edge_indices: Dict[Tuple[str, str, str], Tensor] = {}
self._edge_data: Dict[Tuple[str, str, str], 'EdgeData'] = {}
for (es, et, ed), edge_index in edge_indices.items():
assert isinstance(edge_index, Tensor), f"edge_index must be a tensor, got {type(edge_index)}"
assert edge_index.dim() == 2 and edge_index.size(0) == 2
es, et, ed = str(es), str(et), str(ed)
assert es in self._num_nodes, f"unknown node type '{es}', should be one of {list(self._num_nodes.keys())}."
assert ed in self._num_nodes, f"unknown node type '{ed}', should be one of {list(self._num_nodes.keys())}."
etype = (es, et, ed)
self._edge_indices[etype] = edge_index
self._edge_data[etype] = EdgeData(etype, edge_index.size(1))
self._meta = MetaData()
def meta(self) -> 'MetaData':
return self._meta
def node(self, node_type: Optional[str] = None) -> 'NodeData':
if len(self._node_data) == 1:
for data in self._node_data.values():
return data
return self._node_data[node_type]
def edge(self, edge_type: Optional[Tuple[str, str, str]] = None) -> 'EdgeData':
if len(self._edge_data) == 1:
for data in self._edge_data.values():
return data
return self._edge_data[edge_type]
def edge_index(self, edge_type: Optional[Tuple[str, str, str]] = None) -> Tensor:
if len(self._edge_indices) == 1:
for data in self._edge_indices.values():
return data
return self._edge_indices[edge_type]
def node_types(self) -> List[str]:
return list(self._node_data.keys())
def edge_types(self) -> List[Tuple[str, str, str]]:
return list(self._edge_data.keys())
def to_route(self, group: Any = None) -> Route:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
def to_sparse(self, key: Optional[str] = None, group: Any = None) -> SparseBlocks:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
edge_index = self.edge_index()
edge_index = torch.vstack([
src_ids[edge_index[0]],
dst_ids[edge_index[1]],
])
edge_attr = None if key is None else self.edge()[key]
return SparseBlocks.from_raw_indices(dst_ids, edge_index, edge_attr=edge_attr, group=group)
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
def to(self, device: Any) -> 'GraphData':
self._meta.to(device)
for ndata in self._node_data.values():
ndata.to(device)
for edata in self._edge_data.values():
edata.to(device)
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self
@staticmethod
def from_bipartite(
edge_index: Tensor,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
raw_src_ids: Optional[Tensor] = None,
raw_dst_ids: Optional[Tensor] = None,
) -> 'GraphData':
if num_src_nodes is None:
num_src_nodes = raw_src_ids.numel()
if num_dst_nodes is None:
num_dst_nodes = raw_dst_ids.numel()
g = GraphData(
edge_indices={
("src", "@", "dst"): edge_index,
},
num_nodes={
"src": num_src_nodes,
"dst": num_dst_nodes,
}
)
if raw_src_ids is not None:
g.node("src")["raw_ids"] = raw_src_ids
if raw_dst_ids is not None:
g.node("dst")["raw_ids"] = raw_dst_ids
return g
@staticmethod
def from_pyg_data(data) -> 'GraphData':
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
g = GraphData(data.edge_index, data.num_nodes)
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == data.num_nodes:
g.node()[key] = val
elif val.size(0) == data.num_edges:
g.edge()[key] = val
elif isinstance(val, SparseTensor):
logging.warning(f"found sparse matrix {key}, but ignored.")
else:
g.meta()[key] = val
return g
@staticmethod
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algorithm}")
partition_kwargs = partition_kwargs or {}
<<<<<<< HEAD
not_self_loop = (edge_index[0] != edge_index[1])
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
edge_weight = edge_weight[not_self_loop]
if algorithm == "metis":
node_parts = metis_partition(
edge_index[:,not_self_loop],
=======
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
if algorithm == "metis":
node_parts = metis_partition(
edge_index,
>>>>>>> cmy_dev
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
<<<<<<< HEAD
edge_index[:,not_self_loop],
=======
edge_index,
>>>>>>> cmy_dev
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
<<<<<<< HEAD
edge_index[:,not_self_loop],
num_nodes, num_parts,
**partition_kwargs,
)
elif algorithm == "pyg-metis":
node_parts = pyg_metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
)
=======
edge_index,
num_nodes, num_parts,
**partition_kwargs,
)
>>>>>>> cmy_dev
else:
raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
local_edges = edge_index[:, epart_mask]
raw_src_ids, local_edges = init_vc_edge_index(
raw_dst_ids, local_edges, bipartite=True,
)
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
)
for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask]
for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask]
for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData:
def __init__(self) -> None:
self._data: Dict[str, Any] = {}
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, val: Any):
assert isinstance(key, str)
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'MetaData':
for k in self.keys():
v = self._data[k]
if isinstance(v, Tensor):
self._data[k] = v.to(device)
return self
class NodeData:
def __init__(self,
node_type: str,
num_nodes: int,
) -> None:
self._node_type = str(node_type)
self._num_nodes = int(num_nodes)
self._data: Dict[str, Tensor] = {}
@property
def node_type(self) -> str:
return self._node_type
@property
def num_nodes(self) -> int:
return self._num_nodes
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Tensor):
assert isinstance(key, str)
assert val.size(0) == self._num_nodes
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'NodeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
class EdgeData:
def __init__(self,
edge_type: Tuple[str, str, str],
num_edges: int,
) -> None:
self._edge_type = tuple(str(t) for t in edge_type)
self._num_edges = num_edges
assert len(self._edge_type) == 3
self._data: Dict[str, Tensor] = {}
@property
def edge_type(self) -> Tuple[str, str, str]:
return self._edge_type
@property
def num_edges(self) -> int:
return self._num_edges
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Optional[Tensor]) -> Tensor:
assert isinstance(key, str)
assert val.size(0) == self._num_edges
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
import torch
from torch import Tensor
from typing import *
import shutil
from pathlib import Path
from torch_sparse import SparseTensor
from starrygl.utils.partition import *
from starrygl.parallel.route import Route
from starrygl.parallel.sparse import SparseBlocks
from .utils import init_vc_edge_index
import logging
__all__ = [
"GraphData",
]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
num_nodes: Union[int, Dict[str, int]],
) -> None:
if isinstance(edge_indices, Tensor):
self._heterogeneous = False
edge_indices = {("#", "@", "#"): edge_indices}
num_nodes = {"#": int(num_nodes)}
else:
self._heterogeneous = True
self._num_nodes: Dict[str, int] = {}
self._node_data: Dict[str, 'NodeData'] = {}
for ntype, num in num_nodes.items():
ntype, num = str(ntype), int(num)
self._num_nodes[ntype] = num
self._node_data[ntype] = NodeData(ntype, num)
self._edge_indices: Dict[Tuple[str, str, str], Tensor] = {}
self._edge_data: Dict[Tuple[str, str, str], 'EdgeData'] = {}
for (es, et, ed), edge_index in edge_indices.items():
assert isinstance(edge_index, Tensor), f"edge_index must be a tensor, got {type(edge_index)}"
assert edge_index.dim() == 2 and edge_index.size(0) == 2
es, et, ed = str(es), str(et), str(ed)
assert es in self._num_nodes, f"unknown node type '{es}', should be one of {list(self._num_nodes.keys())}."
assert ed in self._num_nodes, f"unknown node type '{ed}', should be one of {list(self._num_nodes.keys())}."
etype = (es, et, ed)
self._edge_indices[etype] = edge_index
self._edge_data[etype] = EdgeData(etype, edge_index.size(1))
self._meta = MetaData()
def meta(self) -> 'MetaData':
return self._meta
def node(self, node_type: Optional[str] = None) -> 'NodeData':
if len(self._node_data) == 1:
for data in self._node_data.values():
return data
return self._node_data[node_type]
def edge(self, edge_type: Optional[Tuple[str, str, str]] = None) -> 'EdgeData':
if len(self._edge_data) == 1:
for data in self._edge_data.values():
return data
return self._edge_data[edge_type]
def edge_index(self, edge_type: Optional[Tuple[str, str, str]] = None) -> Tensor:
if len(self._edge_indices) == 1:
for data in self._edge_indices.values():
return data
return self._edge_indices[edge_type]
def node_types(self) -> List[str]:
return list(self._node_data.keys())
def edge_types(self) -> List[Tuple[str, str, str]]:
return list(self._edge_data.keys())
def to_route(self, group: Any = None) -> Route:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
return Route.from_raw_indices(src_ids, dst_ids, group=group)
def to_sparse(self, key: Optional[str] = None, group: Any = None) -> SparseBlocks:
src_ids = self.node("src")["raw_ids"]
dst_ids = self.node("dst")["raw_ids"]
edge_index = self.edge_index()
edge_index = torch.vstack([
src_ids[edge_index[0]],
dst_ids[edge_index[1]],
])
edge_attr = None if key is None else self.edge()[key]
return SparseBlocks.from_raw_indices(dst_ids, edge_index, edge_attr=edge_attr, group=group)
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
def to(self, device: Any) -> 'GraphData':
self._meta.to(device)
for ndata in self._node_data.values():
ndata.to(device)
for edata in self._edge_data.values():
edata.to(device)
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self
@staticmethod
def from_bipartite(
edge_index: Tensor,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
raw_src_ids: Optional[Tensor] = None,
raw_dst_ids: Optional[Tensor] = None,
) -> 'GraphData':
if num_src_nodes is None:
num_src_nodes = raw_src_ids.numel()
if num_dst_nodes is None:
num_dst_nodes = raw_dst_ids.numel()
g = GraphData(
edge_indices={
("src", "@", "dst"): edge_index,
},
num_nodes={
"src": num_src_nodes,
"dst": num_dst_nodes,
}
)
if raw_src_ids is not None:
g.node("src")["raw_ids"] = raw_src_ids
if raw_dst_ids is not None:
g.node("dst")["raw_ids"] = raw_dst_ids
return g
@staticmethod
def from_pyg_data(data) -> 'GraphData':
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
g = GraphData(data.edge_index, data.num_nodes)
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == data.num_nodes:
g.node()[key] = val
elif val.size(0) == data.num_edges:
g.edge()[key] = val
elif isinstance(val, SparseTensor):
logging.warning(f"found sparse matrix {key}, but ignored.")
else:
g.meta()[key] = val
return g
@staticmethod
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algorithm}")
partition_kwargs = partition_kwargs or {}
not_self_loop = (edge_index[0] != edge_index[1])
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
edge_weight = edge_weight[not_self_loop]
if algorithm == "metis":
node_parts = metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
**partition_kwargs,
)
elif algorithm == "pyg-metis":
node_parts = pyg_metis_partition(
edge_index[:,not_self_loop],
num_nodes, num_parts,
)
else:
raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
local_edges = edge_index[:, epart_mask]
raw_src_ids, local_edges = init_vc_edge_index(
raw_dst_ids, local_edges, bipartite=True,
)
g = GraphData.from_bipartite(
local_edges,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
)
for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask]
for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask]
for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
class MetaData:
def __init__(self) -> None:
self._data: Dict[str, Any] = {}
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, val: Any):
assert isinstance(key, str)
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'MetaData':
for k in self.keys():
v = self._data[k]
if isinstance(v, Tensor):
self._data[k] = v.to(device)
return self
class NodeData:
def __init__(self,
node_type: str,
num_nodes: int,
) -> None:
self._node_type = str(node_type)
self._num_nodes = int(num_nodes)
self._data: Dict[str, Tensor] = {}
@property
def node_type(self) -> str:
return self._node_type
@property
def num_nodes(self) -> int:
return self._num_nodes
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Tensor):
assert isinstance(key, str)
assert val.size(0) == self._num_nodes
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'NodeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
class EdgeData:
def __init__(self,
edge_type: Tuple[str, str, str],
num_edges: int,
) -> None:
self._edge_type = tuple(str(t) for t in edge_type)
self._num_edges = num_edges
assert len(self._edge_type) == 3
self._data: Dict[str, Tensor] = {}
@property
def edge_type(self) -> Tuple[str, str, str]:
return self._edge_type
@property
def num_edges(self) -> int:
return self._num_edges
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Optional[Tensor]) -> Tensor:
assert isinstance(key, str)
assert val.size(0) == self._num_edges
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.types import Number
from typing import *
from torch_sparse import SparseTensor
from .cclib import all_to_all_s
class TensorAccessor:
def __init__(self, data: Tensor) -> None:
from .context import DistributedContext
self._data = data
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
<<<<<<< HEAD
else:
self._rref = None
=======
>>>>>>> cmy_dev
self.stream = torch.cuda.Stream()
@property
def data(self):
return self._data
@property
def rref(self):
return self._rref
@property
def ctx(self):
return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
@staticmethod
def _index_select(data: Tensor, dim: int, index: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data = data.index_select(dim, index)
fut = torch.futures.Future()
fut.set_result(data)
return fut
@staticmethod
def _index_copy_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_copy_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def _index_add_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_add_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def get_stream() -> Optional[torch.cuda.Stream]:
global _TENSOR_ACCESSOR_STREAM
if torch.cuda.is_available():
return None
if _TENSOR_ACCESSOR_STREAM is None:
_TENSOR_ACCESSOR_STREAM = torch.cuda.Stream()
return _TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM: Optional[torch.cuda.Stream] = None
class DistInt:
def __init__(self, sizes: List[int]) -> None:
self._data = tuple([int(t) for t in sizes])
self._total = sum(self._data)
def __getitem__(self, idx: int) -> int:
return self._data[idx]
def __call__(self) -> int:
return self._total
class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None:
self._data = index.long()
else:
index, part_ids = index.long(), part_ids.long()
self._data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property
def loc(self) -> Tensor:
return self._data & 0xFFFFFFFFFFFF
@property
def part(self) -> Tensor:
return (self._data >> 48) & 0xFFFF
@property
def dist(self) -> Tensor:
return self._data
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def to(self,device) -> Tensor:
return DistIndex(self._data.to(device))
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
<<<<<<< HEAD
if self.accessor.rref is not None:
=======
if hasattr(self,'_rref'):
>>>>>>> cmy_dev
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
<<<<<<< HEAD
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
=======
>>>>>>> cmy_dev
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property
def dtype(self):
return self.accessor.data.dtype
@property
def device(self):
return self.accessor.data.device
@property
def num_nodes(self) -> int:
return self._num_nodes
@property
<<<<<<< HEAD
def num_part_nodes(self):# -> tuple[int,...]:
=======
def num_part_nodes(self) -> Tuple[int,...]:
>>>>>>> cmy_dev
return self._num_part_nodes
@property
def part_id(self) -> int:
return self._part_id
@property
def num_parts(self) -> int:
return self._num_parts
def to(self,device):
return self.accessor.data.to(device)
def __getitem__(self,index):
return self.accessor.data[index]
@property
def ctx(self):
return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
recv_ptr = torch.zeros(recv_sizes.numel() + 1).type_as(recv_sizes)
recv_ptr[1:] = recv_sizes.cumsum(dim=0)
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
"send_ptr": send_ptr,
"recv_ptr": recv_ptr,
"recv_ind": recv_ind,
}
def all_to_all_get(self,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
def all_to_all_set(self,
data: Tensor,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
def callback(fs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
result: Optional[Tensor] = None
for i, f in enumerate(fs.value()):
t: Tensor = f.value()
if result is None:
result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
)
result[part_idx == i] = t
return result
return torch.futures.collect_all(futs).then(callback)
def index_copy_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
def index_add_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
\ No newline at end of file
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.types import Number
from typing import *
from torch_sparse import SparseTensor
from .cclib import all_to_all_s
class TensorAccessor:
def __init__(self, data: Tensor) -> None:
from .context import DistributedContext
self._data = data
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
else:
self._rref = None
self.stream = torch.cuda.Stream()
@property
def data(self):
return self._data
@property
def rref(self):
return self._rref
@property
def ctx(self):
return self._ctx
@staticmethod
@rpc.functions.async_execution
def _index_selet(self):
fut = torch.futures.Future()
fut.set_result(None)
return fut
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_exec(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
@staticmethod
def _index_select(data: Tensor, dim: int, index: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data = data.index_select(dim, index)
fut = torch.futures.Future()
fut.set_result(data)
return fut
@staticmethod
def _index_copy_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_copy_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def _index_add_(data: Tensor, dim: int, index: Tensor, source: Tensor):
stream = TensorAccessor.get_stream()
with torch.cuda.stream(stream):
data.index_add_(dim, index, source)
fut = torch.futures.Future()
fut.set_result(None)
return fut
@staticmethod
def get_stream() -> Optional[torch.cuda.Stream]:
global _TENSOR_ACCESSOR_STREAM
if torch.cuda.is_available():
return None
if _TENSOR_ACCESSOR_STREAM is None:
_TENSOR_ACCESSOR_STREAM = torch.cuda.Stream()
return _TENSOR_ACCESSOR_STREAM
_TENSOR_ACCESSOR_STREAM: Optional[torch.cuda.Stream] = None
class DistInt:
def __init__(self, sizes: List[int]) -> None:
self._data = tuple([int(t) for t in sizes])
self._total = sum(self._data)
def __getitem__(self, idx: int) -> int:
return self._data[idx]
def __call__(self) -> int:
return self._total
class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None:
self._data = index.long()
else:
index, part_ids = index.long(), part_ids.long()
self._data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property
def loc(self) -> Tensor:
return self._data & 0xFFFFFFFFFFFF
@property
def part(self) -> Tensor:
return (self._data >> 48) & 0xFFFF
@property
def dist(self) -> Tensor:
return self._data
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def to(self,device) -> Tensor:
return DistIndex(self._data.to(device))
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
if self.accessor.rref is not None:
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property
def dtype(self):
return self.accessor.data.dtype
@property
def device(self):
return self.accessor.data.device
@property
def num_nodes(self) -> int:
return self._num_nodes
@property
def num_part_nodes(self):# -> tuple[int,...]:
return self._num_part_nodes
@property
def part_id(self) -> int:
return self._part_id
@property
def num_parts(self) -> int:
return self._num_parts
def to(self,device):
return self.accessor.data.to(device)
def __getitem__(self,index):
return self.accessor.data[index]
@property
def ctx(self):
return self.accessor.ctx
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
recv_ptr = torch.zeros(recv_sizes.numel() + 1).type_as(recv_sizes)
recv_ptr[1:] = recv_sizes.cumsum(dim=0)
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
"send_ptr": send_ptr,
"recv_ptr": recv_ptr,
"recv_ind": recv_ind,
}
def all_to_all_get(self,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
) -> Tensor:
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
def all_to_all_set(self,
data: Tensor,
dist_index: Union[Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is not None:
dist_dict = self.all_to_all_ind2ptr(dist_index)
send_ptr = dist_dict["send_ptr"]
recv_ptr = dist_dict["recv_ptr"]
recv_ind = dist_dict["recv_ind"]
recv = torch.empty(recv_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
all_to_all_s(recv, data, recv_ptr, send_ptr,group=group)
self.accessor.data.index_copy_(0, recv_ind, recv)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
def callback(fs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
result: Optional[Tensor] = None
for i, f in enumerate(fs.value()):
t: Tensor = f.value()
if result is None:
result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
)
result[part_idx == i] = t
return result
return torch.futures.collect_all(futs).then(callback)
def index_copy_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
def index_add_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts):
mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
\ No newline at end of file
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
<<<<<<< HEAD
=======
>>>>>>> cmy_dev
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
from starrygl.sample.stream_manager import PipelineManager, getPipelineManger
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda'),
async_op = False):
out = sample_fn(sampler,data,neg_sampling)
if async_op == False:
return to_block(graph,data,out,mailbox,device)
else:
manger = getPipelineManger()
future = manger.submit('lookup',to_block,{'graph':graph,'data':data,\
'sample_out':out,\
'mailbox':mailbox,\
'device':device})
return future
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
<<<<<<< HEAD
if self.is_pipeline is False:
=======
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if(dist.get_world_size() > 0):
>>>>>>> cmy_dev
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
end_event.record()
torch.cuda.synchronize()
sample_time = start_event.elapsed_time(end_event)
return *batch_data,sample_time
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
end_event.record()
torch.cuda.synchronize()
sample_time = start_event.elapsed_time(end_event)
return *batch_data,sample_time
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
return *batch_data
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
return batch_data
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
We will perform feature fetch in the data loader.
you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
is_pipeline = False,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
self.is_pipeline = is_pipeline
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
torch.distributed.barrier()
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if self.is_pipeline is False:
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
return batch_data
else :
raise StopIteration
else:
if self.recv_idxs == 0:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
else:
if(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
self.result_queue.popleft()
batch_data = result.result()
self.recv_idxs += 1
else:
raise StopIteration
if(self.recv_idxs+1<=self.expected_idx):
data = self._next_data()
next_batch = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device,
async_op=True)
self.result_queue.append(next_batch)
return batch_data
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedTempoGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
import starrygl
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False,
uvm_edge = False):
self.device = device
self.ids = pdata.ids.to(device)
self.eids = pdata.eids
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
self.num_edges = self.eids_mapper.data.shape[0]
world_size = dist.get_world_size()
self.uvm_node = uvm_node
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = x
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
ctx = DistributedContext.get_default_context()
pdata.edge_attr = pdata.edge_attr.to(torch.float)
if uvm_edge == False :
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
self.edge_attr = edge_attr
else:
self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
if self.x.rrefs is None or asyncOp is False:
ids = self.x.all_to_all_ind2ptr(ids)
return self.x.all_to_all_get(**ids)
return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids]
else:
if self.edge_attr.rrefs is None or asyncOp is False:
ids = self.edge_attr.all_to_all_ind2ptr(ids)
return self.edge_attr.all_to_all_get(**ids)
return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)]
class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device):
super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
import starrygl
from typing import Union
from typing import List
from typing import Optional
import torch
from torch.distributed import rpc
import torch_scatter
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist
#from starrygl.utils.uvm import cudaMemoryAdvise
class SharedMailBox():
'''
We will first define our mailbox, including our definitions of mialbox and memory:
.. code-block:: python
from starrygl.sample.memory.shared_mailbox import SharedMailBox
mailbox = SharedMailBox(num_nodes=num_nodes, memory_param=memory_param, dim_edge_feat=dim_edge_feat)
Args:
num_nodes (int): number of nodes
memory_param (dict): the memory parameters in the yaml file,refer to TGL
dim_edge_feat (int): the dim of edge feature
device (torch.device): the device used to store MailBox
uvm (bool): 1-use uvm, 0-don't use uvm
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.memory.shared_mailbox import SharedMailBox
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
We then need to hand over the mailbox to the data loader as in the above example, so that the relevant memory/mailbox can be directly loaded during training.
During the training, we will call `get_update_memory`/`get_update_mail` function constantly updates
the relevant storage,which is the idea related to TGN.
'''
def __init__(self,
num_nodes,
memory_param,
dim_edge_feat,
device = torch.device('cuda'),
uvm = False):
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
if memory_param['type'] != 'node':
raise NotImplementedError
self.memory_param = memory_param
self.memory_size = memory_param['dim_out']
assert not (device.type =='cpu' and uvm is True),\
'set uvm must set device on cuda'
memory_device = device
if device.type == 'cuda' and uvm is True:
memory_device = torch.device('cpu')
node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =memory_device)
node_memory_ts = torch.zeros(self.num_nodes,
dtype=torch.float32,
device = self.device)
mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat,
device = memory_device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']),
dtype=torch.float32,device = self.device)
self.uvm = uvm
if uvm is True:
ctx = DistributedContext.get_default_context()
node_memory = starrygl.utils.uvm.uvm_empty(*node_memory.shape,
dtype=node_memory.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(node_memory,device = ctx.device)
starrygl.utils.uvm.uvm_advise(node_memory,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(node_memory)
mailbox = starrygl.utils.uvm.uvm_empty(*mailbox.shape,
dtype=mailbox.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(mailbox,device = ctx.device)
starrygl.utils.uvm.vm_advise(mailbox,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(mailbox)
self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox)
self.mailbox_ts = DistributedTensor(mailbox_ts)
self.next_mail_pos = torch.zeros((self.num_nodes),
dtype=torch.long,
device = self.device)
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc == True:
self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
def reset(self):
self.node_memory.accessor.data.zero_()
self.node_memory_ts.accessor.data.zero_()
self.mailbox.accessor.data.zero_()
self.mailbox_ts.accessor.data.zero_()
self.next_mail_pos.zero_()
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max':
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.node_memory.accessor.data[index] = source
self.node_memory_ts.accessor.data[index] = source_ts
def set_mailbox_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[index] = torch.remainder(
self.next_mail_pos[index] + 1,
self.memory_param['mailbox_size'])
def set_memory_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_memory_local(index,source,source_ts)
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_memory_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def add_to_mailbox_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_mailbox_local(index,source,source_ts)
else:
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_mailbox_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None,group = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split,group = group)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
dst = dst.to(self.device)
index = torch.cat([src, dst]).reshape(-1)
index = dist_indx_mapper[index]
mem_src = memory[src]
mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
mail_ts = max_ts
mail = mail[idx]
index = unq_index
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts
memory = memory[idx]
index = unq_index
return index,memory,ts
def get_memory(self,index):
if self.num_parts == 1:
return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index]
elif self.node_memory.rrefs is None:
return self.gather_memory(dist_index = index)
else:
memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index)
mail = self.mailbox.index_select(index)
mail_ts = self.mailbox_ts.index_select(index)
def callback(fs):
memory,memory_ts,mail,mail_ts = fs.value()
memory = memory.value()
memory_ts = memory_ts.value()
mail = mail.value()
mail_ts = mail_ts.value()
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(
self,
dist_index: Union[torch.Tensor, DistIndex, None] = None,
send_ptr: Optional[List[int]] = None,
recv_ptr: Optional[List[int]] = None,
recv_ind: Optional[List[int]] = None,
group = None
):
if dist_index is None:
return self.node_memory.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.node_memory_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group),\
self.mailbox_ts.all_to_all_get(dist_index,send_ptr,recv_ptr,recv_ind,group)
else:
ids = self.node_memory.all_to_all_ind2ptr(dist_index)
return self.node_memory.all_to_all_get(**ids,group = group),\
self.node_memory_ts.all_to_all_get(**ids,group = group),\
self.mailbox.all_to_all_get(**ids,group = group),\
self.mailbox_ts.all_to_all_get(**ids,group = group)
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
<<<<<<< HEAD
=======
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
>>>>>>> cmy_dev
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
node_weight = None,
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
else:
for i, pdata in enumerate(partition_data_for_gnn(data, num_parts,
algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_data_for_gnn(data: Data, num_parts: int,
algo: str, verbose: bool = False):
if algo == "metis":
part_fn = metis_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose:
print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose:
print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
eids = torch.zeros(num_edges, dtype=torch.long)
len = 0
edgeptr = torch.zeros(num_parts+1, dtype=eids.dtype)
for i in range(num_parts):
epart_i = torch.where(edge_parts == i)[0]
eids[epart_i] = torch.arange(epart_i.shape[0]) + len
len += epart_i.shape[0]
edgeptr[i+1] = len
data.eids = eids
data.sample_graph.sample_eids = eids[data.sample_graph.sample_eid]
nids = torch.zeros(num_nodes, dtype=torch.long)
len = 0
partptr = torch.zeros(num_parts+1, dtype=nids.dtype)
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
nids[npart_i] = torch.arange(npart_i.shape[0]) + len
len += npart_i.shape[0]
partptr[i+1] = len
data.edge_index = nids[data.edge_index]
data.sample_graph.edge_index = nids[data.sample_graph.edge_index]
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:, epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
"sample_graph": data.sample_graph,
"partptr": partptr,
"edgeptr": edgeptr
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def _nopart(edge_index: torch.LongTensor, num_nodes: int):
node_parts = torch.zeros(num_nodes, dtype=torch.long)
if isinstance(edge_index, torch.Tensor):
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
return node_parts
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
node_weight = None,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
edge_list = []
weight_list = []
for i,key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edge_list.append(v)
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
edge_index = torch.cat(edge_list,dim = 1)
edge_weight = torch.cat(weight_list,dim = 0)
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
"""
weight: 各种工作负载边划分权重
按照点均衡划分
"""
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index_dict = data.edge_index_dict
tgnn_norm = compute_temporal_norm(data.edge_index, data.edge_ts, num_nodes)
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
data.eids = eids
data.sample_graph['eids'] = eids[data.sample_graph['eids']]
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
pdata = {
"ids": npart_i,
"tgnn_norm": tgnn_norm,
"edge_index": data.edge_index[:, epart_i],
"sample_graph": data.sample_graph
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "edge_index_dict" \
or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def metis_partition(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
nodes = torch.tensor(list(G.adj.keys()))
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def metis_partition_bydegree(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
value, counts = torch.unique(edge_index[1, :].view(-1), return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.graph['node_weight_attr'] = 'weight'
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
node_weight = None,
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
else:
for i, pdata in enumerate(partition_data_for_gnn(data, num_parts,
algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_data_for_gnn(data: Data, num_parts: int,
algo: str, verbose: bool = False):
if algo == "metis":
part_fn = metis_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose:
print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose:
print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
eids = torch.zeros(num_edges, dtype=torch.long)
len = 0
edgeptr = torch.zeros(num_parts+1, dtype=eids.dtype)
for i in range(num_parts):
epart_i = torch.where(edge_parts == i)[0]
eids[epart_i] = torch.arange(epart_i.shape[0]) + len
len += epart_i.shape[0]
edgeptr[i+1] = len
data.eids = eids
data.sample_graph.sample_eids = eids[data.sample_graph.sample_eid]
nids = torch.zeros(num_nodes, dtype=torch.long)
len = 0
partptr = torch.zeros(num_parts+1, dtype=nids.dtype)
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
nids[npart_i] = torch.arange(npart_i.shape[0]) + len
len += npart_i.shape[0]
partptr[i+1] = len
data.edge_index = nids[data.edge_index]
data.sample_graph.edge_index = nids[data.sample_graph.edge_index]
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:, epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
"sample_graph": data.sample_graph,
"partptr": partptr,
"edgeptr": edgeptr
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def _nopart(edge_index: torch.LongTensor, num_nodes: int):
node_parts = torch.zeros(num_nodes, dtype=torch.long)
if isinstance(edge_index, torch.Tensor):
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
return node_parts
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
node_weight = None,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
edge_list = []
weight_list = []
for i,key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edge_list.append(v)
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
edge_index = torch.cat(edge_list,dim = 1)
edge_weight = torch.cat(weight_list,dim = 0)
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
"""
weight: 各种工作负载边划分权重
按照点均衡划分
"""
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index_dict = data.edge_index_dict
tgnn_norm = compute_temporal_norm(data.edge_index, data.edge_ts, num_nodes)
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
data.eids = eids
data.sample_graph['eids'] = eids[data.sample_graph['eids']]
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
pdata = {
"ids": npart_i,
"tgnn_norm": tgnn_norm,
"edge_index": data.edge_index[:, epart_i],
"sample_graph": data.sample_graph
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "edge_index_dict" \
or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def metis_partition(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
nodes = torch.tensor(list(G.adj.keys()))
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def metis_partition_bydegree(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
value, counts = torch.unique(edge_index[1, :].view(-1), return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.graph['node_weight_attr'] = 'weight'
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
import starrygl
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import math
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
# from sample_cores import ParallelSampler, get_neighbors, heads_unique
from torch.distributed.rpc import rpc_async
class NeighborSampler(BaseSampler):
r'''
Parallel sampling is crucial for expanding model training to a large amount of data.Due to the large scale and complexity of graph data, traditional serial sampling may lead to significant waste of computing and storage resources. The significance of parallel sampling lies in improving the efficiency and overall computational speed of sampling by simultaneously sampling from multiple nodes or neighbors.
This helps to accelerate the training and inference process of the model, making it more scalable and practical when dealing with large-scale graph data.
Our parallel sampling adopts a hybrid approach of CPU and GPU, where the entire graph structure is stored on the CPU and then uploaded to the GPU after sampling the graph structure on the CPU. Each trainer has a separate sampler for parallel training.
We have encapsulated the functions for parallel sampling, and you can easily use them in the following ways:
.. code-block:: python
# First,you need to import Python packages
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
# Then,you can use ours parallel sampler
sampler = NeighborSampler(num_nodes=num_nodes, num_layers=num_layers, fanout=fanout, graph_data=graph_data,
workers=workers, is_distinct = is_distinct, policy = policy, edge_weight= edge_weight, graph_name = graph_name)
Args:
num_nodes (int): the num of all nodes in the graph
num_layers (int): the num of layers to be sampled
fanout (list): the list of max neighbors' number chosen for each layer
graph_data (:class: starrygl.sample.sample_core.neighbor_sampler): the graph data you want to sample
workers (int): the number of threads, default value is 1
is_distinct (bool): 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy (str): "uniform" or "recent" or "weighted"
edge_weight (torch.Tensor,Optional): the initial weights of edges
graph_name (str): the name of graph should provide edge_index or (neighbors, deg)
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],
graph_data=sample_graph, workers=15, policy = 'recent', graph_name = "wiki_train")
If you want to directly call parallel sampling functions, use the following methods:
.. code-block:: python
# the parameter meaning is the same as the `Args` above
from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors
# get neighbor infomation table,row and col come from graph_data.edge_index=(row, col)
tnb = get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, graph_data. eid, edge_weight, timestamp)
# call parallel sampler
p_sampler = ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers, fanout, num_layers, policy)
For complete usage and more details, please refer to `~starrygl.sample.sample_core.neighbor_sampler`
'''
def __init__(
self,
num_nodes: int,
num_layers: int,
fanout: list,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy: "uniform" or "recent" or "weighted"
edge_weight: the initial weights of edges
graph_name: the name of graph
should provide edge_index or (neighbors, deg)
"""
super().__init__()
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = workers # min(workers, torch.get_num_threads())
self.fanout = fanout
self.num_nodes = num_nodes
self.graph_data=graph_data
self.policy = policy
self.is_distinct = is_distinct
assert graph_name is not None
self.graph_name = graph_name
if(tnb is None):
if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort()
timestamp = timestamp.float().contiguous()
eid = graph_data.eid[ind].contiguous()
row, col = graph_data.edge_index[:,ind]
else:
eid = graph_data.eid
timestamp = None
row, col = graph_data.edge_index
if(edge_weight is not None):
edge_weight = edge_weight.float().contiguous()
self.tnb = starrygl.sampler_ops.get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, eid, edge_weight, timestamp)
else:
assert tnb is not None
self.tnb = tnb
self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy)
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
ts: Optional[torch.Tensor] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index,
ts: the timestamp of nodes, optional,
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
fanout_index: optional. Specify the index to fanout
Returns:
sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled
"""
if(ts is None):
self.part_unique = True
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None, self.part_unique)
ret = self.p_sampler.get_ret()
return ret
else:
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.float().contiguous(), None)
ret = self.p_sampler.get_ret()
return ret
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
ets: the timestamp of edges, optional
neg_sampling: The negative sampling configuration
Returns:
sampled_edge_index_list: the edges sampled
sampled_eid_list: the edges' id sampled
sampled_delta_ts_list:the edges' delta time sampled
metadata: other infomation
"""
src, dst = edges
num_pos = src.numel()
num_neg = 0
with_timestap = ets is not None
seed_ts = None
if neg_sampling is not None:
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg, dst_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets, ets], dim=0)
if neg_sampling.is_triplet():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets], dim=0)
else:
seed = torch.cat([src, dst], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets], dim=0)
# 去重负采样
if neg_sampling is not None and neg_sampling.unique:
if with_timestap: # ts操作
pair, inverse_seed= torch.unique(torch.stack([seed, seed_ts],0), return_inverse=True, dim=1)
seed, seed_ts = pair
seed = seed.long()
else:
seed, inverse_seed = seed.unique(return_inverse=True)
out = self.sample_from_nodes(seed, seed_ts, with_outer_sample)
if neg_sampling is None or (not neg_sampling.unique):
if with_timestap:
return out, {'seed':seed,'seed_ts':seed_ts,
'src_pos_index':torch.arange(0,num_pos),
'dst_pos_index':torch.arange(num_pos,2*num_pos),
'src_neg_index':torch.arange(2*num_pos,3*num_pos)}
else:
return out, {'seed':seed,
'src_pos_index':slice(0,num_pos),
'dst_pos_index':slice(num_pos,2*num_pos),
'src_neg_index':slice(2*num_pos,3*num_pos)}
metadata = {}
if neg_sampling.is_binary():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:3 * num_pos]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
dst_neg_index = inverse_seed[3 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
metadata = {'seed':seed, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index, 'dst_neg_index':dst_neg_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
elif neg_sampling.is_triplet():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
# src_index是seed里src点的索引
# dst_pos_index是seed里dst_pos点的索引
# dst_neg_index是seed里dst_neg点的索引
metadata = {'seed':seed, 'src_pos_index':src_pos_index, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
# sampled_nodes最前方是原始序列的采样起点也就是去重后的seed
return out, metadata
if __name__=="__main__":
# edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
# [1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
edge_weight1 = None
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
sampler = NeighborSampler(num_nodes=num_nodes1,
num_layers=3,
fanout=[2, 1, 1],
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
ts=torch.tensor([1, 2]),
with_outer_sample=SampleType.Whole)
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('delta_ts_list:', out.delta_ts_list)
print('metadata: ', out.metadata)
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
<<<<<<< HEAD
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
=======
from starrygl.module.utils import parse_config, EarlyStopMonitor
>>>>>>> cmy_dev
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
<<<<<<< HEAD
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
=======
graph = DistributedGraphStore(pdata = pdata)
>>>>>>> cmy_dev
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
<<<<<<< HEAD
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
=======
for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
>>>>>>> cmy_dev
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
<<<<<<< HEAD
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
=======
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
torch.cuda.synchronize()
write_back_time += start_event.elapsed_time(end_event)/1000
>>>>>>> cmy_dev
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
<<<<<<< HEAD
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
=======
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
>>>>>>> cmy_dev
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
<<<<<<< HEAD
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
=======
graph = DistributedGraphStore(pdata = pdata)
>>>>>>> cmy_dev
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
<<<<<<< HEAD
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
=======
for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
>>>>>>> cmy_dev
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
<<<<<<< HEAD
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
=======
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
torch.cuda.synchronize()
write_back_time += start_event.elapsed_time(end_event)/1000
>>>>>>> cmy_dev
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
<<<<<<< HEAD
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
=======
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
>>>>>>> cmy_dev
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
<<<<<<< HEAD
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
=======
graph = DistributedGraphStore(pdata = pdata)
>>>>>>> cmy_dev
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
<<<<<<< HEAD
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
=======
for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
>>>>>>> cmy_dev
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
<<<<<<< HEAD
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
=======
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
torch.cuda.synchronize()
write_back_time += start_event.elapsed_time(end_event)/1000
>>>>>>> cmy_dev
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
<<<<<<< HEAD
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
=======
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
>>>>>>> cmy_dev
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
#
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']):
torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
main()
...@@ -110,19 +110,6 @@ if(WITH_MTMETIS) ...@@ -110,19 +110,6 @@ if(WITH_MTMETIS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif() endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include") include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp) add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
......
cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
message("third_party dir is ${CMAKE_SOURCE_DIR}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
add_definitions(-DWITH_METIS)
set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
include_directories(${METIS_INCLUDE_DIRS})
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
add_definitions(-DWITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
# set(LDG_DIR "csrc/partition/neighbor_clustering")
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
# add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.1
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -18,13 +18,15 @@ memory: ...@@ -18,13 +18,15 @@ memory:
dim_out: 100 dim_out: 100
gnn: gnn:
- arch: 'transformer_attention' - arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1 layer: 1
att_head: 2 att_head: 2
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 5 - epoch: 20
#batch_size: 100 batch_size: 200
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory"); m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device"); m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu"); m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
......
...@@ -114,11 +114,15 @@ edge_weight_dict = {} ...@@ -114,11 +114,15 @@ edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1 edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
# #
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', # partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
......
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
Distributed Temporal Sampling
=============================
.. note::
基于分布式时序图采样的训练模式
\ No newline at end of file
...@@ -294,7 +294,7 @@ class DistributedTensor: ...@@ -294,7 +294,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
...@@ -308,7 +308,7 @@ class DistributedTensor: ...@@ -308,7 +308,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
......
...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module): ...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module):
out = torch.stack(out, dim=0) out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :] out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id #metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None: if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0) #out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0) out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples) return self.edge_predictor(out, neg_samples=neg_samples)
......
import yaml import yaml
import numpy as np
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
...@@ -8,3 +8,31 @@ def parse_config(f): ...@@ -8,3 +8,31 @@ def parse_config(f):
gnn_param = conf['gnn'][0] gnn_param = conf['gnn'][0]
train_param = conf['train'][0] train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param return sample_param, memory_param, gnn_param, train_param
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
...@@ -166,6 +166,7 @@ class DataSet: ...@@ -166,6 +166,7 @@ class DataSet:
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1] self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
...@@ -222,7 +223,7 @@ class DataSet: ...@@ -222,7 +223,7 @@ class DataSet:
class TemporalGraphData(DistributedGraphStore): class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device): def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device) super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size): def _set_temporal_batch_cache(self,size,pin_size):
pass pass
def _load_feature_to_cuda(self,ids): def _load_feature_to_cuda(self,ids):
......
...@@ -299,7 +299,7 @@ class SharedMailBox(): ...@@ -299,7 +299,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory): memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -309,12 +309,14 @@ class SharedMailBox(): ...@@ -309,12 +309,14 @@ class SharedMailBox():
mem_src = memory[src] mem_src = memory[src]
mem_dst = memory[dst] mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None: if edge_feats is not None:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1) src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1]) mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
...@@ -324,7 +326,6 @@ class SharedMailBox(): ...@@ -324,7 +326,6 @@ class SharedMailBox():
index = unq_index index = unq_index
return index,mail,mail_ts return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts): def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
......
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
......
...@@ -5,10 +5,14 @@ from os.path import abspath, join, dirname ...@@ -5,10 +5,14 @@ from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
...@@ -34,10 +38,13 @@ parser = argparse.ArgumentParser( ...@@ -34,10 +38,13 @@ parser = argparse.ArgumentParser(
) )
parser.add_argument('--rank', default=0, type=int, metavar='W', parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset') help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W', parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples') help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W', parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples') help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args() args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
...@@ -66,7 +73,7 @@ seed_everything(1234) ...@@ -66,7 +73,7 @@ seed_everything(1234)
def main(): def main():
print('main') print('main')
use_cuda = True use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml') sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12) torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
...@@ -83,7 +90,7 @@ def main(): ...@@ -83,7 +90,7 @@ def main():
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device)) val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1) test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device)) test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_data.shape[1],val_data.shape[1],test_data.shape[1]) #print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1)) train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0: #if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1)) test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
...@@ -100,7 +107,7 @@ def main(): ...@@ -100,7 +107,7 @@ def main():
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
chunk_size = None, chunk_size = None,
...@@ -111,7 +118,7 @@ def main(): ...@@ -111,7 +118,7 @@ def main():
testloader = DistributedDataLoader(graph,test_data,sampler = sampler, testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -121,7 +128,7 @@ def main(): ...@@ -121,7 +128,7 @@ def main():
valloader = DistributedDataLoader(graph,val_data,sampler = sampler, valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -194,6 +201,8 @@ def main(): ...@@ -194,6 +201,8 @@ def main():
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,
use_dst_emb,
) )
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
...@@ -212,10 +221,13 @@ def main(): ...@@ -212,10 +221,13 @@ def main():
auc_mrr = float(torch.tensor(auc_mrr).mean()) auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
torch.cuda.synchronize() torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time() epoch_start_time = time.time()
train_aps = list() train_aps = list()
print('Epoch {:d}:'.format(e)) print('Epoch {:d}:'.format(e))
...@@ -227,7 +239,8 @@ def main(): ...@@ -227,7 +239,8 @@ def main():
model.module.memory_updater.last_updated_nid = None model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time() t_prep_s = time.time()
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
...@@ -270,13 +283,13 @@ def main(): ...@@ -270,13 +283,13 @@ def main():
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
) )
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
#end_event.record()
#torch.cuda.synchronize()
#write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize() torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time avg_time += time.time() - epoch_start_time
...@@ -288,9 +301,19 @@ def main(): ...@@ -288,9 +301,19 @@ def main():
#if cache.node_cache is not None: #if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum)) # print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val') ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time)) print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval() model.eval()
if mailbox is not None: if mailbox is not None:
mailbox.reset() mailbox.reset()
...@@ -304,6 +327,7 @@ def main(): ...@@ -304,6 +327,7 @@ def main():
else: else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc)) print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch']) print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown() ctx.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment