Commit 85e4b04c by zlj

build torch-utils by cmake

parent 54424543
...@@ -26,6 +26,10 @@ share/python-wheels/ ...@@ -26,6 +26,10 @@ share/python-wheels/
*.egg *.egg
MANIFEST MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller # PyInstaller
# Usually these files are written by a python script from a template # Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it. # before PyInstaller builds the exe, so as to inject date/other infos into it.
......
[submodule "csrc/partition/neighbor_clustering"]
path = csrc/partition/neighbor_clustering
url = https://gitee.com/onlynagesha/graph-partition-v4
set(CMAKE_SOURCE_DIR "/mnt/data/") set(CMAKE_SOURCE_DIR "/mnt/data/")
cmake_minimum_required(VERSION 3.15) cmake_minimum_required(VERSION 3.15)
project(starrygl_ops VERSION 0.1) project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON) option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON) option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON) option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON) option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
...@@ -38,8 +39,8 @@ if(WITH_CUDA) ...@@ -38,8 +39,8 @@ if(WITH_CUDA)
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp") file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm SHARED ${UVM_SRCS}) add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif() endif()
if(WITH_METIS) if(WITH_METIS)
...@@ -55,10 +56,10 @@ if(WITH_METIS) ...@@ -55,10 +56,10 @@ if(WITH_METIS)
include_directories(${METIS_INCLUDE_DIRS}) include_directories(${METIS_INCLUDE_DIRS})
add_library(metis SHARED "csrc/partition/metis.cpp") add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis PRIVATE ${GKLIB_LIBRARIES}) target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis PRIVATE ${METIS_LIBRARIES}) target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif() endif()
if(WITH_MTMETIS) if(WITH_MTMETIS)
...@@ -69,16 +70,29 @@ if(WITH_MTMETIS) ...@@ -69,16 +70,29 @@ if(WITH_MTMETIS)
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a") file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS}) include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis SHARED "csrc/partition/mtmetis.cpp") add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis PRIVATE ${MTMETIS_LIBRARIES}) target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_VERTICES) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_EDGES) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_WEIGHTS) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_PARTITIONS) target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif() endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
set(LDG_DIR "csrc/partition/neighbor_clustering")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include") include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp) add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
...@@ -91,17 +105,23 @@ if(WITH_PYTHON) ...@@ -91,17 +105,23 @@ if(WITH_PYTHON)
endif() endif()
if (WITH_CUDA) if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm) target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif() endif()
if (WITH_METIS) if (WITH_METIS)
target_link_libraries(${PROJECT_NAME} PRIVATE metis) message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif() endif()
if (WITH_MTMETIS) if (WITH_MTMETIS)
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis) message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif() endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so # add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler") set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
...@@ -4,15 +4,13 @@ ...@@ -4,15 +4,13 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory"); m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device"); m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu"); m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage"); m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage"); m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise") py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy) .value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy) .value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
...@@ -20,4 +18,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -20,4 +18,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation) .value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly) .value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly); .value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
} }
...@@ -21,4 +21,12 @@ at::Tensor mt_metis_partition( ...@@ -21,4 +21,12 @@ at::Tensor mt_metis_partition(
int64_t num_parts, int64_t num_parts,
int64_t num_workers, int64_t num_workers,
bool recursive bool recursive
); );
\ No newline at end of file
at::Tensor ldg_partition(
at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers
);
#include <torch/all.h>
#include "neighbor_clustering/vertex_partition/vertex_partition.h"
#include "neighbor_clustering/vertex_partition/params.h"
at::Tensor ldg_partition(at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers) {
AT_ASSERT(edges.dim() == 2);
auto edges_n_cols = edges.size(1);
AT_ASSERT(edges_n_cols >= 2 && edges_n_cols <= 3);
// Note: other checks are performed in the implementation of vp::ldg_partition function series.
auto n = edges.slice(1, 0, 2).max().item<int64_t>() + 1;
auto params = vp::LDGParams{.N = n, .K = num_parts, .openmp_n_threads = static_cast<int>(num_workers)};
auto edges_clone = edges.clone();
if (vertex_weights.has_value()) {
auto vertex_weights_clone = vertex_weights->clone();
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_v_init(edges_clone, vertex_weights_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition_v(edges_clone, vertex_weights_clone, params);
}
} else {
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_init(edges_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition(edges_clone, params);
}
}
}
...@@ -114,11 +114,15 @@ edge_weight_dict = {} ...@@ -114,11 +114,15 @@ edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1 edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
# #
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', # partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
......
#include <cuda_runtime.h>
#include <cstdio>
int main()
{
int count = 0;
if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;
if (count == 0) return -1;
for (int device = 0; device < count; ++device)
{
cudaDeviceProp prop;
if (cudaSuccess == cudaGetDeviceProperties(&prop, device))
std::printf("%d.%d ", prop.major, prop.minor);
}
return 0;
}
#include <cuda.h>
#include <cstdio>
int main() {
printf("%d.%d", CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100);
return 0;
}
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
Advanced Concepts
=================
.. toctree::
ts_sampling
pp_training
tp_training
data_proc
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
Distributed Temporal Sampling
=============================
.. note::
基于分布式时序图采样的训练模式
\ No newline at end of file
starrygl.distributed
====================
.. note::
自动生辰的api文档,需要在starrygl源代码目录里添加码内注释
.. currentmodule:: starrygl.distributed
.. autosummary::
DistributedContext
\ No newline at end of file
Package References
==================
.. toctree::
distributed
\ No newline at end of file
Cheatsheets
===========
.. note::
可以放一些框架内置的GNN模型,GNN数据集清单以及默认的接口调用规则,参考
`pyg cheatsheets <https://pytorch-geometric.readthedocs.io/en/latest/cheatsheet/data_cheatsheet.html>`_
\ No newline at end of file
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import starrygl
project = 'StarryGL'
copyright = '2023, StarryGL Team'
author = 'StarryGL Team'
version = starrygl.__version__
release = starrygl.__version__
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.duration",
]
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
External Resources
==================
.. note::
!!!放置一些StarryGL相关的论文,GNN相关的项目以及有助于使用StarryGL的资源,可以参考:
`PyG Resources <https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html>`_
\ No newline at end of file
Get Started
===========
.. toctree::
install_guide
intro_example
\ No newline at end of file
Installation
============
.. note::
安装方法
\ No newline at end of file
Introduction by Example
=======================
.. note::
安装好之后,快速开始一个时序图模型训练。先展示一下基础使用,可以是单机的?
\ No newline at end of file
StarryGL Documentation
======================
.. toctree::
guide/index
tutorial/index
advanced/index
api/python/index
cheatsheets/index
external/index
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
.. * :ref:`search`
User Cases and Applications
===========================
.. note::
放一些时序GNN的用例和应用
\ No newline at end of file
Preparing the Temporal Graph Dataset
====================================
.. note::
包含从原始数据开始的数据清洗和预处理步骤,最终形成可以被StarryGL使用的数据文件
\ No newline at end of file
Distributed Training
====================
.. note::
开始介绍分布式训练相关的方法,并展示一个最简单的分布式DDP训练流程
\ No newline at end of file
Tutorials
===============
.. toctree::
intro
module
dataset
application
distributed
\ No newline at end of file
Introduction to Temporal GNN
==============================================
.. note::
简单介绍一下时序GNN,应用场景,需要解决的问题等,相当于一个总体的介绍
Creating Temporal GNN Models
============================
.. note::
介绍如何创建GNN模型,找最经典最简洁的两个例子即可。包括 **离散时间动态图模型** 模型构建和 **连续时间动态图模型**。
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hwj/.miniconda3/envs/pyg/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def act_blks(act_mask: Tensor, blk_size: int) -> Tensor:\n",
" assert act_mask.dtype == torch.bool\n",
" assert act_mask.dim() == 1\n",
" n = act_mask.size(0)\n",
" m = (n + blk_size - 1) // blk_size * blk_size\n",
"\n",
" blk_mask = act_mask.detach()\n",
" if n != m:\n",
" blk_mask = torch.zeros(m, dtype=act_mask.dtype, device=act_mask.device)\n",
" blk_mask[:n] = act_mask.detach()\n",
" return blk_mask.view(-1, blk_size).max(dim=-1).values"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"act_mask = torch.rand(1000) < 0.01\n",
"blk_mask = act_blks(act_mask, 64)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cached_blocks = torch.randn(blk_mask.size(0), 64, 8)\n",
"blks = cached_blocks[blk_mask]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def gather_block(blks: Tensor, blk_mask: Tensor, index: Tensor) -> Tensor:\n",
" blk_size = blks.size(1)\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
" \n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
" \n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" s = index.shape[:1] + blks.shape[2:]\n",
" data = torch.zeros(s, dtype=blks.dtype, device=blks.device)\n",
" data[ind_mask] = blks[idx0, idx1, ...]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def scatter_block(x: Tensor, index: Tensor, blk_mask: Tensor, blk_size: int, reduce: str = \"sum\") -> Tensor:\n",
" num_blks = blk_mask.count_nonzero()\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
"\n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
"\n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" data = scatter(\n",
" src=x[ind_mask],\n",
" index=idx0*blk_size+idx1,\n",
" dim=0,\n",
" dim_size=num_blks*blk_size,\n",
" reduce=reduce,\n",
" )\n",
" return data.view(num_blks, blk_size, *x.shape[1:])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# def index_block(index: Tensor, blk_mask: Tensor, blk_size: int) -> Tensor:\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, True, True, False, True, False, False, True, False, True,\n",
" False, True, True, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, True, False, False, True,\n",
" True, False, False, True, True, False, True, True, False, False,\n",
" True, False, True, False, False, False, False, False, False, False,\n",
" True, True, False, False, False, True, True, True, False, False,\n",
" True, False, True, False, True, False, True, False, False, True,\n",
" False, True, True, True, False, False, False, False, False, True,\n",
" False, True, False, True, True, True, False, True, False, True,\n",
" False, True, False, True, True, False, False, True, False, True])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index = torch.randint(1000, size=(100,), dtype=torch.long)\n",
"(gather_block(blks, blk_mask, index)**2).sum(dim=-1) != 0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 64, 8])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(100, 8)\n",
"scatter_block(x, index, blk_mask, 64).size()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True, True, True, True, True])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = scatter_block(x, index, blk_mask, 64)\n",
"t = t.view(-1, *t.shape[2:])\n",
"act_blks((t**2).sum(dim=-1) != 0, 64)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.sort(\n",
"values=tensor([ 5, 15, 18, 19, 25, 48, 62, 64, 68, 76, 88, 92, 94, 106,\n",
" 114, 115, 120, 122, 126, 131, 135, 165, 169, 169, 177, 192, 214, 215,\n",
" 245, 249, 255, 269, 282, 286, 294, 299, 334, 358, 359, 372, 401, 403,\n",
" 430, 437, 455, 459, 459, 459, 471, 484, 490, 504, 505, 516, 528, 529,\n",
" 559, 584, 613, 614, 614, 624, 629, 637, 638, 674, 679, 701, 706, 735,\n",
" 742, 746, 748, 750, 759, 774, 774, 776, 778, 787, 803, 825, 831, 834,\n",
" 861, 862, 868, 877, 882, 889, 899, 909, 921, 933, 938, 954, 960, 969,\n",
" 981, 986]),\n",
"indices=tensor([ 6, 77, 14, 41, 43, 0, 13, 53, 38, 54, 15, 76, 98, 47, 75, 80, 8, 95,\n",
" 17, 34, 83, 87, 62, 93, 81, 20, 67, 39, 32, 5, 70, 4, 89, 72, 37, 64,\n",
" 94, 42, 57, 29, 56, 21, 60, 7, 45, 59, 86, 96, 49, 92, 28, 88, 24, 68,\n",
" 3, 61, 44, 2, 79, 55, 85, 33, 12, 1, 11, 97, 50, 9, 84, 40, 71, 22,\n",
" 26, 30, 73, 16, 90, 27, 19, 35, 48, 18, 46, 74, 65, 78, 52, 31, 23, 58,\n",
" 91, 66, 99, 69, 51, 36, 82, 63, 25, 10]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.sort()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
# from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
# from .cache import MessageCache, NodeProbe
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
):
assert len(output_tensor_list) == len(input_tensor_list)
if group is None:
group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
elif backend == "mpi":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
p2p_op_list: List[dist.P2POp] = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
p2p_op_list.extend([
dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
])
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
# import torch
# import torch.nn as nn
# from torch.utils import hooks
# from torch import Tensor
# from typing import *
# from .route import Route, RouteWorkBase
# from .shrink import ShrinkData
# from .acopy import get_executor as get_acopy_executor, AsyncCopyWorkBase
# from .route import get_route_executor
# class ProbeLayer:
# def __init__(self, layer_id, probe_obj) -> None:
# self.layer_id: int = layer_id
# self.probe_obj: NodeProbe = probe_obj
# self.last_w = torch.ones(self.probe_obj.num_nodes, dtype=torch.float32)
# self._val_hook_handle: Optional[hooks.RemovableHandle] = None
# def to(self, device):
# self.last_w = self.last_w.to(device)
# return self
# def remove_val_hook(self):
# if self._val_hook_handle is not None:
# self._val_hook_handle.remove()
# self._val_hook_handle = None
# def register_val_hook(self, val: Tensor, idx: Optional[Tensor]):
# assert self._val_hook_handle is None, "cannot call register_val_hook() twice"
# def hook(grad: Tensor):
# from starrygl.utils.printer import main_print
# import time
# self.probe_obj.update_sample_w(self.last_w, grad, idx)
# self.remove_val_hook()
# self._backward_sample(False)
# main_print(self.layer_id, grad.size(0), idx.size(0))
# self.val_hook_handle = val.register_hook(hook)
# def warmup_sample(self):
# assert self._is_last_layer()
# for probe_layer in self.probe_obj.layers[::-1]:
# probe_layer._backward_sample()
# # max_norm = max(probe_layer.last_w.max().item(), 1.0)
# # probe_layer.last_w[probe_layer.last_w == 1.0] = max_norm
# def _backward_sample(self, x: bool = True):
# dst_idx = self._collect_prev_dst_idx()
# if dst_idx is None:
# return
# for cache in self.probe_obj.hook_caches:
# next_cache_layer = self._next_cache_layer(cache)
# shrink = next_cache_layer.set_shrink_data(dst_idx)
# src_idx = shrink.src_idx
# work = get_acopy_executor().async_pull(next_cache_layer.data, src_idx)
# next_cache_layer.sync_pull_work(work)
# if not self._is_first_layer() and x:
# next_cache_layer.sync_backward_sample_work(src_idx)
# def _collect_prev_dst_idx(self) -> Optional[Tensor]:
# if len(self.probe_obj.hook_caches) == 0:
# return None
# if self._is_last_layer():
# return self._wrapped_sample(self.probe_obj.num_samples, None)
# if len(self.probe_obj.hook_caches) == 1:
# for cache in self.probe_obj.hook_caches:
# prev_cache_layer = self._prev_cache_layer(cache)
# dst_idx = prev_cache_layer.sync_backward_sample_work()
# else:
# tmp = torch.zeros(self.probe_obj.num_nodes, dtype=torch.bool, device=self.last_w.device)
# for cache in self.probe_obj.hook_caches:
# prev_cache_layer = self._prev_cache_layer(cache)
# prev_idx = prev_cache_layer.sync_backward_sample_work()
# if prev_idx is not None:
# tmp.index_fill_(0, prev_idx, 1)
# dst_idx = torch.where(tmp)[0]
# if dst_idx.size(0) == 0:
# dst_idx = None
# if dst_idx is None:
# return None
# return self._wrapped_sample(self.probe_obj.num_samples, dst_idx)
# def _prev_cache_layer(self, cache):
# cache: MessageCache = cache
# i = self.layer_id + 1
# if i < self.probe_obj.num_layers:
# return cache.layers[i]
# return None
# def _next_cache_layer(self, cache):
# cache: MessageCache = cache
# return cache.layers[self.layer_id]
# def _is_last_layer(self):
# return self.layer_id + 1 >= self.probe_obj.num_layers
# def _is_first_layer(self):
# return self.layer_id == 0
# def _wrapped_sample(self, k: int, idx: Optional[Tensor]) -> Optional[Tensor]:
# w = self.last_w
# if idx is None:
# return self.probe_obj.sample(w, k) if 0 < k and k < self.probe_obj.num_nodes else None
# if 0 < k and k < idx.size(0):
# t = self.probe_obj.sample(w[idx], k)
# return idx[t]
# else:
# return idx
# class NodeProbe:
# def __init__(self,
# num_nodes: int,
# num_layers: int,
# num_samples: int = 0,
# p: str = "fro",
# dim: int = -1,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.num_nodes = num_nodes
# self.num_layers = num_layers
# self.num_samples = num_samples
# self.p = p
# self.dim = dim
# self.beta = beta
# self.layers = [ProbeLayer(i, self) for i in range(num_layers)]
# self.hook_caches: Set[MessageCache] = set()
# def to(self, device):
# for i in range(self.num_layers):
# self.layers[i].to(device)
# return self
# def assign_message_cache(self, cache):
# cache: MessageCache = cache
# assert self.num_layers == cache.num_layers
# self.hook_caches.add(cache)
# def update_sample_w(self, last_w: Tensor, grad: Tensor, idx: Optional[Tensor]):
# val_norm = grad.norm(p=self.p, dim=self.dim)
# if idx is None:
# last_w[:] = val_norm
# else:
# if self.beta != 1.0:
# last_w.mul_(self.beta)
# last_w[idx] = val_norm
# def sample(self, w: Tensor, k: int) -> Optional[Tensor]:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=k, replacement=False)
# def apply(self, i: int, val: Tensor, idx: Optional[Tensor]) -> Tensor:
# self.layers[i].register_val_hook(val, idx)
# return val
# class CacheLayer:
# def __init__(self, layer_id, cache_obj, num_features) -> None:
# self.layer_id: int = layer_id
# self.cache_obj: MessageCache = cache_obj
# self.data: Tensor = torch.randn(
# size=(self.cache_obj.src_size, num_features),
# dtype=torch.float32,
# device=self.cache_obj.cache_device,
# )
# self.state: Dict[str, Any] = {}
# self.shrink: Optional[ShrinkData] = None
# self._async_push_work: Optional[AsyncCopyWorkBase] = None
# self._async_pull_work: Optional[AsyncCopyWorkBase] = None
# self._backward_sample_work: Optional[RouteWorkBase] = None
# def to(self, device):
# if self.shrink is not None:
# self.shrink = self.shrink.to(device)
# return self
# def cache_data_to(self):
# self.data = self.data.to(self.cache_obj.cache_device)
# return self
# def set_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
# self.shrink = self.cache_obj.new_shrink_data(dst_idx)
# return self.shrink
# def get_shrink_data(self) -> Optional[ShrinkData]:
# return self.shrink
# # def set_backward_sample_work(self, src_idx: Optional[Tensor]):
# # if src_idx is None:
# # return
# # assert self._backward_sample_work is None
# # self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# # def pop_backward_sample_work(self) -> Optional[RouteWorkBase]:
# # work = self._backward_sample_work
# # self._backward_sample_work = None
# # return work
# def sync_backward_sample_work(self, src_idx: Optional[Tensor] = None) -> Optional[Tensor]:
# dst_idx = None
# if self._backward_sample_work is not None:
# _, dst_idx = self._backward_sample_work.get()
# if src_idx is not None:
# # self._backward_sample_work = self.cache_obj.route.backward_a2a(src_idx, src_idx)
# work = get_route_executor().async_backward_a2a(src_idx, src_idx, self.cache_obj.route)
# self._backward_sample_work = work
# else:
# self._backward_sample_work = None
# return dst_idx
# def sync_push_work(self, work: Optional[AsyncCopyWorkBase] = None):
# if self._async_push_work is not None:
# self._async_push_work.get()
# self._async_push_work = work
# def sync_pull_work(self, work: Optional[AsyncCopyWorkBase] = None) -> Optional[Tensor]:
# out = None
# if self._async_pull_work is not None:
# out = self._async_pull_work.get()
# self._async_pull_work = work
# return out
# class MessageCache:
# def __init__(self,
# src_ids: Tensor,
# dst_ids: Tensor,
# edge_index: Tensor,
# num_features: Union[int, torch.Size],
# num_layers: int,
# cache_device: Union[str, torch.device, None],
# bipartite: bool = False,
# ) -> None:
# self.src_size = src_ids.size(0)
# self.dst_size = dst_ids.size(0)
# self.edge_index = edge_index
# self.num_layers = num_layers
# if cache_device is None:
# cache_device = edge_index.device
# self.cache_device = cache_device
# self.bipartite = bipartite
# self.route = Route(dst_ids, src_ids, bipartite=bipartite)
# self.layers = [CacheLayer(i, self, num_features) for i in range(num_layers)]
# def to(self, device):
# self.edge_index = self.edge_index.to(device)
# self.route = self.route.to(device)
# for i in range(self.num_layers):
# self.layers[i].to(device)
# return self
# def cached_data_to(self, device):
# self.cache_device = torch.device(device)
# for i in range(self.num_layers):
# self.layers[i].cache_data_to()
# return self
# # @property
# # def is_offload(self) -> bool:
# # return self.edge_index.device != self.cache_device
# def new_shrink_data(self, dst_idx: Optional[Tensor]) -> Optional[ShrinkData]:
# if dst_idx is None:
# return None
# return ShrinkData(
# src_size=self.src_size,
# dst_size=self.dst_size,
# dst_idx=dst_idx,
# edge_index=self.edge_index,
# bipartite=self.bipartite,
# )
# # def clear_shrink_data(self):
# # for i in range(self.num_layers):
# # self.layers[i].shrink = None
# def replace_layer_data(self, i: int, data: Tensor):
# layer = self.layers[i]
# assert layer.data.size(0) == data.size(0)
# layer.data = data
# return self
# def update_cache(self,
# i: int,
# val: Tensor,
# idx: Optional[Tensor],
# async_op: bool = False,
# ) -> Tuple[Tensor, Optional[Tensor]]:
# layer: CacheLayer = self.layers[i]
# src_val, src_idx = self.route.apply(val, idx, layer.state, async_op=async_op)
# # full graph
# if idx is None:
# return src_val, None
# shrink = layer.get_shrink_data()
# if shrink is None: # 这里可能存在问题,在初始化的时候
# data = torch.empty_like(layer.data, device=val.device).copy_(layer.data)
# data.index_copy_(0, src_idx, src_val)
# return data, src_idx
# else:
# push_work = get_acopy_executor().async_push(layer.data, src_idx, src_val)
# layer.sync_push_work(push_work)
# data = layer.sync_pull_work()
# sval, sidx = shrink.shrink_src_val_and_idx(src_val, src_idx)
# data.index_copy_(0, sidx, sval)
# return data, sidx
# # 先异步更新缓存
# # 同时异步下载缓存
# # 然后本地和远程缓存混合,作为最终结果返回
# def fetch_pull_tensor(self, i: int) -> Optional[Tensor]:
# return self.layers[i].sync_pull_work()
# # def _update_cache_impl(self,
# # i: int,
# # dst_val: Tensor,
# # dst_idx: Optional[Tensor],
# # route: Route,
# # async_op: bool,
# # ):
# # # communcation
# # state = self.layers[i].state
# # src_val, src_idx = route.apply(dst_val, dst_idx, state, async_op=async_op)
# # # push latest embeddings
# # data = self.layers[i].data
# # if src_idx is None:
# # data[:] = src_val
# # else:
# # data[src_idx] = src_val
# # # get previous generated shrink data
# # shr: ShrinkData = self.layers[i].shrink
# # if shr is None:
# # return src_val, src_idx
# # else:
# # # pull latest embeddings
# # return data[shr.src_idx], shr.src_idx
# # def _update_cache_offload(self,
# # i: int,
# # dst_val: Tensor,
# # dst_idx: Optional[Tensor],
# # route: Route,
# # async_op: bool,
# # ):
# # raise NotImplementedError
# # def _compute_cached_data_size(self) -> torch.Size:
# # num_features = self.num_features
# # if num_features is None:
# # cached_data_size = torch.Size([self.src_size,])
# # else:
# # cached_data_size = torch.Size([self.src_size, num_features])
# # return cached_data_size
# # def _compute_cpu_buf_size(self) -> torch.Size:
# # num_features = self.num_features
# # if num_features is None:
# # cpu_buf_size = torch.Size([2**32,])
# # else:
# # cpu_buf_size = torch.Size([2**32 // num_features, num_features])
# # return cpu_buf_size
\ No newline at end of file
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# from .route import Route, GatherWork
# # from .cache import CachedEmbeddings
# class GatherContext:
# def __init__(self,
# this,
# route: Route,
# async_op: bool,
# ) -> None:
# self.this = this
# self.route = route
# self.async_op = async_op
# class Gather(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_features: Optional[int] = None,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.beta = beta
# if num_features is None:
# self.register_buffer("last_embd", torch.zeros(num_nodes))
# else:
# self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
# self.last_fw_work: Optional[GatherWork] = None
# self.last_bw_work: Optional[GatherWork] = None
# self.reset_parameters()
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# route: Route,
# async_op: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# with self._manager(route=route, async_op=async_op):
# return GatherFunction.apply(val, idx)
# def reset_parameters(self):
# last_embd = self.get_buffer("last_embd")
# nn.init.normal_(last_embd, mean=0, std=1)
# def fuse_embeddings(self,
# val: Tensor,
# idx: Optional[Tensor] = None,
# inplace: bool = False,
# ) -> Tensor:
# last_embd = self.get_buffer("last_embd")
# return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
# @contextmanager
# def _manager(self, route: Route, async_op: bool):
# global _global_gather_context
# stacked = _global_gather_context
# try:
# _global_gather_context = GatherContext(
# this=self,
# route=route,
# async_op=async_op,
# )
# yield _global_gather_context
# finally:
# _global_gather_context = stacked
# class GatherFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor] = None,
# ):
# gather_ctx = _last_global_gather_context()
# this: Gather = gather_ctx.this
# route: Route = gather_ctx.route
# async_op: bool = gather_ctx.async_op
# return_idx: bool = idx is not None
# current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
# if async_op:
# work = this.last_fw_work or current_work
# this.last_fw_work = current_work
# else:
# work = current_work
# recv_val, recv_idx = work.get()
# recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
# recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
# if this.training:
# ctx.save_for_backward(idx, recv_idx)
# ctx.this = this
# ctx.route = route
# ctx.async_op = async_op
# return recv_val, recv_idx
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# val_grad: Tensor,
# idx_grad: Optional[Tensor],
# ):
# this: Gather = ctx.this
# route: Route = ctx.route
# async_op: bool = ctx.async_op
# with torch.no_grad():
# recv_idx, idx_grad = ctx.saved_tensors
# if idx_grad is not None:
# val_grad = val_grad[idx_grad]
# current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
# if async_op:
# work = this.last_bw_work or current_work
# this.last_bw_work = current_work
# else:
# work = current_work
# recv_val = work.get_val()
# if recv_idx is not None:
# recv_val = recv_val[recv_idx]
# return recv_val, None
# class GatherFuseFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# last_embd: Tensor,
# beta: float,
# inplace: bool,
# training: bool,
# ):
# if not inplace:
# last_embd = last_embd.clone()
# ctx.beta = beta
# if idx is None:
# assert val.size(0) == last_embd.size(0)
# if beta != 1.0:
# last_embd.mul_(1 - beta).add_(val * beta)
# else:
# last_embd[:] = (val)
# else:
# assert val.size(0) == idx.size(0)
# if beta != 1.0:
# last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
# else:
# last_embd[idx] = val
# if training:
# ctx.beta = beta
# ctx.save_for_backward(idx)
# return last_embd
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# beta: float = ctx.beta
# idx, = ctx.saved_tensors
# if idx is not None:
# grad = grad[idx]
# if beta != 1.0:
# grad = grad * beta
# return grad, None, None, None, None, None
# #### private functions
# _global_gather_context: Optional[GatherContext] = None
# def _last_global_gather_context() -> GatherContext:
# global _global_gather_context
# assert _global_gather_context is not None
# return _global_gather_context
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# class Lache:
# def __init__(self,
# cache_size: int,
# data: Tensor,
# ) -> None:
# assert not data.is_cuda, "data must be in CPU Memory"
# cache_size = (cache_size,) + data.shape[1:]
# self.fdata = data
# self.cache = torch.zeros(cache_size, dtype=data.dtype)
# self.no_idx: int = (2**62-1)*2+1
# self.cached_idx = torch.empty(cache_size, dtype=torch.long).fill_(self.no_idx)
# self.read_count = torch.zeros(data.size(0), dtype=torch.long)
# def to(self, device):
# self.cache = self.cache.to(device)
# self.cached_idx = self.cached_idx.to(device)
# self.read_count = self.read_count.to(device)
# return self
# def _push_impl(self, value: Tensor, index: Tensor):
# # self.read_count[index] += 1
# imp = torch.zeros_like(self.read_count)
# def _pull_impl(self, index: Tensor):
# assert index.device == self.cache.device
# self.read_count[index] += 1
# s = index.shape[:1] + self.cache.shape[1:]
# x = torch.empty(s, dtype=self.cache.dtype, device=self.cache.device)
# cache_mask = torch.zeros_like(self.read_count, dtype=torch.bool).index_fill_(0, self.cached_idx, 1)[index]
# cache_index = index[cache_mask]
# x[cache_mask] =
# no_cache_index = index[~cache_mask]
# rt_index = index[~lc_mask].to(self.fdata.device)
# rt_data = self.fdata[rt_index].to(index.device)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
import torch
import torch.autograd as autograd
import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor
from typing import *
from contextlib import contextmanager
from .a2a import all_to_all
from .event import Event
class RouteWorkBase:
def __init__(self) -> None:
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
raise NotImplementedError
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
raise NotImplementedError
def has_events(self) -> bool:
return self._events is not None
def set_events(self, start, end):
self._events = (start, end)
return self
def time_used(self) -> float:
if self._events is None:
raise RuntimeError("not found events")
start, end = self._events
return start.elapsed_time(end)
class RouteExecutorWork(RouteWorkBase):
def __init__(self, handle) -> None:
super().__init__()
self._handle = handle
self._events: Optional[Tuple[Event, Event]] = None
def wait(self) -> None:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.wait()
def get(self) -> Tuple[Tensor, Optional[Tensor]]:
if self._events is not None:
self._events[1].wait(torch.cuda.current_stream())
return self._handle.get()
class RouteExecutor:
def __init__(self) -> None:
if torch.cuda.is_available():
self._stream = torch.cuda.Stream()
else:
self._stream = None
self._executor = ThreadPool(processes=1)
self._group = dist.new_group()
def async_forward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.forward_a2a, val, idx)
def async_backward_a2a(self, val: Tensor, idx: Optional[Tensor], route) -> RouteWorkBase:
return self._async_a2a_impl(route.backward_a2a, val, idx)
@torch.no_grad()
def _async_a2a_impl(self, func, val: Tensor, idx: Optional[Tensor]) -> RouteWorkBase:
start = Event(use_cuda=val.is_cuda)
end = Event(use_cuda=val.is_cuda)
stream: Optional[torch.cuda.Stream] = self._stream
def run():
start.wait(stream)
with torch.cuda.stream(stream):
ret = func(val, idx, group=self._group)
end.record()
return ret
start.record()
handle = self._executor.apply_async(run)
return RouteExecutorWork(handle).set_events(start, end)
class Route:
def __init__(self,
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool = False,
group: Any = None,
):
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
if not bipartite:
# 要求数据为点分割,即所有分区的src_ids正交
assert src_ids.size(0) <= dst_ids.size(0)
assert (dst_ids[:src_ids.size(0)] == src_ids).all()
rank = dist.get_rank()
world_size = dist.get_world_size()
# 注意这里的src_ids和dst_ids与PGraph中的二分图含义不一样
# src_ids表示发送方,实际是PGraph.dst_ids
# dst_ids表示接收方,实际是PGraph.src_ids
self.src_size: int = src_ids.size(0)
self.dst_size: int = dst_ids.size(0)
# 索引的格式
ikw = dict(dtype=torch.long, device=src_ids.device)
# 获得每个分区发送方的节点个数
all_src_sizes: List[torch.Size] = [None] * world_size
dist.all_gather_object(all_src_sizes, src_ids.size(), group=group)
# 获得总节点个数
num_nodes = _all_reduce_num_nodes(src_ids, dst_ids, group=group)
# dst_ids节点到局部编号的映射
imp = torch.empty(num_nodes, **ikw).fill_((2**62-1)*2+1)
imp[dst_ids] = torch.arange(self.dst_size, **ikw)
# 这部分代码主要是异步获取其他分区的点,重叠一部分计算和IO
all_src_ids: List[Tensor] = [None] * world_size
all_src_get = [None] * world_size
def fetch_src_ids(i: int):
if i == rank:
all_src_ids[i] = src_ids
else:
s = all_src_sizes[i]
all_src_ids[i] = torch.empty(s, **ikw)
all_src_get[i] = dist.broadcast(
all_src_ids[i], src=i, async_op=True, group=group)
self.forward_routes: List[Tensor] = []
self.backward_routes: List[Tensor] = []
# xmp有两个作用,一个是计算节点交集,构建索引
# 另一个是映射src_ids到局部编号
xmp = torch.empty_like(imp)
for i in range(world_size):
# 预取数据
if i == 0:
fetch_src_ids(i)
if i + 1 < world_size:
fetch_src_ids(i + 1)
all_src_get[i].wait()
ids = all_src_ids[i]
# 计算交集,判断需要传输的部分
xmp.fill_(0)
xmp[ids] += 1
xmp[dst_ids] += 1
ind = torch.where(xmp > 1)[0]
# src_ids节点到局部编码的映射
xmp.fill_((2**62-1)*2+1)
xmp[ids] = torch.arange(ids.size(0), **ikw)
# 局部编码
src_ind = xmp[ind]
dst_ind = imp[ind]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_route = torch.vstack([src_ind, dst_ind])
bw_route = torch.vstack([dst_ind, src_ind])
self.forward_routes.append(fw_route)
self.backward_routes.append(bw_route)
# 把fw_route发送给src_ids所在分区,构建最终的路由表
self.forward_routes = _fix_fw_routes(self.forward_routes, group=group)
# 满足同构图条件,则每个点添加自环
if not bipartite:
rank_ind = torch.arange(src_ids.size(0), **ikw)
rank_route = torch.vstack([rank_ind, rank_ind])
self.forward_routes[rank] = rank_route
self.backward_routes[rank] = rank_route
dist.barrier()
def to(self, device):
self.forward_routes = [ro.to(device) for ro in self.forward_routes]
self.backward_routes = [ro.to(device) for ro in self.backward_routes]
return self
def _a2a_impl(self,
value: Tensor,
index: Optional[Tensor],
smask: Union[Tensor, int],
rmask: Union[Tensor, int],
send_routes: List[Tensor],
recv_routes: List[Tensor],
group: Optional[Any],
) -> Tuple[Tensor, Optional[Tensor]]:
bkw = dict(dtype=torch.bool, device=value.device)
ikw = dict(dtype=torch.long, device=value.device)
fkw = dict(dtype=value.dtype, device=value.device)
send_buf_size = smask.size(0) if isinstance(smask, Tensor) else smask
recv_buf_size = rmask.size(0) if isinstance(rmask, Tensor) else rmask
if index is None:
send_val_idx = [ro[0] for ro in send_routes]
recv_val_idx = [ro[0] for ro in recv_routes]
send_val_dat, recv_val_dat = [], []
for sidx, ridx in zip(send_val_idx, recv_val_idx):
s = (ridx.size(0),) + value.shape[1:]
send_val_dat.append(value[sidx])
recv_val_dat.append(torch.zeros(s, **fkw))
all_to_all(recv_val_dat, send_val_dat, group=group)
s = (recv_buf_size,) + value.shape[1:]
recv_value = torch.zeros(s, **fkw)
for idx, dat in zip(recv_val_idx, recv_val_dat):
recv_value[idx] += dat
return recv_value, None
else:
assert value.size(0) == index.size(0)
if not isinstance(smask, Tensor):
smask = torch.zeros(send_buf_size, **bkw).index_fill_(0, index, 1)
send_routes = [ro[:,smask[ro[0]]] for ro in send_routes]
send_sizes = torch.tensor([ro.size(1) for ro in send_routes], **ikw)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes, group=group)
recv_sizes = recv_sizes.tolist()
imp = torch.empty(send_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[index] = torch.arange(index.size(0), **ikw)
send_val_idx, recv_val_idx = [], []
send_val_dat, recv_val_dat = [], []
for i, s in enumerate(recv_sizes):
c = (s,) + value.shape[1:]
send_val_idx.append(send_routes[i][1])
recv_val_idx.append(torch.zeros(s, **ikw))
send_index = imp[send_routes[i][0]]
send_val_dat.append(value[send_index])
recv_val_dat.append(torch.zeros(c, **fkw))
all_to_all(recv_val_idx, send_val_idx, group=group)
all_to_all(recv_val_dat, send_val_dat, group=group)
rmask = rmask if isinstance(rmask, Tensor) else torch.zeros(recv_buf_size, **bkw)
for idx in recv_val_idx:
rmask[idx] = True
recv_index = torch.where(rmask)[0]
imp = torch.empty(recv_buf_size, **ikw).fill_((2**62-1)*2+1)
imp[recv_index] = torch.arange(recv_index.size(0), **ikw)
s = recv_index.shape[:1] + value.shape[1:]
recv_value = torch.zeros(s, **fkw)
for idx, dat in zip(recv_val_idx, recv_val_dat):
recv_value[imp[idx]] += dat
return recv_value, recv_index
def forward_a2a(self,
value: Tensor,
index: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
if index is None or index.dtype == torch.long:
return self._a2a_impl(
value=value,
index=index,
smask=self.src_size,
rmask=self.dst_size,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
group=group,
)
else:
smask, index = index, torch.where(index)[0]
assert smask.size(0) == self.src_size
rmask = torch.zeros(self.dst_size, dtype=smask.dtype, device=smask.device)
value, index = self._a2a_impl(
value=value,
index=index,
smask=smask,
rmask=rmask,
send_routes=self.forward_routes,
recv_routes=self.backward_routes,
group=group,
)
return value, rmask
def backward_a2a(self,
value: Tensor,
index: Optional[Tensor] = None,
group: Optional[Any] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
if index is None or index.dtype == torch.long:
return self._a2a_impl(
value=value,
index=index,
smask=self.dst_size,
rmask=self.src_size,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
else:
smask, index = index, torch.where(index)[0]
assert smask.size(0) == self.dst_size
rmask = torch.zeros(self.src_size, dtype=smask.dtype, device=smask.device)
value, index = self._a2a_impl(
value=value,
index=index,
smask=smask,
rmask=rmask,
send_routes=self.backward_routes,
recv_routes=self.forward_routes,
group=group,
)
return value, rmask
def apply(self,
value: Tensor,
index: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
with self._a2a_manager():
return RouteFunction.apply(value, index)
@contextmanager
def _a2a_manager(self):
global _global_route_context
stacked_ctx = _global_route_context
try:
_global_route_context = RouteCtx(self)
yield _global_route_context
finally:
_global_route_context = stacked_ctx
class RouteCtx:
def __init__(self, route) -> None:
self.route = route
class RouteFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
value: Tensor,
index: Optional[Tensor],
):
route_ctx = get_global_route_context()
route: Route = route_ctx.route
value = value.detach()
if index is not None:
index = index.detach()
work = get_route_executor().async_forward_a2a(value, index, route)
recv_value, recv_index = work.get()
# if work.has_events():
# route.total_time_used += work.time_used()
ctx.route = route
ctx.save_for_backward(recv_index)
return recv_value, recv_index
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
_: None,
):
route: Route = ctx.route
index, = ctx.saved_tensors
grad = grad.detach()
work = get_route_executor().async_backward_a2a(grad, index, route)
recv_grad, recv_index = work.get()
# if work.has_events():
# route.total_time_used += work.time_used()
return recv_grad, recv_index
#### private functions
def _all_reduce_num_nodes(
src_ids: Tensor,
dst_ids: Tensor,
group: Any = None,
) -> int:
max_ids = torch.zeros(1, dtype=src_ids.dtype, device=src_ids.device)
max_ids = max_ids.max(src_ids.max()) if src_ids.numel() > 0 else max_ids
max_ids = max_ids.max(dst_ids.max()) if dst_ids.numel() > 0 else max_ids
dist.all_reduce(max_ids, op=dist.ReduceOp.MAX, group=group)
return max_ids.item() + 1
def _fix_fw_routes(tensors: List[Tensor], group: Any = None) -> List[Tensor]:
rank = dist.get_rank()
world_size = dist.get_world_size()
tensor_sizes: List[torch.Size] = [t.size() for t in tensors]
all_tensor_sizes: List[List[torch.Size]] = [None] * world_size
dist.all_gather_object(all_tensor_sizes, tensor_sizes, group=group)
new_tensors: List[Tensor] = []
for i in range(world_size):
s = all_tensor_sizes[i][rank]
t = torch.zeros(s).type_as(tensors[i])
new_tensors.append(t)
all_to_all(new_tensors, tensors, group=group)
return new_tensors
_global_route_context: Optional[RouteCtx] = None
def get_global_route_context() -> RouteCtx:
global _global_route_context
assert _global_route_context is not None
return _global_route_context
_THREAD_EXEC: Optional[RouteExecutor] = None
def get_route_executor() -> RouteExecutor:
global _THREAD_EXEC
if _THREAD_EXEC is None:
_THREAD_EXEC = RouteExecutor()
return _THREAD_EXEC
# import torch
# from torch import Tensor
# from typing import *
# class ShrinkData:
# no_idx: int = (2**62-1)*2+1
# def __init__(self,
# src_size: int,
# dst_size: int,
# dst_idx: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> None:
# device = dst_idx.device
# tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
# tmp.fill_(0)
# tmp.index_fill_(0, dst_idx, 1)
# edge_idx = torch.where(tmp[edge_index[1]])[0]
# edge_index = edge_index[:, edge_idx]
# if bipartite:
# tmp.fill_(0)
# tmp.index_fill_(0, edge_index[0], 1)
# src_idx = torch.where(tmp)[0]
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
# dst = imp[edge_index[1]]
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# src = imp[edge_index[0]]
# edge_index = torch.vstack([src, dst])
# else:
# tmp.index_fill_(0, edge_index[0], 1)
# tmp.index_fill_(0, dst_idx, 0)
# src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edge_index = edge_index
# self._src_imp = imp[:src_size]
# def to(self, device):
# self.src_idx = self.src_idx.to(device)
# self.dst_idx = self.dst_idx.to(device)
# self.edge_idx = self.edge_idx.to(device)
# self.edge_index = self.edge_index.to(device)
# self._src_imp = self._src_imp.to(device)
# return self
# def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
# idx = self._src_imp[idx]
# m = (idx != self.no_idx)
# return val[m], idx[m]
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
# @property
# def edge_size(self) -> int:
# return self.edge_idx.size(0)
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# class StraightContext:
# def __init__(self, this, g) -> None:
# self.this = this
# self.g = g
# class Straight(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_samples: int,
# norm_kwargs: Optional[Dict[str, Any]] = None,
# beta: float = 1.0,
# prev: Optional[List[Any]] = None,
# ) -> None:
# super().__init__()
# assert num_samples <= num_nodes
# self.num_nodes = num_nodes
# self.num_samples = num_samples
# self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
# self.prev = prev
# self.register_buffer("last_w", torch.ones(num_nodes))
# self._next_idx = None
# self._next_shrink_helper = None
# self.reset_parameters()
# def reset_parameters(self):
# last_w = self.get_buffer("last_w")
# nn.init.constant_(last_w, 1.0)
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# g,
# ) -> Tensor:
# with self._manager(self, g):
# return StraightFunction.apply(val, idx)
# def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
# if not self.training:
# return None, None
# next_idx = self._next_idx
# self._next_idx = None
# next_sh = self._next_shrink_helper
# self._next_shrink_helper = None
# return next_idx, next_sh
# def _sample_next(self) -> Tensor:
# w = self.get_buffer("last_w")
# if self._next_idx is None:
# if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
# else:
# self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
# elif self.num_samples < self._next_idx.size(0):
# idx = self.sample_impl(w[self._next_idx])
# self._next_idx = self._next_idx[idx]
# return self._next_idx
# def sample_impl(self, w: Tensor) -> Tensor:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
# # def multinomial(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# # w = w / w.sum()
# # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
# # def multinomial_mask(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
# # w = self.multinomial(num_samples, replacement)
# # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
# # return m
# @contextmanager
# def _manager(self, this, g):
# global _global_straight_context
# stacked = _global_straight_context
# try:
# _global_straight_context = StraightContext(this=this, g=g)
# yield _global_straight_context
# finally:
# _global_straight_context = stacked
# class StraightFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# ):
# from ..graph import DistGraph
# stx = _last_global_straight_context()
# this: Straight = stx.this
# g: DistGraph = stx.g
# last_w = this.get_buffer("last_w")
# if idx is None:
# assert val.size(0) == last_w.size(0)
# else:
# assert val.size(0) == idx.size(0)
# if this.training:
# ctx.this = this
# ctx.g = g
# ctx.save_for_backward(idx)
# return val
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
# if this.prev is not None or this._next_idx is None:
# from ..nn.convs.utils import ShrinkHelper
# dst_idx = this._sample_next()
# if this.prev is not None and this._next_idx is not None:
# this._next_shrink_helper = ShrinkHelper(g, dst_idx)
# src_idx = this._next_shrink_helper.src_idx
# work = g.route.gather_backward(
# src_idx, src_idx, async_op=False, return_idx=True)
# prev_dst_idx = work.get_idx()
# for p in this.prev:
# assert isinstance(p, Straight)
# p._next_idx = prev_dst_idx
# return grad, None
# #### private functions
# _global_straight_context: Optional[StraightContext] = None
# def _last_global_straight_context() -> StraightContext:
# global _global_straight_context
# assert _global_straight_context is not None
# return _global_straight_context
# if __name__ == "__main__":
# s = Straight(3, beta=1.1)
# x = torch.rand(3, 10).requires_grad_()
# m = torch.tensor([0, 1, 0], dtype=torch.bool)
# s(x).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x, m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x[m], m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# print(s.multinomial_mask(2))
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from torch_scatter import scatter_sum
# from starrygl.core.route import Route
# def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
# dst_size = route.src_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
# in_deg, _ = route.forward_a2a(in_deg)
# return in_deg
# def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
# src_size = route.dst_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
# out_deg, _ = route.backward_a2a(out_deg)
# out_deg, _ = route.forward_a2a(out_deg)
# return out_deg
# def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
# in_deg = compute_in_degree(edge_index, route)
# out_deg = compute_out_degree(edge_index, route)
# a = in_deg[edge_index[0]].pow(-0.5)
# b = out_deg[edge_index[0]].pow(-0.5)
# x = a * b
# x[x.isinf()] = 0.0
# x[x.isnan()] = 0.0
# return x
# import torch
# import torch.nn as nn
# from torch import Tensor
# from ..graph import DistGraph
# from typing import *
# from .metrics import *
# def train_epoch(
# model: nn.Module,
# opt: torch.optim.Optimizer,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> float:
# model.train()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss: Tensor = criterion(pred, targ)
# opt.zero_grad()
# loss.backward()
# opt.step()
# with torch.no_grad():
# train_loss = all_reduce_loss(loss, targ.size(0))
# train_acc = accuracy(pred.argmax(dim=-1), targ)
# return train_loss, train_acc
# @torch.no_grad()
# def eval_epoch(
# model: nn.Module,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> Tuple[float, float]:
# model.eval()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss = criterion(pred, targ)
# eval_loss = all_reduce_loss(loss, targ.size(0))
# eval_acc = accuracy(pred.argmax(dim=-1), targ)
# return eval_loss, eval_acc
\ No newline at end of file
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# from contextlib import contextmanager
# # import torch_sparse
# from .ndata import NData
# from .edata import EData
# from .utils import init_local_edge_index
# from ..core import MessageCache, Route
# class DistGraph:
# def __init__(self,
# ids: Tensor,
# edge_index: Tensor,
# num_features: int,
# num_layers: int,
# cache_device: str = "cpu",
# **args: Dict[str, Any],
# ):
# # build local_edge_index
# dst_ids = ids
# src_ids, local_edge_index = init_local_edge_index(
# dst_ids=dst_ids,
# edge_index=edge_index,
# )
# self._src_ids = src_ids
# self._dst_ids = dst_ids
# self._message_cache = MessageCache(
# src_ids=src_ids,
# dst_ids=dst_ids,
# edge_index=local_edge_index,
# num_features=num_features,
# num_layers=num_layers,
# cache_device=cache_device,
# bipartite=False,
# )
# # node's attributes
# self.ndata = NData(
# src_size=src_ids.size(0),
# dst_size=dst_ids.size(0),
# )
# # edge's attributes
# self.edata = EData(
# edge_size=local_edge_index.size(1),
# )
# # graph's attributes
# self.args = dict(args)
# def to(self, device):
# self._message_cache.to(device)
# return self
# def cache_data_to(self, device):
# self._message_cache.cached_data_to(device)
# @property
# def cache(self) -> MessageCache:
# return self._message_cache
# @property
# def route(self) -> Route:
# return self._message_cache.route
# @property
# def device(self) -> torch.device:
# return self.edge_index.device
# @property
# def edge_index(self) -> Tensor:
# return self._message_cache.edge_index
# @property
# def src_ids(self) -> Tensor:
# if self._src_ids.device != self.device:
# self._src_ids = self._src_ids.to(self.device)
# return self._src_ids
# @property
# def src_size(self) -> int:
# return self._src_ids.size(0)
# @property
# def dst_ids(self) -> Tensor:
# if self._dst_ids.device != self.device:
# self._dst_ids = self._dst_ids.to(self.device)
# return self._dst_ids
# @property
# def dst_size(self) -> int:
# return self._dst_ids.size(0)
# @contextmanager
# def scoped_manager(self):
# stacked_ndata = self.ndata
# stacked_edata = self.edata
# try:
# self.ndata = NData(
# p=stacked_ndata,
# )
# self.edata = EData(
# p=stacked_edata,
# )
# yield self
# finally:
# self.ndata = stacked_ndata
# self.edata = stacked_edata
# # def permute_edge_(self):
# # perm = self._local_edge_index[1].argsort()
# # self._local_edge_index = self._local_edge_index[:,perm]
# # self._local_edge_ptr = torch.ops.torch_sparse.ind2ptr(self._local_edge_index[1], self.dst_size)
# # self.edata.permute_(perm)
# # @torch.no_grad()
# # def shrink(self,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ):
# # edge_index = self.edge_index
# # device = edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # return self
# # else:
# # if dst_mask is None:
# # dst_mask = torch.zeros(self.ndata.dst_size, **bkw)
# # else:
# # assert dst_mask.size(0) == self.ndata.dst_size
# # dst_mask = dst_mask.clone()
# # if src_mask is not None:
# # m = src_mask[edge_index[0]]
# # m = edge_index[1][m]
# # dst_mask[m] = True
# # edge_mask = dst_mask[edge_index[1]]
# # edge_index = edge_index[:,edge_mask]
# # src_mask = torch.zeros(self.ndata.src_size, **bkw)
# # src_mask[:self.ndata.dst_size] = dst_mask
# # src_mask[edge_index[0]] = True
# # dst_mask = src_mask[:self.ndata.dst_size]
# # # 重新编号edge_index
# # imp = torch.empty(self.ndata.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # # assert dst_mask.count_nonzero() > edge_index[1].max().item()
# # ndata = NData(
# # src_size=None,
# # dst_size=None,
# # p=self.ndata,
# # src_mask=src_mask,
# # dst_mask=dst_mask,
# # )
# # edata = EData(
# # edge_size=None,
# # p=self.edata,
# # edge_mask=edge_mask,
# # )
# # edata[EID] = edge_index
# # return DistGraph(self.args, ndata, edata, self.route)
# # class ShrinkGraph:
# # def __init__(self,
# # g: DistGraph,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ) -> None:
# # device = g.edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # # 如果src_mask和dst_mask都不指定,
# # # 则ShrinkGraph是DistGraph的复刻
# # self.src_ids = g.sr
# # self.dst_size = g.dst_size
# # self.src_mask = torch.ones(self.src_size, **bkw)
# # self.dst_mask_size = g.dst_size
# # self.edge_index = g.edge_index
# # self.pgraph = g
# # else:
# # if src_mask is not None:
# # tmp_mask = torch.zeros(g.dst_size, **bkw)
# # # 计算直接激活的边
# # m = src_mask[g.edge_index[0]]
# # m = g.edge_index[1][m]
# # # 计算直接激活的dst_ids
# # tmp_mask[m] = True
# # # 和已激活的dst_ids合并
# # if dst_mask is not None:
# # tmp_mask |= dst_mask
# # dst_mask = tmp_mask
# # # 计算间接激活的边
# # edge_mask = dst_mask[g.edge_index[1]]
# # edge_index = g.edge_index[:,edge_mask]
# # # 计算间接激活的src_ids
# # src_mask = torch.zeros(g.src_size, **bkw)
# # src_mask[edge_index[0]] = True
# # src_mask[:g.dst_size] |= dst_mask
# # self.src_ids = g.src_ids[src_mask]
# # self.dst_size = src_mask[:g.dst_size].count_nonzero().item()
# # self.src_mask = src_mask
# # self.dst_mask_size = g.dst_size
# # # 重新编号edge_index
# # imp = torch.empty(g.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # self.edge_index = edge_index
# # self.pgraph = g
# # self.ndata = ShrinkData(self.dst_mask, g.ndata)
# # self.edata = ShrinkData(self.edge_mask, g.edata)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class EData:
# def __init__(self,
# edge_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert edge_size is not None
# self.edge_size = edge_size
# else:
# assert edge_size is None
# self.edge_size = p.edge_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError(f"the second parameter's type must be Tensor")
# if tensor.size(0) == self.edge_size:
# self.data[name] = tensor
# else:
# raise ValueError(f"tensor's shape must match the edge_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class NData:
# def __init__(self,
# src_size: Optional[int] = None,
# dst_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert src_size is not None
# assert dst_size is not None
# self.src_size = src_size
# self.dst_size = dst_size
# else:
# assert src_size is None
# assert dst_size is None
# self.src_size = p.src_size
# self.dst_size = p.dst_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError("the second parameter's type must be Tensor")
# if tensor.size(0) == self.src_size:
# self.data[name] = tensor
# elif tensor.size(0) == self.dst_size:
# self.data[name] = tensor
# else:
# raise ValueError("tensor's shape must match the src_size or dst_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
# def get_type(self, key: Union[str, Tensor]):
# if isinstance(key, Tensor):
# if key.size(0) == self.src_size:
# return "src"
# elif key.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
# t = self.__getitem__(key)
# if t.size(0) == self.src_size:
# return "src"
# elif t.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 383,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.autograd as autograd\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter\n",
"from torch_geometric.utils import softmax\n",
"\n",
"import triton\n",
"import triton.language as tl\n",
"\n",
"from copy import deepcopy"
]
},
{
"cell_type": "code",
"execution_count": 319,
"metadata": {},
"outputs": [],
"source": [
"# class SoftmaxSumFunction(autograd.Function):\n",
"# @staticmethod\n",
"# def forward(\n",
"# ctx: autograd.function.FunctionCtx,\n",
"# x: Tensor,\n",
"# a: Tensor,\n",
"# _x: Tensor,\n",
"# index: Tensor,\n",
"# mask: Tensor,\n",
"# exp_a: Tensor,\n",
"# sum_exp_a: Tensor,\n",
"# sum_exp_ah: Tensor,\n",
"# training: bool,\n",
"# ):\n",
"# dlt_exp_ah = exp_a[mask, None] * (x - _x)\n",
"# sum_exp_ah = scatter(dlt_exp_ah, index, dim=0, out=sum_exp_ah.clone(), reduce=\"sum\")\n",
"# out = sum_exp_ah / sum_exp_a[:, None]\n",
"# if training:\n",
"# w = exp_a[mask] / sum_exp_a.index_select(0, index)\n",
"# ctx.save_for_backward(out, x, w, index)\n",
"# return out\n",
" \n",
"# @staticmethod\n",
"# def backward(\n",
"# ctx: autograd.function.FunctionCtx,\n",
"# grad: Tensor,\n",
"# ):\n",
"# out, x, w, index = ctx.saved_tensors\n",
"# g = grad.index_select(0, index)\n",
"# o = out.index_select(0, index)\n",
"\n",
"# grad_x = g * w[:, None]\n",
"# grad_a = (g * (x - o)).sum(dim=-1) * w\n",
"# return grad_x, grad_a, None, None, None, None, None, None, None"
]
},
{
"cell_type": "code",
"execution_count": 320,
"metadata": {},
"outputs": [],
"source": [
"# class SoftmaxSum(nn.Module):\n",
"# def __init__(self,\n",
"# num_nodes: int,\n",
"# num_edges: int,\n",
"# num_features: int,\n",
"# dtype: Optional[torch.dtype] = None,\n",
"# ) -> None:\n",
"# super().__init__()\n",
"# self.num_nodes = num_nodes\n",
"# self.num_edges = num_edges\n",
"\n",
"# self.register_buffer(\"max_hist_a\", torch.zeros(num_nodes, dtype=dtype))\n",
"# self.register_buffer(\"exp_hist_a\", torch.zeros(num_edges, dtype=dtype))\n",
"# self.register_buffer(\"sum_exp_a\", torch.zeros(num_nodes, dtype=dtype))\n",
"# self.register_buffer(\"sum_exp_ah\", torch.zeros(num_nodes, num_features, dtype=dtype))\n",
" \n",
"# @torch.no_grad()\n",
"# def update_snapshot(self, x: Tensor, a: Tensor, index: Tensor):\n",
"# max_a = scatter(a, index, dim=0, dim_size=self.num_nodes, reduce=\"max\")\n",
"# exp_a = (a - max_a.index_select(0, index)).exp_()\n",
"# self.get_buffer(\"max_hist_a\")[:] = max_a\n",
"# self.get_buffer(\"exp_hist_a\")[:] = exp_a\n",
"\n",
"# sum_exp_a = scatter(exp_a, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\") + 1e-10\n",
"# sum_exp_ah = scatter(exp_a[:,None] * x, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\")\n",
"# self.get_buffer(\"sum_exp_a\")[:] = sum_exp_a\n",
"# self.get_buffer(\"sum_exp_ah\")[:] = sum_exp_ah\n",
"# return self\n",
" \n",
"# @torch.no_grad()\n",
"# def update_attention(self, x: Tensor, a: Tensor, _x: Tensor, index: Tensor, mask: Tensor):\n",
"# exp_a = (a - self.get_buffer(\"max_hist_a\").index_select(0, index)).exp_()\n",
"# exp_ah = exp_a[:, None] * x\n",
"\n",
"# exp_hist_a = self.get_buffer(\"exp_hist_a\")[mask]\n",
"# exp_hist_ah = exp_hist_a[:, None] * _x\n",
"\n",
"# self.get_buffer(\"exp_hist_a\")[mask] = exp_a\n",
"# scatter(exp_a - exp_hist_a, index, dim=0, out=self.get_buffer(\"sum_exp_a\"), reduce=\"sum\")\n",
"# scatter(exp_ah - exp_hist_ah, index, dim=0, out=self.get_buffer(\"sum_exp_ah\"), reduce=\"sum\")\n",
"# return self\n",
" \n",
"# def forward(self, x: Tensor, a: Tensor, _x: Tensor, index: Tensor, mask: Tensor):\n",
"# return SoftmaxSumFunction.apply(x, a, _x, index, mask,\n",
"# self.get_buffer(\"exp_hist_a\"),\n",
"# self.get_buffer(\"sum_exp_a\"),\n",
"# self.get_buffer(\"sum_exp_ah\"),\n",
"# self.training)"
]
},
{
"cell_type": "code",
"execution_count": 321,
"metadata": {},
"outputs": [],
"source": [
"# num_nodes = 200\n",
"# num_edges = 100\n",
"# num_features = 8\n",
"# for i in range(10):\n",
"# x = torch.randn(num_edges, num_features)\n",
"# a = torch.randn(num_edges)\n",
"# index = torch.randint(num_nodes, size=(num_edges,), dtype=torch.long)\n",
"# mask = torch.ones(num_edges, dtype=torch.bool)\n",
"\n",
"# ss = SoftmaxSum(num_nodes, num_edges, num_features)\n",
"# ss.update_snapshot(x, a, index)\n",
"# ss.update_attention(x, a, x, index, mask)\n",
"\n",
"# x.requires_grad_()\n",
"# a.requires_grad_()\n",
"# ss(x, a, x, index, mask).sum().backward()\n",
"# grad_x0 = x.grad.clone()\n",
"# grad_a0 = a.grad.clone()\n",
"\n",
"# x.grad = None\n",
"# a.grad = None\n",
"# t = softmax(a, index, num_nodes=num_nodes)[:,None] * x\n",
"# scatter(t, index, dim=0, dim_size=num_nodes, reduce=\"sum\").sum().backward()\n",
"\n",
"# # print(torch.norm(grad_x0 - x.grad), torch.norm(grad_a0 - a.grad))\n",
"# # print(torch.abs(grad_x0 - x.grad).max(), torch.abs(grad_a0 - a.grad).max())\n",
"# # print(grad_x0)\n",
"# # print(grad_a0)\n",
"# print(grad_a0 - a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 355,
"metadata": {},
"outputs": [],
"source": [
"class SoftmaxSumFunction(autograd.Function):\n",
" @staticmethod\n",
" def forward(\n",
" ctx: autograd.function.FunctionCtx,\n",
" x1: Tensor,\n",
" a1: Tensor,\n",
" x0: Tensor,\n",
" a0: Tensor,\n",
" index: Tensor,\n",
" max_a: Tensor,\n",
" sum_exp_a: Tensor,\n",
" sum_exp_ah: Tensor,\n",
" training: bool,\n",
" ):\n",
" exp_a = (a0 - max_a.index_select(0, index)).exp_()\n",
" dlt_exp_ah = exp_a[:, :, None] * (x1 - x0)\n",
" sum_exp_ah = scatter(dlt_exp_ah, index, dim=0, out=sum_exp_ah.clone(), reduce=\"sum\")\n",
" out = sum_exp_ah / sum_exp_a[:, :, None]\n",
" if training:\n",
" w = exp_a / sum_exp_a.index_select(0, index)\n",
" ctx.save_for_backward(out, x1, w, index)\n",
" return out\n",
" \n",
" @staticmethod\n",
" def backward(\n",
" ctx: autograd.function.FunctionCtx,\n",
" grad: Tensor,\n",
" ):\n",
" out, x, w, index = ctx.saved_tensors\n",
" g = grad.index_select(0, index)\n",
" o = out.index_select(0, index)\n",
"\n",
" grad_x = g * w[:, :, None]\n",
" grad_a = (g * (x - o)).sum(dim=-1) * w\n",
" return grad_x, grad_a, None, None, None, None, None, None, None\n",
"\n",
"class SoftmaxSum(nn.Module):\n",
" def __init__(self,\n",
" num_nodes: int,\n",
" num_edges: int,\n",
" num_heads: int,\n",
" num_features: int,\n",
" dtype: Optional[torch.dtype] = None,\n",
" ) -> None:\n",
" super().__init__()\n",
" self.num_nodes = num_nodes\n",
" self.num_edges = num_edges\n",
" self.num_heads = num_heads\n",
"\n",
" self.register_buffer(\"max_a\", torch.zeros(num_nodes, num_heads, dtype=dtype))\n",
" self.register_buffer(\"sum_exp_a\", torch.zeros(num_nodes, num_heads, dtype=dtype))\n",
" self.register_buffer(\"sum_exp_ah\", torch.zeros(num_nodes, num_heads, num_features, dtype=dtype))\n",
" \n",
" @torch.no_grad()\n",
" def update_snapshot(self, x: Tensor, a: Tensor, index: Tensor):\n",
" max_a = scatter(a, index, dim=0, dim_size=self.num_nodes, reduce=\"max\")\n",
" exp_a = (a - max_a.index_select(0, index)).exp_()\n",
" self.get_buffer(\"max_a\")[:] = max_a\n",
"\n",
" sum_exp_a = scatter(exp_a, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\") + 1e-10\n",
" sum_exp_ah = scatter(exp_a[:, :, None] * x, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\")\n",
" self.get_buffer(\"sum_exp_a\")[:] = sum_exp_a\n",
" self.get_buffer(\"sum_exp_ah\")[:] = sum_exp_ah\n",
" return self\n",
" \n",
" @torch.no_grad()\n",
" def update_attention(self, x1: Tensor, a1: Tensor, x0: Tensor, a0: Tensor, index: Tensor):\n",
" max_hist_a = self.get_buffer(\"max_a\").index_select(0, index)\n",
" exp_a0 = (a0 - max_hist_a).exp_()\n",
" exp_a1 = (a1 - max_hist_a).exp_()\n",
" scatter(exp_a1 - exp_a0, index, dim=0, out=self.get_buffer(\"sum_exp_a\"), reduce=\"sum\")\n",
"\n",
" exp_ah0 = exp_a0[:, :, None] * x0\n",
" exp_ah1 = exp_a1[:, :, None] * x1\n",
" scatter(exp_ah1 - exp_ah0, index, dim=0, out=self.get_buffer(\"sum_exp_ah\"), reduce=\"sum\")\n",
" return self\n",
" \n",
" def forward(self, x1: Tensor, a1: Tensor, x0: Tensor, a0: Tensor, index: Tensor) -> Tensor:\n",
" return SoftmaxSumFunction.apply(x1, a1, x0, a0, index,\n",
" self.get_buffer(\"max_a\"),\n",
" self.get_buffer(\"sum_exp_a\"),\n",
" self.get_buffer(\"sum_exp_ah\"),\n",
" self.training)"
]
},
{
"cell_type": "code",
"execution_count": 359,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.) tensor(6.2413e-06)\n",
"tensor([[-2.1367, -3.5242],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.8726, -6.0982]])\n",
"tensor([[-2.1367, -3.5242],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.8726, -6.0982]])\n",
"tensor(0.) tensor(7.5973e-06)\n",
"tensor([[ 4.1189, -1.3866],\n",
" [ 6.2586, -4.8426],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7955, 8.4197],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 4.1189, -1.3866],\n",
" [ 6.2586, -4.8426],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7955, 8.4197],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.0887e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [-1.7955, 4.7750],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-3.7791, -3.4205],\n",
" [ 3.1130, 3.6583],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [-1.7955, 4.7750],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-3.7791, -3.4205],\n",
" [ 3.1130, 3.6583],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(5.9715e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 9.4426, -2.2261],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7483, 3.0900],\n",
" [-9.4426, 2.2261]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 9.4426, -2.2261],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7483, 3.0900],\n",
" [-9.4426, 2.2261]])\n",
"tensor(0.) tensor(7.1696e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.2263, 0.6597],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.2893, -4.1325],\n",
" [-3.9234, -3.9451],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.2263, 0.6597],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.2893, -4.1325],\n",
" [-3.9234, -3.9451],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.6798e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 1.8056, 3.9442],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.4050, -5.5951]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 1.8056, 3.9442],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.4050, -5.5951]])\n",
"tensor(0.) tensor(4.7071e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ -3.6159, -10.4370],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.1697, 4.4867],\n",
" [ 0.8153, 9.0103]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ -3.6159, -10.4370],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.1697, 4.4867],\n",
" [ 0.8153, 9.0103]])\n",
"tensor(0.) tensor(6.2858e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 5.5491, -3.5294],\n",
" [ 2.2001, -5.5199],\n",
" [ 1.6753, 1.5286],\n",
" [ 5.7522, -2.5286]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 5.5491, -3.5294],\n",
" [ 2.2001, -5.5199],\n",
" [ 1.6753, 1.5286],\n",
" [ 5.7522, -2.5286]])\n",
"tensor(0.) tensor(5.8452e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-0.3675, 3.0714],\n",
" [ 0.4008, -9.1039],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-0.3675, 3.0714],\n",
" [ 0.4008, -9.1039],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.5390e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 4.4908, 0.8768],\n",
" [ 0.1810, 2.1797],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-4.4908, -0.8768]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 4.4908, 0.8768],\n",
" [ 0.1810, 2.1797],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-4.4908, -0.8768]])\n"
]
}
],
"source": [
"num_nodes = 200\n",
"num_edges = 100\n",
"num_heads = 2\n",
"num_features = 256\n",
"for i in range(10):\n",
" x = torch.randn(num_edges, num_heads, num_features)\n",
" a = torch.randn(num_edges, num_heads)\n",
" index = torch.randint(num_nodes, size=(num_edges,), dtype=torch.long)\n",
"\n",
" ss = SoftmaxSum(num_nodes, num_edges, num_heads, num_features)\n",
" ss.update_snapshot(x, a, index)\n",
" ss.update_attention(x, a, x, a, index)\n",
"\n",
" x.requires_grad_()\n",
" a.requires_grad_()\n",
" ss(x, a, x, a, index).sum().backward()\n",
" grad_x0 = x.grad.clone()\n",
" grad_a0 = a.grad.clone()\n",
"\n",
" x.grad = None\n",
" a.grad = None\n",
" t = softmax(a, index, num_nodes=num_nodes)[:, :, None] * x\n",
" scatter(t, index, dim=0, dim_size=num_nodes, reduce=\"sum\").sum().backward()\n",
"\n",
" print(torch.norm(grad_x0 - x.grad), torch.norm(grad_a0 - a.grad))\n",
" # print(torch.abs(grad_x0 - x.grad).max(), torch.abs(grad_a0 - a.grad).max())\n",
" # print(grad_x0)\n",
" print(grad_a0[:8, :])\n",
" print(a.grad[:8, :])\n",
" # print(grad_a0 - a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 399,
"metadata": {},
"outputs": [],
"source": [
"class GATConv(nn.Module):\n",
" def __init__(self,\n",
" in_channels: int,\n",
" out_channels: int,\n",
" heads: int = 1,\n",
" concat: bool = False,\n",
" negative_slope: float = 0.2,\n",
" bias: bool = True,\n",
" **kwargs\n",
" ) -> None:\n",
" super().__init__()\n",
" \n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
" self.heads = heads\n",
" self.concat = concat\n",
" self.negative_slope = negative_slope\n",
" self._softmax_sum_list = [None]\n",
" \n",
" self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))\n",
" self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))\n",
" self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))\n",
" \n",
" if bias and concat:\n",
" self.bias = nn.Parameter(torch.Tensor(heads * out_channels))\n",
" elif bias and not concat:\n",
" self.bias = nn.Parameter(torch.Tensor(out_channels))\n",
" else:\n",
" self.bias = None\n",
"\n",
" self.reset_parameters()\n",
" \n",
" def reset_parameters(self):\n",
" nn.init.xavier_normal_(self.weight)\n",
" nn.init.xavier_normal_(self.att_src)\n",
" nn.init.xavier_normal_(self.att_dst)\n",
" if self.bias is not None:\n",
" nn.init.zeros_(self.bias)\n",
" \n",
" @property\n",
" def softmax_sum(self) -> Optional[SoftmaxSum]:\n",
" return self._softmax_sum_list[0]\n",
" \n",
" def set_softmax_sum(self, m: Optional[SoftmaxSum]):\n",
" self._softmax_sum_list = [m]\n",
" return self\n",
" \n",
" def train(self, mode: bool = True):\n",
" if self.softmax_sum is not None:\n",
" self.softmax_sum.train(mode)\n",
" return super().train(mode)\n",
" \n",
" def update_softmax_snapshot(self,\n",
" x: Tensor,\n",
" edge_index: Tensor,\n",
" snap_net: nn.Module,\n",
" ):\n",
" snap_net: GATConv = snap_net.eval()\n",
" with torch.no_grad():\n",
" x, a = snap_net._att_x(x, edge_index)\n",
" self.softmax_sum.update_snapshot(x, a, edge_index[1])\n",
"\n",
" def _att_x(self, x: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]:\n",
" H, C = self.heads, self.out_channels\n",
" x = (x @ self.weight).view(-1, H, C)\n",
"\n",
" src_a = (x * self.att_src).sum(dim=-1)\n",
" src_a = src_a[edge_index[0]]\n",
" dst_a = (x * self.att_dst).sum(dim=-1)\n",
" dst_a = dst_a[edge_index[1]]\n",
"\n",
" x = x[edge_index[0]]\n",
" a = F.leaky_relu(src_a + dst_a, self.negative_slope)\n",
" return x, a\n",
" \n",
" def forward(self,\n",
" x1: Tensor,\n",
" x0: Tensor,\n",
" edge_index: Tensor,\n",
" snap_net: nn.Module,\n",
" update_att: bool,\n",
" ) -> Tensor:\n",
" H, C = self.heads, self.out_channels\n",
" snap_net: GATConv = snap_net.eval()\n",
"\n",
" index = edge_index[1]\n",
" with torch.no_grad():\n",
" x0, a0 = snap_net._att_x(x0, edge_index)\n",
" if update_att:\n",
" x, a = snap_net._att_x(x1, edge_index)\n",
" self.softmax_sum.update_attention(x, a, x0, a0, index)\n",
" x0, a0 = x, a\n",
" \n",
" x1, a1 = self._att_x(x1, edge_index)\n",
" out = self.softmax_sum.forward(x1, a1, x0, a0, index)\n",
" \n",
" if self.concat:\n",
" out = out.view(-1, H * C)\n",
" else:\n",
" out = out.mean(dim=1)\n",
"\n",
" if self.bias is not None:\n",
" out += self.bias\n",
" return out\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 400,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def slim_local_graph(src_mask: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor]:\n",
" edge_mask = src_mask[edge_index[0]]\n",
" slim_edge_index = edge_index[:, edge_mask]\n",
" \n",
" node_mask = src_mask.clone().index_fill_(0, slim_edge_index[1], 1)\n",
"\n",
" idx = torch.where(node_mask)[0]\n",
" imp = torch.empty_like(src_mask, dtype=torch.long).fill_((2**62-1)*2+1)\n",
" imp[idx] = torch.arange(idx.size(0))\n",
" \n",
" slim_edge_index = imp[slim_edge_index.flatten()].view_as(slim_edge_index)\n",
" return node_mask, edge_mask, slim_edge_index"
]
},
{
"cell_type": "code",
"execution_count": 401,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([200])\n",
"torch.Size([100])\n",
"torch.Size([2, 35])\n"
]
}
],
"source": [
"src_mask = torch.rand(num_nodes) < 0.3\n",
"edge_index = torch.randint(num_nodes, size=(2, num_edges), dtype=torch.long)\n",
"node_mask, edge_mask, slim_edge_index = slim_local_graph(src_mask, edge_index)\n",
"print(node_mask.size())\n",
"print(edge_mask.size())\n",
"print(slim_edge_index.size())"
]
},
{
"cell_type": "code",
"execution_count": 484,
"metadata": {},
"outputs": [],
"source": [
"net = GATConv(3, num_features, num_heads)\n",
"att = SoftmaxSum(num_nodes, num_edges, num_heads, num_features)\n",
"net.set_softmax_sum(att)\n",
"\n",
"snap_net = deepcopy(net)"
]
},
{
"cell_type": "code",
"execution_count": 485,
"metadata": {},
"outputs": [],
"source": [
"x1 = torch.randn(num_nodes, 3)\n",
"x0 = torch.randn(num_nodes, 3)\n",
"edge_index = torch.randint(num_nodes, size=(2, num_edges), dtype=torch.long)\n",
"\n",
"net.update_softmax_snapshot(x0, edge_index, snap_net)"
]
},
{
"cell_type": "code",
"execution_count": 511,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.1482, -0.0744, -0.0213, ..., -0.0421, -0.1849, -0.1039],\n",
" [-0.0075, 0.0321, -0.0046, ..., 0.0154, 0.0506, -0.0110],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" ...,\n",
" [ 0.0878, -0.0887, -0.0550, ..., -0.0982, -0.1763, -0.0175],\n",
" [ 0.1193, 0.0695, -0.0637, ..., 0.0060, 0.0391, -0.1508],\n",
" [ 0.0294, 0.0186, 0.0140, ..., 0.0327, 0.0128, -0.0487]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 511,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = net.forward(x1, x0, edge_index, snap_net, update_att=True)\n",
"x0 = x1\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import torch
import torch.nn as nn
from torch import Tensor
from typing import *
from torch_sparse import SparseTensor
from torch_scatter import scatter_min
from multiprocessing.pool import ThreadPool
from .buffer import TensorBuffer
from .route import Route
from .context import RouteContext
# class BatchHandle:
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# feat_buffer: TensorBuffers,
# grad_buffer: TensorBuffers,
# with_feat0: bool = False,
# layer_id: Optional[int] = None,
# target_device: Any = None,
# ) -> None:
# self.src_ids = src_ids
# self.dst_size = dst_size
# self.feat_buffer = feat_buffer
# self.grad_buffer = grad_buffer
# self.with_feat0 = with_feat0
# self.layer_id = layer_id
# if target_device is None:
# self.target_device = src_ids.device
# else:
# self.target_device = torch.device(target_device)
# if with_feat0:
# _ = self.feat0
# @property
# def dst_ids(self) -> Tensor:
# return self.src_ids[:self.dst_size]
# @property
# def src_size(self) -> int:
# return self.src_ids.size(0)
# @property
# def device(self) -> torch.device:
# return self.src_ids.device
# @property
# def feat0(self) -> Tensor:
# if not hasattr(self, "_feat0"):
# self._feat0 = self.feat_buffer.get(0, self.src_ids).to(self.target_device)
# return self._feat0
# def fetch_feat(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.feat_buffer.get(layer_id, self.src_ids).to(self.target_device)
# def update_feat(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# def push_and_pull(self, x: Tensor, layer_id: Optional[int] = None) -> Tensor:
# assert x.size(0) == self.dst_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.feat_buffer.set(layer_id + 1, self.dst_ids, x.detach().to(self.device))
# o = self.feat_buffer.get(layer_id + 1, self.src_ids[self.dst_size:]).to(x.device)
# return torch.cat([x, o], dim=0)
# def fetch_grad(self, layer_id: Optional[int] = None) -> Tensor:
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# return self.grad_buffer.get(layer_id + 1, self.dst_ids).to(self.target_device)
# def accumulate_grad(self, x: Tensor, layer_id: Optional[int] = None):
# assert x.size(0) == self.src_size
# layer_id = int(self.layer_id if layer_id is None else layer_id)
# self.grad_buffer.add(layer_id, self.src_ids, x.detach().to(self.device))
class NodeHandle:
def __init__(self,
src_ids: Tensor,
dst_size: int,
edge_ids: Tensor,
adj_t: SparseTensor,
executor: ThreadPool,
device: Any,
stream: Optional[torch.cuda.Stream] = None,
edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
) -> None:
self._src_ids = src_ids
self._dst_size = dst_size
self._edge_ids = edge_ids
self._device = torch.device(device)
if self._device != torch.device("cpu") and stream is None:
self._stream = torch.cuda.Stream(self._device)
else:
self._stream = stream
self._executor = executor
self._adj_t = adj_t.to(self.device)
self._node_time = node_time
self._edge_time = edge_time
def get_src_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.src_ids)
def get_dst_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.dst_ids)
def get_ext_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.ext_ids)
def get_edge_feats(self, data: Union[Tensor, TensorBuffer]):
return self.async_select(data, self.edge_ids)
def get_src_time(self):
return self.async_select(self._node_time, self.src_ids)
def get_dst_time(self):
return self.async_select(self._node_time, self.dst_ids)
def get_ext_time(self):
return self.async_select(self._node_time, self.ext_ids)
def get_edge_time(self) -> Tensor:
return self.async_select(self._edge_time, self.edge_ids)
def push_and_pull(self, value: Tensor, ext_fut: torch.futures.Future[Tensor], data: Union[Tensor, TensorBuffer]):
fut = self.async_update(data, value, self.dst_ids)
ext_value = ext_fut.wait()
x = torch.cat([value, ext_value], dim=0)
return x, fut
def async_select(self, src: Union[Tensor, TensorBuffer], idx: Tensor) -> torch.futures.Future[Tensor]:
if self._device != torch.device("cpu"):
fut = torch.futures.Future(devices=[self._device])
else:
fut = torch.futures.Future()
def run():
try:
with torch.no_grad():
with torch.cuda.stream(self._stream):
index = idx.to(src.device)
if isinstance(src, TensorBuffer):
val = src.local_get(index, lock=True)
else:
val = src.index_select(0, index)
val = val.to(self._device)
fut.set_result(val)
except Exception as e:
fut.set_exception(e)
self._executor.apply_async(run)
return fut
def async_update(self, src: Union[Tensor, TensorBuffer], val: Tensor, idx: Tensor, ops: str = "mov") -> torch.futures.Future:
assert ops in ["mov", "add"]
if self._device != torch.device("cpu"):
fut = torch.futures.Future(devices=[self._device])
else:
fut = torch.futures.Future()
def run():
try:
with torch.no_grad():
with torch.cuda.stream(self._stream):
value = val.to(src.device)
index = idx.to(src.device)
if isinstance(src, TensorBuffer):
if ops == "mov":
src.all_remote_set(value, index, lock=True)
elif ops == "add":
src.all_remote_add(value, index, lock=True)
else:
if ops == "mov":
src[index] = value
elif ops == "add":
src[index] += value
fut.set_result(None)
except Exception as e:
fut.set_exception(e)
self._executor.apply_async(run)
return fut
@property
def adj_t(self) -> SparseTensor:
return self._adj_t
@property
def src_ids(self) -> Tensor:
return self._src_ids
@property
def src_size(self) -> int:
return self._src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self._src_ids[:self._dst_size]
@property
def dst_size(self) -> int:
return self._dst_size
@property
def ext_ids(self) -> Tensor:
return self._src_ids[self._dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
@property
def edge_ids(self) -> Tensor:
return self._edge_ids
@property
def edge_size(self) -> Tensor:
return self._edge_ids.size(0)
@property
def device(self) -> torch.device:
return self._device
class NodeLoader:
def __init__(self,
global_ids: Tensor,
global_edges: Tensor,
device: Any,
edge_time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
num_threads: int = 1,
) -> None:
self.route, edge_index = Route.from_edge_index(global_ids, global_edges)
if node_time is None and edge_time is not None:
node_time = scatter_min(edge_time[edge_index[0]], edge_index[1], dim=0, dim_size=self.dst_size)
node_ids = torch.arange(self.dst_size).type_as(global_ids)
if node_time is not None:
perm = node_time.argsort()
node_ids = node_ids[perm]
self.node_ids = node_ids
self.node_time = node_time
edge_ids = torch.arange(edge_index.size(1)).type_as(global_ids)
perm = edge_index[1].argsort()
edge_ids = edge_ids[perm]
edge_index = edge_index[:,perm]
self.edge_ids = edge_ids
self.edge_index = edge_index
self.edge_time = edge_time
# self.rowptr = torch.ops.torch_sparse.ind2ptr(self.edge_index[1], self.dst_size)
self.device = torch.device(device)
if self.device != torch.device("cpu"):
self.stream = torch.cuda.Stream(self.device)
else:
self.stream = None
self.executor = ThreadPool(num_threads)
@property
def src_size(self) -> int:
return self.route.src_size
@property
def dst_size(self) -> int:
return self.route.dst_size
# def _select_nodes(self, start_time: int, end_time: int) -> Tuple[Optional[Tensor], Optional[Tensor]]:
# if self.node_time is None:
# if self.edge_time is None:
# raise RuntimeError("neither node_time nor edge_time exists!")
# else:
# edge_mask = (start_time <= self.edge_time) & (self.edge_time < end_time)
# idx = self.edge_index[1, edge_mask]
# node_mask = torch.zeros(self.dst_size, dtype=torch.bool, device=idx.device).index_fill_(0, idx, 1)
# return torch.where(node_mask)[0], None
# else:
# time_range = torch.tensor([start_time, end_time]).type_as(self.node_time)
# s, t = torch.searchsorted(self.node_time, time_range).tolist()
# return self.node_ids[s:t], self.node_time[s:t]
def _get_node_handle(self, node_ids: Tensor, edge_ids: Tensor, edge_index: Tensor) -> NodeHandle:
num_node: int = torch.max(node_ids.max(), edge_index.max()).item() + 1
imp = torch.zeros(num_node, dtype=torch.bool, device=node_ids.device)
imp.index_fill_(0, edge_index[0], 1)
imp.index_fill_(0, node_ids, 0)
src_ids = torch.cat([node_ids, torch.where(imp)[0]], dim=0)
src_size = src_ids.size(0)
dst_size = node_ids.size(0)
imp = torch.zeros(num_node, dtype=torch.long, device=node_ids.device).fill_((2**62-1)*2+1)
imp[src_ids] = torch.arange(src_size, dtype=torch.long, device=node_ids.device)
edge_index = imp[edge_index.flatten()].view_as(edge_index)
perm = edge_index[1].argsort()
edge_ids = edge_ids[perm]
edge_index = edge_index[:,perm]
rowptr = torch.ops.torch_sparse.ind2ptr(edge_index[1], dst_size)
adj_t = SparseTensor(rowptr=rowptr, col=edge_index[0], sparse_sizes=(dst_size, src_size))
return NodeHandle(
src_ids=src_ids,
dst_size=dst_size,
edge_ids=edge_ids,
adj_t=adj_t,
executor=self.executor,
device=self.device,
stream=self.stream,
edge_time=self.edge_time,
node_time=self.node_time,
)
# very very very slow !!!
def iter(self, batch_size: int, wind: bool = False, start_time: int = 0, end_time: int = 9223372036854775807, seed: int = 0):
rand_gen = torch.Generator()
if seed != 0:
rand_gen.manual_seed(seed)
if not wind:
sampled = torch.randperm(self.dst_size, generator=rand_gen, dtype=torch.long)
for s in range(0, sampled.size(0), batch_size):
t = min(s + batch_size, sampled.size(0))
node_ids = sampled[s:t]
imp = torch.zeros(self.dst_size, dtype=torch.bool, device=node_ids.device)
imp.index_fill_(0, node_ids, 1)
edge_mask = imp[self.edge_index[1]]
edge_ids = self.edge_ids[edge_mask]
edge_index = self.edge_index[:,edge_mask]
yield self._get_node_handle(node_ids, edge_ids, edge_index)
else:
assert self.node_time is not None
node_mask = (start_time <= self.node_time) & (self.node_time < end_time)
all_node_ids = torch.where(node_mask)[0]
sampled = torch.randperm(node_ids.size(0), generator=rand_gen, dtype=torch.long)
all_node_ids = all_node_ids[sampled]
for s in range(0, sampled.size(0), batch_size):
t = min(s + batch_size, sampled.size(0))
node_ids = all_node_ids[s:t]
imp = torch.zeros(self.dst_size, dtype=torch.bool, device=node_ids.device)
imp.index_fill_(0, node_ids, 1)
edge_mask = imp[self.edge_index[1]]
if self.edge_time is not None:
edge_time = self.edge_time[self.edge_ids]
edge_mask &= (start_time <= edge_time) & (edge_time < end_time)
edge_ids = self.edge_ids[edge_mask]
edge_index = self.edge_index[:,edge_mask]
yield self._get_node_handle(node_ids, edge_ids, edge_index)
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__()
self.channels = channels
self.num_nodes = num_nodes
self.route = route
self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
@property
def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None:
return self.data
else:
return self.data[index]
def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None:
self.data.copy_(value)
else:
# value = value.to(self.device)
self.data[index] = value
def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None:
self.data.add_(value)
else:
# value = value.to(self.device)
self.data[index] += value
def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None:
self.data.zero_()
else:
self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
@staticmethod
def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
@staticmethod
def _method_call(method, rref: rpc.RRef, *args, **kwargs):
self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from starrygl.parallel import get_compute_device
from starrygl.core.route import Route
from starrygl.utils.partition import metis_partition
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from typing import *
def _local_TP_FP_FN(pred: LongTensor, targ: LongTensor, num_classes: int) -> Tensor:
TP, FP, FN = 0, 1, 2
tmp = torch.empty(3, num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
pred_c = (pred == c)
targ_c = (targ == c)
tmp[TP, c] = torch.count_nonzero(pred_c and targ_c)
tmp[FP, c] = torch.count_nonzero(pred_c and not targ_c)
tmp[FN, c] = torch.count_nonzero(not pred_c and targ_c)
return tmp
def micro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes).sum(dim=-1)
dist.all_reduce(tmp)
TP, FP, FN = tmp.tolist()
precision = TP / (TP + FP)
recall = TP / (TP + FN)
return 2 * precision * recall / (precision + recall)
def macro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes)
dist.all_reduce(tmp)
TP, FP, FN = tmp
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
return f1.mean().item()
def accuracy(pred: LongTensor, targ: LongTensor) -> float:
tmp = torch.empty(2, dtype=torch.float32, device=pred.device)
tmp[0] = pred.eq(targ).count_nonzero()
tmp[1] = pred.size(0)
dist.all_reduce(tmp)
a, b = tmp.tolist()
return a / b
def all_reduce_loss(loss: Tensor, batch_size: int) -> float:
tmp = torch.tensor([
loss.item() * batch_size,
batch_size
], dtype=torch.float32, device=loss.device)
dist.all_reduce(tmp)
cum_loss, n = tmp.tolist()
return cum_loss / n
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch.distributed as dist
def sync_print(*args, **kwargs):
rank = dist.get_rank()
world_size = dist.get_world_size()
for i in range(world_size):
if i == rank:
print(f"rank {rank}:", *args, **kwargs)
dist.barrier()
def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
# dist.barrier()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.core import NodeProbe
from starrygl.nn.convs.s_gat_conv import GATConv
from starrygl.graph import DistGraph
from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
class Net(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
heads: int = 8,
) -> None:
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
last_ch = in_channels
self.convs = nn.ModuleList()
for i in range(num_layers):
out_ch = out_channels if i == num_layers - 1 else hidden_channels
self.convs.append(GATConv(last_ch, out_ch, heads))
last_ch = out_ch
def forward(self, g: DistGraph, probe: NodeProbe):
for i in range(self.num_layers):
shrink = g.cache.layers[i].get_shrink_data()
if i == 0:
x = g.cache.fetch_pull_tensor(i)
else:
# x = torch.ones(dst_idx.size(0), self.hidden_channels, device=x.device)
x, _ = g.cache.update_cache(i, x, dst_idx)
dst_idx = shrink.dst_idx
# print(shrink.edge_index[1].unique().size(0), shrink.dst_size)
# print(x.size(0), shrink.edge_index[0].unique().size(0), shrink.src_size)
# e = x[shrink.edge_index[0]]
# from torch_scatter import scatter
from starrygl.utils.printer import main_print
main_print(x.requires_grad)
# print(e.size(0), shrink.edge_index[1].size(0), shrink.dst_size, shrink.edge_index[1].max().item())
# scatter(e, shrink.edge_index[1], dim=0, dim_size=shrink.dst_size)
# return
x = self.convs[i](shrink, x)
x = F.relu(x)
x = probe.apply(i, x, dst_idx)
return x
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(
ids=pdata.ids,
edge_index=pdata.edge_index,
num_features=64,
num_layers=3,
cache_device="cpu"
).to(device)
cached_data_0, _ = g.route.forward_a2a(pdata.x).get()
g.cache.replace_layer_data(0, cached_data_0.to(g.cache.cache_device))
probe = NodeProbe(
g.dst_size,
num_layers=3,
num_samples=128,
).to(device)
probe.assign_message_cache(g.cache)
probe.layers[-1].warmup_sample()
net = Net(pdata.num_features, 64, pdata.num_classes, num_layers=3).to(device)
net(g, probe).sum().backward()
\ No newline at end of file
import torch
import torch.distributed as dist
from starrygl.parallel import init_process_group
from typing import *
if __name__ == "__main__":
device = init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 4
xs = [torch.randn(100+i*10, device=device) for i in range(world_size)]
send_sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long, device=device)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
\ No newline at end of file
# import torch
# from starrygl.core.acopy import get_executor, AsyncCopyWorkBase, List
# if __name__ == "__main__":
# data = torch.arange(1024*1024*1024, dtype=torch.float32).view(-1, 1).repeat(1, 8)
# # buffer = torch.arange(1000, dtype=torch.float32).view(-1, 1).repeat(1, 8).pin_memory()
# index = torch.arange(1024*1024 * 256, dtype=torch.long).cuda()
# values = torch.ones(1024*1024 * 256, 8, dtype=torch.float32).cuda()
# pull_works: List[AsyncCopyWorkBase] = []
# push_works: List[AsyncCopyWorkBase] = []
# for s in range(0, index.size(0), 1024 * 1024 * 256):
# t = min(s + 1024 * 1024 * 256, index.size(0))
# idx = index[s:t]
# val = values[idx]
# pullw = get_executor().async_pull(data, idx)
# pushw = get_executor().async_push(data, idx, val)
# pull_works.append(pullw)
# push_works.append(pushw)
# pull_time = 0.0
# for w in pull_works:
# val = w.get()
# pull_time += w.time_used()
# # print(val)
# push_time = 0.0
# for w in push_works:
# w.get()
# push_time += w.time_used()
# # print(data)
# print(pull_time, push_time)
\ No newline at end of file
import torch
import torch.autograd as autograd
from torch import Tensor
from typing import *
class AFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
return a, b.float()
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
print(a, b)
return a, b.bool()
if __name__ == "__main__":
x = torch.rand(10).requires_grad_()
y = torch.ones(10).bool()
a, b = AFunction.apply(x, y)
(a + b).sum().backward()
print(a, b)
print(x.grad, y.grad)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.loader import NodeLoader, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis")
loader = NodeLoader(pdata.ids, pdata.edge_index, device)
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
for handle in loader.iter(128):
with RouteContext() as ctx:
sync_print(handle.src_size, handle.dst_size)
dst_feats = handle.get_dst_feats(buffers[0]).wait()
ext_fut = handle.get_ext_feats(buffers[1])
sync_print(dst_feats.size())
src_feats, fut = handle.push_and_pull(dst_feats[:,:hidden_size], ext_fut, buffers[1])
ctx.add_futures(fut)
sync_print(src_feats.size())
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
# from starrygl.loader import DataLoader, TensorBuffers
from starrygl.loader.route import Route
from starrygl.loader.buffer import TensorBuffer
from starrygl.loader.utils import init_local_edge_index, local_partition_fn
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
# def get_route_table(device = "cpu") -> Tuple[RouteTable, Tensor]:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
# edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
# src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
# return RouteTable(src_ids, ids.size(0)), local_edge_index
# def get_features(device = "cpu") -> Tensor:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
route, local_edge_index = Route.from_edge_index(pdata.ids, pdata.edge_index)
# route_table = RouteTable(src_ids, dst_size)
# sync_print(route_table.fw_0)
# route_table = get_route_table()[0]
# for name, buffer in route_table.named_buffers():
# if name.startswith("fw_") or name.startswith("bw_"):
# sync_print(name, buffer.tolist())
tb = TensorBuffer(pdata.num_features, route.src_size, route).cpu()
print(tb.data.device, pdata.x.device)
tb.data[:pdata.num_nodes] = pdata.x
dist.barrier()
tb.broadcast()
print(tb.data[pdata.num_nodes:].sum(dim=-1))
# assert (tb.data[:pdata.num_nodes] == pdata.x).all()
# local_src_ids = torch.arange(route.src_size)
# for local_ids, remote_ids in route.parts_iter(local_src_ids):
# sync_print(local_ids.size(), remote_ids.size())
# router, edge_index = get_route_table()
# x = get_features()[:,None]
# buffer = TensorBuffers(router, channels=[1])
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
# main_print("+"*64)
# buffer.zero_grad()
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast2()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
rpc.shutdown()
# buffer.remote_get(0, 0, async_op=True).then()
# num_parts = 50
# node_parts = local_partition_fn(dst_size, local_edge_index, num_parts)
# feat_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# grad_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# node_time = torch.rand(dst_size)
# edge_time = torch.rand(local_edge_index.size(1))
# edge_attr = torch.randn(local_edge_index.size(1), 10)
# loader = DataLoader(node_parts, feat_buffer, grad_buffer, local_edge_index, edge_attr, edge_time, node_time)
# feat_buffer.set(0, None, collect_feat0(src_ids, src_ids[:dst_size], pdata.x))
# for handle, edge_index, edge_attr in loader.iter(2, layer_id=0, device=device, filter=lambda x: x < 0.5):
# print(handle.feat0.sum(dim=-1))
# print(edge_index)
# print(edge_attr.sum(dim=-1))
# handle.push_and_pull(handle.feat0[:handle.dst_size], 0)
# assert (handle.fetch_feat() == handle.feat0).all()
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="mpi")
sync_print(dist.get_backend())
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from multiprocessing.pool import ThreadPool
num_threads = 1
num_tasks = 5
pool = ThreadPool(num_threads)
if __name__ == "__main__":
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
group = dist.new_group()
stream = torch.cuda.Stream(device)
handles = []
for i in range(num_tasks):
x = torch.ones(10, dtype=torch.float32, device=device)
def run(x):
# stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
group.allreduce(x)
return x
stream.wait_stream(torch.cuda.current_stream())
handles.append(pool.apply_async(run, (x,)))
x = torch.ones(3, dtype=torch.float32, device=device)
dist.all_reduce(x)
sync_print(x.tolist())
for i in range(num_tasks):
main_print("="*64)
x = handles[i].get()
assert (handles[i].get() == x).all()
sync_print(x.tolist())
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from starrygl.core.gather import Gather
from starrygl.core.straight import Straight
from starrygl.parallel import init_process_group, compute_degree
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_distgraph(device = "cpu") -> DistGraph:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
return DistGraph(ids, edge_index)
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(g: DistGraph, x: Tensor) -> Tensor:
edge_index = g.edge_index
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
return x
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
g = get_distgraph(device)
# sync_print(g.src_ids.tolist())
# main_print("="*64)
# sync_print(g.dst_ids.tolist())
# main_print("="*64)
# sync_print(g.route.forward_routes[0])
# main_print("="*64)
in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
# straight = Straight(g.dst_size, num_samples=1)
# gather = Gather(g.src_size).to(device)
# sync_print(g)
# sync_print(gather.get_buffer("last_embd").tolist())
# main_print("="*64)
sync_print(x.tolist(), g.route.src_size)
main_print("="*64)
a = torch.arange(g.route.src_size, dtype=torch.long, device=device)
# a = None
# z, a = gather(x, None, route=g.route)
z, a = g.route.apply(x, a, {}, async_op=False)
sync_print(z.tolist(), "" if a is None else a.tolist(), g.route.dst_size)
main_print("="*64)
sync_print(g.edge_index[0])
main_print("="*64)
z = reduce(g, z)
sync_print(z.tolist())
main_print("="*64)
# z = straight(x, a, g)
# sync_print(z.tolist())
# main_print("="*64)
loss = z.sum()
sync_print(loss.item())
main_print("="*64)
loss.backward()
sync_print(x.grad.tolist())
main_print("="*64)
sync_print(f"time_used: {g.route.total_time_used}ms")
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.route import Route
from starrygl.graph.utils import init_local_edge_index
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_route(device = "cpu") -> Tuple[Route, Tensor]:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
return Route(ids, src_ids), local_edge_index
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(x: Tensor, edge_index: Tensor, dim_size: int) -> Tensor:
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=dim_size)
return x
if __name__ == "__main__":
device = init_process_group(backend="mpi")
route, edge_index = get_route(device)
src_size = route.dst_size
dst_size = route.src_size
# in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
sync_print(x.tolist())
main_print("="*64)
z, _ = route.apply(x, None)
sync_print(z.tolist())
main_print("="*64)
h = reduce(z, edge_index, dst_size)
sync_print(h.tolist())
main_print("="*64)
h.sum().backward()
sync_print(x.grad.tolist())
main_print("="*64)
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group, all_gather_remote_objects
from starrygl.utils.printer import main_print, sync_print
class A:
def __init__(self, device) -> None:
self.id = rpc.get_worker_info().id
self.data = torch.tensor([self.id], device=device)
@staticmethod
def get_data(rref: rpc.RRef) -> Tensor:
self: A = rref.local_value()
return self.data
def set_remote_a(self, *rrefs: rpc.RRef):
self.rrefs = tuple(rrefs)
def display(self):
sync_print(f"id = {self.id}")
def gather_data(self, device) -> List[Tensor]:
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
f = rpc.rpc_async(self.rrefs[i].owner(), A.get_data, args=(self.rrefs[i],))
futs.append(f)
results: List[Tensor] = []
for f in futs:
f.wait()
results.append(f.value())
return results
if __name__ == "__main__":
device = init_process_group("nccl")
a = A(device)
a.set_remote_a(*all_gather_remote_objects(a))
a.display()
for t in a.gather_data(device):
sync_print(t, t.device)
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.gather import Gather
from starrygl.parallel import init_process_group
from starrygl.parallel.sync_bn import SyncBatchNorm
from starrygl.graph import DistGraph, EID, SID, DID
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
device = init_process_group(backend="nccl")
bn = SyncBatchNorm(100).to(device)
bn.train()
x = torch.randn(10, 100, dtype=torch.float64).to(device).requires_grad_()
sync_print(x)
from torch.autograd import gradcheck
sync_print(gradcheck(lambda x: bn(x), x))
# sync_print(bn(x))
-r requirements.txt
# used by docs
sphinx-autobuild
sphinx_rtd_theme
matplotlib
ipykernel
\ No newline at end of file
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.1+cu118
torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118
torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118
ogb
tqdm
\ No newline at end of file
...@@ -2,14 +2,17 @@ import torch ...@@ -2,14 +2,17 @@ import torch
import logging import logging
__version__ = "0.1.0"
try: try:
from .lib import libstarrygl_ops as ops from .lib import libstarrygl as ops
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.") logging.error("unable to import libstarrygl.so, some features may not be available.")
try: try:
from .lib import libstarrygl_ops_sampler as sampler_ops from .lib import libstarrygl_sampler as sampler_ops
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
logging.error("unable to import libstarrygl_sampler.so, some features may not be available.") logging.error("unable to import libstarrygl_sampler.so, some features may not be available.")
\ No newline at end of file
\ No newline at end of file
...@@ -20,6 +20,9 @@ __all__ = [ ...@@ -20,6 +20,9 @@ __all__ = [
] ]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData: class GraphData:
def __init__(self, def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
...@@ -168,33 +171,99 @@ class GraphData: ...@@ -168,33 +171,99 @@ class GraphData:
return g return g
@staticmethod @staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData': def load_partition(
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}" root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__()) return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"): def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph" assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index() edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}") logging.info(f"running partition aglorithm: {algorithm}")
if algo == "metis": partition_kwargs = partition_kwargs or {}
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis": if node_weight is not None:
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts) node_weight = self.node()[node_weight]
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts) if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
if algorithm == "metis":
node_parts = metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index,
num_nodes, num_parts,
**partition_kwargs,
)
else: else:
raise ValueError(f"unknown partition algorithm: {algo}") raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve() root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}" base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists(): if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.") logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__()) shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True) base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts): for i in range(num_parts):
npart_mask = node_parts == i npart_mask = node_parts == i
...@@ -213,13 +282,19 @@ class GraphData: ...@@ -213,13 +282,19 @@ class GraphData:
raw_dst_ids=raw_dst_ids, raw_dst_ids=raw_dst_ids,
) )
for key in self.node().keys(): for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask] g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys(): for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask] g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys(): for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key] g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}") logging.info(f"saving partition data: {i+1}/{num_parts}")
......
...@@ -16,7 +16,8 @@ class TensorAccessor: ...@@ -16,7 +16,8 @@ class TensorAccessor:
self._data = data self._data = data
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data) if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
self.stream = torch.cuda.Stream() self.stream = torch.cuda.Stream()
@property @property
...@@ -141,14 +142,15 @@ class DistIndex: ...@@ -141,14 +142,15 @@ class DistIndex:
class DistributedTensor: class DistributedTensor:
def __init__(self, data: Tensor) -> None: def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data) self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs() if hasattr(self,'_rref'):
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = [] local_sizes = []
for rref in self.rrefs: for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait() n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n) local_sizes.append(n)
self._num_nodes: int = sum(local_sizes) self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes) self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
self._part_id: int = self.accessor.ctx.rank self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size self._num_parts: int = self.accessor.ctx.world_size
...@@ -166,7 +168,7 @@ class DistributedTensor: ...@@ -166,7 +168,7 @@ class DistributedTensor:
return self._num_nodes return self._num_nodes
@property @property
def num_part_nodes(self) -> tuple[int,...]: def num_part_nodes(self) -> Tuple[int,...]:
return self._num_part_nodes return self._num_part_nodes
@property @property
......
from .convs import GCNConv, GATConv, GINConv
# from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
# from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
# class ShrinkGCN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGCNConv(in_channels, out_channels, **kwargs)
# class ShrinkGAT(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGATConv(in_channels, out_channels, **kwargs)
# class ShrinkGIN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
# import torch
# import torch.nn as nn
# import torch.distributed as dist
# from torch import Tensor
# from typing import *
# from starrygl.loader import BatchHandle
# class BaseLayer(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
# return x
# def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat()
# with torch.no_grad():
# x = self.forward(x, edge_index, edge_attr)
# handle.update_feat(x)
# def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat().requires_grad_()
# g = handle.fetch_grad()
# self.forward(x, edge_index, edge_attr).backward(g)
# handle.accumulate_grad(x.grad)
# x.grad = None
# def all_reduce_grad(self):
# for p in self.parameters():
# if p.grad is not None:
# dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# class BaseModel(nn.Module):
# def __init__(self,
# num_features: int,
# layers: List[int],
# prev_layer: bool = False,
# post_layer: bool = False,
# ) -> None:
# super().__init__()
# def init_prev_layer(self) -> Tensor:
# pass
# def init_post_layer(self) -> Tensor:
# pass
# def init_conv_layer(self) -> Tensor:
# pass
\ No newline at end of file
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch import Tensor
# from typing import *
# import copy
# import inspect
# from torch_geometric.nn import JumpingKnowledge
# from torch_geometric.nn.resolver import activation_resolver, normalization_resolver
# from dataclasses import dataclass
# from ..core.cache import NodeProbe
# from ..graph.distgraph import DistGraph
# # from ..core.gather import Gather
# # from ..core.straight import Straight
# @dataclass
# class BasicInputOptions:
# weight: bool = True
# bias: bool = True
# gather: bool = True
# gather_first: bool = False
# dropout: float = 0.0
# act: Optional[str] = None
# act_kwargs: Optional[Dict[str, Any]] = None
# act_first: bool = False
# norm: Optional[str] = None
# norm_kwargs: Optional[Dict[str, Any]] = None
# straight_enabled: bool = True
# straight_num_samples: Optional[int] = None
# @dataclass
# class BasicLayerOptions:
# in_channels: int
# hidden_channels: int
# num_layers: int
# out_channels: Optional[int] = None
# gather_beta: float = 1.0
# dropout: float = 0.0
# act: Optional[str] = "relu"
# act_kwargs: Optional[Dict[str, Any]] = None
# act_first: bool = False
# norm: Optional[str] = None
# norm_kwargs: Optional[Dict[str, Any]] = None
# jk_mode: Optional[str] = None
# @dataclass
# class BasicStraightOptions:
# enabled: bool = False
# num_samples: Optional[int] = None
# beta: float = 1.0
# class BasicGNN(nn.Module):
# def __init__(
# self,
# g: DistGraph,
# layer_options: BasicLayerOptions,
# input_options: BasicInputOptions = BasicInputOptions(),
# straight_options: BasicStraightOptions = BasicStraightOptions(),
# **kwargs,
# ):
# super().__init__()
# num_samples = straight_options.num_samples or g.dst_size
# prev_straight = None
# self.in_channels = in_channels = layer_options.in_channels
# self.hidden_channels = hidden_channels = layer_options.hidden_channels
# self.num_layers = num_layers = layer_options.num_layers
# # default out_channels is hidden_channels
# self.out_channels = out_channels = hidden_channels \
# if layer_options.out_channels is None else layer_options.out_channels
# self.gather_beta = gather_beta = layer_options.gather_beta
# self.layer_options = layer_options
# self.input_options = input_options
# self.straight_options = straight_options
# # initialize input layer
# if input_options.weight:
# self.lin_x = nn.Linear(in_channels, hidden_channels, bias=input_options.bias)
# in_channels = hidden_channels
# if straight_options.enabled and input_options.straight_enabled and not input_options.gather_first:
# sns = input_options.straight_num_samples or g.dst_size
# self.straight_x = Straight(g.dst_size, sns, beta=straight_options.beta, prev=prev_straight)
# prev_straight = [self.straight_x]
# if input_options.gather:
# self.gather_x = Gather(g.src_size, in_channels, beta=gather_beta)
# if input_options.act is not None:
# self.act_x = activation_resolver(
# input_options.act, **(input_options.act_kwargs or {}))
# if input_options.norm is not None:
# self.norm_x = normalization_resolver(
# input_options.norm, in_channels, **(input_options.norm_kwargs or {}))
# # initialize activation layers
# if layer_options.act is not None:
# self.acts = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.acts.append(activation_resolver(
# layer_options.act, **(layer_options.act_kwargs or {})))
# if layer_options.jk_mode is not None:
# self.acts.append(activation_resolver(
# layer_options.act, **(layer_options.act_kwargs or {})))
# # initialize normalization layers
# if layer_options.norm is not None:
# self.norms = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.norms.append(normalization_resolver(
# layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# if layer_options.jk_mode is not None:
# self.norms.append(normalization_resolver(
# layer_options.norm, hidden_channels, **(layer_options.norm_kwargs or {})))
# # initialize straight layers
# if straight_options.enabled:
# self.straights = nn.ModuleList()
# for _ in range(num_layers):
# prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
# self.straights.append(prev_straight[0])
# # if layer_options.jk_mode is not None:
# # prev_straight = [Straight(g.dst_size, num_samples, beta=straight_options.beta, prev=prev_straight)]
# # self.straights.append(prev_straight[0])
# # initialize gather and conv layers
# self.convs = nn.ModuleList()
# self.gathers = nn.ModuleList()
# for _ in range(num_layers - 1):
# self.convs.append(
# self.init_conv(in_channels, hidden_channels, **kwargs))
# self.gathers.append(Gather(g.src_size, hidden_channels, beta=gather_beta))
# in_channels = hidden_channels
# if layer_options.jk_mode is None:
# self.convs.append(
# self.init_conv(in_channels, out_channels, **kwargs))
# self.gathers.append(Gather(g.dst_size, out_channels)) # only fuse embeddings
# else:
# self.convs.append(
# self.init_conv(in_channels, hidden_channels, **kwargs))
# self.gathers.append(Gather(g.dst_size, hidden_channels, beta=gather_beta)) # only fuse embeddings
# if layer_options.jk_mode != "last":
# self.jk = JumpingKnowledge(layer_options.jk_mode, hidden_channels, num_layers)
# if layer_options.jk_mode == "cat":
# jk_channels = num_layers * hidden_channels
# else:
# jk_channels = hidden_channels
# self.lin_jk = nn.Linear(jk_channels, out_channels)
# self.reset_parameters()
# def init_conv(self,
# in_channels: int,
# out_channels: int,
# **kwargs
# ) -> nn.Module:
# raise NotImplementedError
# def reset_parameters(self):
# if hasattr(self, "lin_x"):
# self.lin_x.reset_parameters()
# if hasattr(self, "straight_x"):
# self.straight_x.reset_parameters()
# if hasattr(self, "gather_x"):
# self.gather_x.reset_parameters()
# if hasattr(self, "norm_x"):
# self.norm_x.reset_parameters()
# if hasattr(self, "norms"):
# for norm in self.norms:
# norm.reset_parameters()
# if hasattr(self, "straights"):
# for straight in self.straights:
# straight.reset_parameters()
# for conv in self.convs:
# conv.reset_parameters()
# for gather in self.gathers:
# gather.reset_parameters()
# if hasattr(self, "jk"):
# self.jk.reset_parameters()
# if hasattr(self, "lin_jk"):
# self.lin_jk.reset_parameters()
# def forward(
# self,
# g: DistGraph,
# ):
# from ..utils.printer import main_print, sync_print
# x = g.ndata["x"]
# async_op = g.args.get("async_op", False)
# # input layer
# if hasattr(self, "straight_x"):
# dst_idx, _ = self.straight_x.pop_next_shrink_helper()
# else:
# dst_idx = None
# if dst_idx is not None:
# x = x[dst_idx]
# if hasattr(self, "gather_x") and self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# if self.input_options.dropout != 0.0:
# x = F.dropout(x, p=self.input_options.dropout, training=self.training)
# if hasattr(self, "lin_x"):
# x = self.lin_x(x)
# if hasattr(self, "act_x") and self.input_options.act_first:
# x = self.act_x(x)
# if hasattr(self, "norm_x"):
# x = self.norm_x(x)
# if hasattr(self, "act_x") and not self.input_options.act_first:
# x = self.act_x(x)
# # straight sampler
# if hasattr(self, "straight_x"):
# # sync_print(f"{x.size()} - {'' if dst_idx is None else dst_idx.size()}")
# x = self.straight_x(x, dst_idx, g)
# # gather features
# if hasattr(self, "gather_x") and not self.input_options.gather_first:
# x, _ = self.gather_x(x, dst_idx, g.route, async_op=async_op)
# # conv layers
# xs: List[Tensor] = []
# for i in range(self.num_layers):
# if self.layer_options.dropout != 0.0:
# x = F.dropout(x, p=self.layer_options.dropout, training=self.training)
# if hasattr(self, "straights"):
# straight: Straight = self.straights[i]
# dst_idx, sh = straight.pop_next_shrink_helper()
# with g.scoped_manager():
# g.ndata["x"] = x
# x = self.convs[i](g, sh=sh, dst_idx=dst_idx)
# if i == self.num_layers - 1 and not hasattr(self, "jk"):
# x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
# break
# if hasattr(self, "acts") and self.layer_options.act_first:
# x = self.acts[i](x)
# if hasattr(self, "norms"):
# x = self.norms[i](x)
# if hasattr(self, "acts") and not self.layer_options.act_first:
# x = self.acts[i](x)
# if hasattr(self, "straights"):
# x = self.straights[i](x, dst_idx, g)
# if i == self.num_layers - 1:
# x = self.gathers[i].fuse_embeddings(x, dst_idx, inplace=True)
# else:
# x, _ = self.gathers[i](x, dst_idx, g.route, async_op=async_op)
# if hasattr(self, "jk"):
# xs.append(x[:g.dst_size])
# x = self.jk(xs) if hasattr(self, "jk") else x
# x = self.lin_jk(x) if hasattr(self, "lin_jk") else x
# # sync_print(f"out: {x.size()}")
# return x
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
# from .shrink_gcn_conv import ShrinkGCNConv
# from .shrink_gat_conv import ShrinkGATConv
# from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
H, C = self.heads, self.out_channels
x = g.ndata["x"]
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GCNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tensor:
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
edge_index = g.edge_index
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GINConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mlp_channels is None:
mlp_channels = in_channels + out_channels
self.mlp_channels = mlp_channels
self.initial_eps = eps
self.nn = nn.Sequential(
nn.Linear(in_channels, mlp_channels),
nn.ReLU(),
nn.Linear(mlp_channels, out_channels),
)
if train_eps:
self.eps = torch.nn.Parameter(torch.tensor[eps])
else:
self.register_buffer("eps", torch.tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
for name, param in self.nn.named_parameters():
if "weight" in name:
nn.init.xavier_normal_(param)
if "bias" in name:
nn.init.zeros_(param)
nn.init.constant_(self.eps, self.initial_eps)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
x = g.ndata["x"]
edge_index = g.edge_index
z = x[edge_index[0]]
z = scatter_sum(z, edge_index[1], dim=0, dim_size=g.dst_size)
x = z + (1 + self.eps) * x[:g.dst_size]
return self.nn(x)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch_geometric.utils import softmax
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gat_conv import GATConv
# from .utils import ShrinkHelper
# class ShrinkGATConv(GATConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# heads: int = 1,
# concat: bool = False,
# negative_slope: float = 0.2,
# dropout: float = 0.0,
# edge_dim: Optional[int] = None,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# heads=heads,
# concat=concat,
# negative_slope=negative_slope,
# dropout=dropout,
# edge_dim=edge_dim,
# bias=bias,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# H, C = self.heads, self.out_channels
# x = g.ndata["x"]
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
# src_x = (src_x @ self.weight).view(-1, H, C)
# dst_x = (dst_x @ self.weight).view(-1, H, C)
# alpha_i = (src_x * self.att_src).sum(dim=-1)
# alpha_j = alpha_j[edge_index[0]]
# alpha_i = (dst_x * self.att_dst).sum(dim=-1)
# alpha_i = alpha_i[edge_index[1]]
# if self.edge_dim is not None:
# edge_attr = g.edata["edge_attr"]
# edge_attr = edge_attr[sh.edge_idx]
# if edge_attr.dim() == 1:
# edge_attr = edge_attr.view(-1, 1)
# e = (edge_attr @ self.lin_edge).view(-1, H, C)
# alpha_e = (e * self.att_edge).sum(dim=-1)
# alpha = alpha_i + alpha_j + alpha_e
# else:
# alpha = alpha_i + alpha_j
# alpha = F.leaky_relu(alpha, self.negative_slope)
# alpha = softmax(
# src=alpha,
# index=edge_index[1],
# num_nodes=sh.dst_size,
# )
# alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# x = x[edge_index[0]] * alpha.view(-1, H, 1)
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.concat:
# x = x.view(-1, H * C)
# else:
# x = x.mean(dim=1)
# if self.bias is not None:
# x += self.bias
# return x
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gcn_conv import GCNConv
# from .utils import ShrinkHelper
# class ShrinkGCNConv(GCNConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# bias=bias,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# x = g.ndata["x"]
# gcn_norm = g.edata["gcn_norm"].view(-1, 1)
# x = x[sh.src_idx]
# gcn_norm = gcn_norm[sh.edge_idx]
# edge_index = sh.edges
# x = x @ self.weight
# x = x[edge_index[0]] * gcn_norm
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.bias is not None:
# x += self.bias
# return x
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gin_conv import GINConv
# from .utils import ShrinkHelper
# class ShrinkGINConv(GINConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# mlp_channels: Optional[int] = None,
# eps: float = 0,
# train_eps: bool = False,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# mlp_channels=mlp_channels,
# eps=eps,
# train_eps=train_eps,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# x = g.ndata["x"]
# sh = ShrinkHelper(g, dst_idx)
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
# z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
# x = z + (1 + self.eps) * dst_x
# return self.nn(x)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class ShrinkHelper:
# def __init__(self, g, dst_idx: Tensor) -> None:
# from starrygl.graph import DistGraph
# g: DistGraph = g
# self.device = dst_idx.device
# dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
# dst_m.index_fill_(0, dst_idx, 1)
# edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
# edge_index = g.edge_index[:, edge_idx]
# src_idx = edge_index[0]
# src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
# src_m.index_fill_(0, src_idx, 1)
# src_idx = torch.where(src_m)[0]
# imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
# dst = imp[edge_index[1]]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
\ No newline at end of file
...@@ -4,7 +4,7 @@ import torch.distributed as dist ...@@ -4,7 +4,7 @@ import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox from starrygl.module.memorys import MailBox
from starrygl.sample.graph_core import DataSet from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData from starrygl.sample.graph_core import DistributedGraphStore
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl import dgl
...@@ -44,7 +44,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid): ...@@ -44,7 +44,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape) #print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None): def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1: if len(sample_out) > 1:
sample_out,metadata = sample_out sample_out,metadata = sample_out
......
...@@ -66,9 +66,12 @@ class DistributedDataLoader: ...@@ -66,9 +66,12 @@ class DistributedDataLoader:
if train is True: if train is True:
self._get_expected_idx(self.dataset.len) self._get_expected_idx(self.dataset.len)
else: else:
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX) if torch.distributed.get_rank() == 0:
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size)) #self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
else:
self.expected_idx = 0
def __iter__(self): def __iter__(self):
if self.chunk_size is None: if self.chunk_size is None:
if self.shuffle: if self.shuffle:
...@@ -134,6 +137,9 @@ class DistributedDataLoader: ...@@ -134,6 +137,9 @@ class DistributedDataLoader:
return next_data return next_data
def __next__(self): def __next__(self):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if(dist.get_world_size() > 0): if(dist.get_world_size() > 0):
if self.recv_idxs < self.expected_idx: if self.recv_idxs < self.expected_idx:
data = self._next_data() data = self._next_data()
...@@ -145,7 +151,9 @@ class DistributedDataLoader: ...@@ -145,7 +151,9 @@ class DistributedDataLoader:
self.device) self.device)
self.recv_idxs += 1 self.recv_idxs += 1
assert batch_data is not None assert batch_data is not None
return batch_data end_event.record()
sample_time = start_event.elapsed_time(end_event)
return batch_data,sample_time
else : else :
raise StopIteration raise StopIteration
if self.queue_size > 0 : if self.queue_size > 0 :
......
...@@ -4,7 +4,7 @@ import os.path as osp ...@@ -4,7 +4,7 @@ import os.path as osp
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch_geometric.data import Data from torch_geometric.data import Data
class GraphStore(): class DistributedGraphStore:
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False): def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False):
self.device = device self.device = device
self.ids = pdata.ids.to(device) self.ids = pdata.ids.to(device)
...@@ -67,6 +67,7 @@ class DataSet: ...@@ -67,6 +67,7 @@ class DataSet:
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1] self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
print(self.edges.shape,self.len)
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
...@@ -106,9 +107,9 @@ class DataSet: ...@@ -106,9 +107,9 @@ class DataSet:
setattr(d,k,v[indx]) setattr(d,k,v[indx])
return d return d
class TemporalGraphData(GraphStore): class DistributedTemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device): def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device) super(DistributedTemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size): def _set_temporal_batch_cache(self,size,pin_size):
pass pass
def _load_feature_to_cuda(self,ids): def _load_feature_to_cuda(self,ids):
...@@ -117,7 +118,7 @@ class TemporalGraphData(GraphStore): ...@@ -117,7 +118,7 @@ class TemporalGraphData(GraphStore):
class TemporalNeighborSampleGraph(GraphStore): class TemporalNeighborSampleGraph(DistributedGraphStore):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None): def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index'] self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1] self.num_edges = self.edge_index.shape[1]
......
...@@ -42,9 +42,10 @@ class SharedMailBox(): ...@@ -42,9 +42,10 @@ class SharedMailBox():
dtype=torch.long, dtype=torch.long,
device = self.device) device = self.device)
self._ctx = DistributedContext.get_default_context() self._ctx = DistributedContext.get_default_context()
self.rref = rpc.RRef(self) if self._ctx._use_rpc is True:
self.rrefs = self._ctx.all_gather_remote_objects(self.rref) self.rref = rpc.RRef(self)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device) self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
def reset(self): def reset(self):
......
...@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"): ...@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"):
logging.info(f"GraphData.node().keys(): {g.node().keys()}") logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}") logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo) g.save_partition(root, num_parts, algorithm=part_algo)
return g return g
class SimpleConv(pyg_nn.MessagePassing): class SimpleConv(pyg_nn.MessagePassing):
......
...@@ -7,7 +7,7 @@ from starrygl.distributed.utils import DistIndex ...@@ -7,7 +7,7 @@ from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
...@@ -67,8 +67,8 @@ def main(): ...@@ -67,8 +67,8 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
print('use cuda on',device_id) print('use cuda on',device_id)
pdata = partition_load("./dataset/here/{}".format(args.dataname), algo="metis_for_tgnn") pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = GraphData(pdata = pdata) graph = DistributedGraphStore(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full') sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0) mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader: for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos)) total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
...@@ -196,10 +196,11 @@ def main(): ...@@ -196,10 +196,11 @@ def main():
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
torch.cuda.synchronize() torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time() epoch_start_time = time.time()
train_aps = list() train_aps = list()
print('Epoch {:d}:'.format(e)) print('Epoch {:d}:'.format(e))
...@@ -211,7 +212,8 @@ def main(): ...@@ -211,7 +212,8 @@ def main():
model.module.memory_updater.last_updated_nid = None model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time() t_prep_s = time.time()
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
...@@ -252,8 +254,12 @@ def main(): ...@@ -252,8 +254,12 @@ def main():
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) )
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize() torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time time_prep = time.time() - epoch_start_time
...@@ -264,6 +270,7 @@ def main(): ...@@ -264,6 +270,7 @@ def main():
ap, auc = eval('val') ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s',format(fetch_time,write_back_time))
model.eval() model.eval()
if mailbox is not None: if mailbox is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment