Commit 65ba618b by xxx

Merge branch 'master' of http://192.168.1.53:8082/wjie98/starrygl into hzq

parents e0b25dd1 1968a4f2
......@@ -28,6 +28,10 @@ share/python-wheels/
*.egg
MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
......@@ -164,8 +168,8 @@ cython_debug/
*.pt
/nohup.out
/cora
/dataset
/test_*
/*.ipynb
/third_party
!/third_party/ldg_partition
[submodule "third_party/ldg_partition"]
path = third_party/ldg_partition
url = https://gitee.com/onlynagesha/graph-partition-v4
cmake_minimum_required(VERSION 3.15)
project(starrygl_ops VERSION 0.1)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
......@@ -37,8 +38,8 @@ if(WITH_CUDA)
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm SHARED ${UVM_SRCS})
target_link_libraries(uvm PRIVATE ${TORCH_LIBRARIES})
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
......@@ -54,10 +55,10 @@ if(WITH_METIS)
include_directories(${METIS_INCLUDE_DIRS})
add_library(metis SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis PRIVATE ${METIS_LIBRARIES})
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
......@@ -68,16 +69,29 @@ if(WITH_MTMETIS)
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_PARTITIONS)
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
......@@ -90,17 +104,23 @@ if(WITH_PYTHON)
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
target_link_libraries(${PROJECT_NAME} PRIVATE metis)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
......@@ -4,15 +4,13 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
......@@ -20,4 +18,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
}
......@@ -22,3 +22,11 @@ at::Tensor mt_metis_partition(
int64_t num_workers,
bool recursive
);
at::Tensor ldg_partition(
at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers
);
#include <torch/all.h>
#include "vertex_partition/vertex_partition.h"
#include "vertex_partition/params.h"
at::Tensor ldg_partition(at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers) {
AT_ASSERT(edges.dim() == 2);
auto edges_n_cols = edges.size(1);
AT_ASSERT(edges_n_cols >= 2 && edges_n_cols <= 3);
// Note: other checks are performed in the implementation of vp::ldg_partition function series.
auto n = edges.slice(1, 0, 2).max().item<int64_t>() + 1;
auto params = vp::LDGParams{.N = n, .K = num_parts, .openmp_n_threads = static_cast<int>(num_workers)};
auto edges_clone = edges.clone();
if (vertex_weights.has_value()) {
auto vertex_weights_clone = vertex_weights->clone();
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_v_init(edges_clone, vertex_weights_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition_v(edges_clone, vertex_weights_clone, params);
}
} else {
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_init(edges_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition(edges_clone, params);
}
}
}
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
Advanced Concepts
=================
.. toctree::
sampling_parallel/index
partition_parallel/index
timeline_parallel/index
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Feature Fetching
============================
\ No newline at end of file
Distributed Sampling Parallel
=============================
.. note::
基于分布式时序图采样的训练模式
.. toctree::
sampler
features
memory
Distributed Memory Updater
==========================
\ No newline at end of file
Distributed Temporal Sampling
=============================
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
starrygl.distributed
====================
.. note::
自动生辰的api文档,需要在starrygl源代码目录里添加码内注释
.. currentmodule:: starrygl.distributed
.. autosummary::
DistributedContext
\ No newline at end of file
Package References
==================
.. toctree::
distributed
\ No newline at end of file
Cheatsheets
===========
.. note::
可以放一些框架内置的GNN模型,GNN数据集清单以及默认的接口调用规则,参考
`pyg cheatsheets <https://pytorch-geometric.readthedocs.io/en/latest/cheatsheet/data_cheatsheet.html>`_
\ No newline at end of file
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import starrygl
project = 'StarryGL'
copyright = '2023, StarryGL Team'
author = 'StarryGL Team'
version = starrygl.__version__
release = starrygl.__version__
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.duration",
]
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
External Resources
==================
.. note::
!!!放置一些StarryGL相关的论文,GNN相关的项目以及有助于使用StarryGL的资源,可以参考:
`PyG Resources <https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html>`_
\ No newline at end of file
Get Started
===========
.. toctree::
install_guide
intro_example
\ No newline at end of file
Installation
============
.. note::
安装方法
\ No newline at end of file
Introduction by Example
=======================
.. note::
安装好之后,快速开始一个时序图模型训练。先展示一下基础使用,可以是单机的?
\ No newline at end of file
StarryGL Documentation
======================
.. toctree::
guide/index
tutorial/index
advanced/index
api/python/index
cheatsheets/index
external/index
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
.. * :ref:`search`
User Cases and Applications
===========================
.. note::
放一些时序GNN的用例和应用
\ No newline at end of file
Preparing the Temporal Graph Dataset
====================================
.. note::
包含从原始数据开始的数据清洗和预处理步骤,最终形成可以被StarryGL使用的数据文件
\ No newline at end of file
Distributed Training
====================
.. note::
开始介绍分布式训练相关的方法,并展示一个最简单的分布式DDP训练流程
\ No newline at end of file
Tutorials
===============
.. toctree::
intro
module
dataset
application
distributed
\ No newline at end of file
Introduction to Temporal GNN
==============================================
.. note::
简单介绍一下时序GNN,应用场景,需要解决的问题等,相当于一个总体的介绍
Creating Temporal GNN Models
============================
.. note::
介绍如何创建GNN模型,找最经典最简洁的两个例子即可。包括 **离散时间动态图模型** 模型构建和 **连续时间动态图模型**。
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hwj/.miniconda3/envs/pyg/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def act_blks(act_mask: Tensor, blk_size: int) -> Tensor:\n",
" assert act_mask.dtype == torch.bool\n",
" assert act_mask.dim() == 1\n",
" n = act_mask.size(0)\n",
" m = (n + blk_size - 1) // blk_size * blk_size\n",
"\n",
" blk_mask = act_mask.detach()\n",
" if n != m:\n",
" blk_mask = torch.zeros(m, dtype=act_mask.dtype, device=act_mask.device)\n",
" blk_mask[:n] = act_mask.detach()\n",
" return blk_mask.view(-1, blk_size).max(dim=-1).values"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"act_mask = torch.rand(1000) < 0.01\n",
"blk_mask = act_blks(act_mask, 64)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cached_blocks = torch.randn(blk_mask.size(0), 64, 8)\n",
"blks = cached_blocks[blk_mask]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def gather_block(blks: Tensor, blk_mask: Tensor, index: Tensor) -> Tensor:\n",
" blk_size = blks.size(1)\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
" \n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
" \n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" s = index.shape[:1] + blks.shape[2:]\n",
" data = torch.zeros(s, dtype=blks.dtype, device=blks.device)\n",
" data[ind_mask] = blks[idx0, idx1, ...]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def scatter_block(x: Tensor, index: Tensor, blk_mask: Tensor, blk_size: int, reduce: str = \"sum\") -> Tensor:\n",
" num_blks = blk_mask.count_nonzero()\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
"\n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
"\n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" data = scatter(\n",
" src=x[ind_mask],\n",
" index=idx0*blk_size+idx1,\n",
" dim=0,\n",
" dim_size=num_blks*blk_size,\n",
" reduce=reduce,\n",
" )\n",
" return data.view(num_blks, blk_size, *x.shape[1:])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# def index_block(index: Tensor, blk_mask: Tensor, blk_size: int) -> Tensor:\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, True, True, False, True, False, False, True, False, True,\n",
" False, True, True, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, True, False, False, True,\n",
" True, False, False, True, True, False, True, True, False, False,\n",
" True, False, True, False, False, False, False, False, False, False,\n",
" True, True, False, False, False, True, True, True, False, False,\n",
" True, False, True, False, True, False, True, False, False, True,\n",
" False, True, True, True, False, False, False, False, False, True,\n",
" False, True, False, True, True, True, False, True, False, True,\n",
" False, True, False, True, True, False, False, True, False, True])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index = torch.randint(1000, size=(100,), dtype=torch.long)\n",
"(gather_block(blks, blk_mask, index)**2).sum(dim=-1) != 0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 64, 8])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(100, 8)\n",
"scatter_block(x, index, blk_mask, 64).size()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True, True, True, True, True])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = scatter_block(x, index, blk_mask, 64)\n",
"t = t.view(-1, *t.shape[2:])\n",
"act_blks((t**2).sum(dim=-1) != 0, 64)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.sort(\n",
"values=tensor([ 5, 15, 18, 19, 25, 48, 62, 64, 68, 76, 88, 92, 94, 106,\n",
" 114, 115, 120, 122, 126, 131, 135, 165, 169, 169, 177, 192, 214, 215,\n",
" 245, 249, 255, 269, 282, 286, 294, 299, 334, 358, 359, 372, 401, 403,\n",
" 430, 437, 455, 459, 459, 459, 471, 484, 490, 504, 505, 516, 528, 529,\n",
" 559, 584, 613, 614, 614, 624, 629, 637, 638, 674, 679, 701, 706, 735,\n",
" 742, 746, 748, 750, 759, 774, 774, 776, 778, 787, 803, 825, 831, 834,\n",
" 861, 862, 868, 877, 882, 889, 899, 909, 921, 933, 938, 954, 960, 969,\n",
" 981, 986]),\n",
"indices=tensor([ 6, 77, 14, 41, 43, 0, 13, 53, 38, 54, 15, 76, 98, 47, 75, 80, 8, 95,\n",
" 17, 34, 83, 87, 62, 93, 81, 20, 67, 39, 32, 5, 70, 4, 89, 72, 37, 64,\n",
" 94, 42, 57, 29, 56, 21, 60, 7, 45, 59, 86, 96, 49, 92, 28, 88, 24, 68,\n",
" 3, 61, 44, 2, 79, 55, 85, 33, 12, 1, 11, 97, 50, 9, 84, 40, 71, 22,\n",
" 26, 30, 73, 16, 90, 27, 19, 35, 48, 18, 46, 74, 65, 78, 52, 31, 23, 58,\n",
" 91, 66, 99, 69, 51, 36, 82, 63, 25, 10]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.sort()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
# from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
# from .cache import MessageCache, NodeProbe
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
):
assert len(output_tensor_list) == len(input_tensor_list)
if group is None:
group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
elif backend == "mpi":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
p2p_op_list: List[dist.P2POp] = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
p2p_op_list.extend([
dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
])
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# from .route import Route, GatherWork
# # from .cache import CachedEmbeddings
# class GatherContext:
# def __init__(self,
# this,
# route: Route,
# async_op: bool,
# ) -> None:
# self.this = this
# self.route = route
# self.async_op = async_op
# class Gather(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_features: Optional[int] = None,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.beta = beta
# if num_features is None:
# self.register_buffer("last_embd", torch.zeros(num_nodes))
# else:
# self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
# self.last_fw_work: Optional[GatherWork] = None
# self.last_bw_work: Optional[GatherWork] = None
# self.reset_parameters()
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# route: Route,
# async_op: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# with self._manager(route=route, async_op=async_op):
# return GatherFunction.apply(val, idx)
# def reset_parameters(self):
# last_embd = self.get_buffer("last_embd")
# nn.init.normal_(last_embd, mean=0, std=1)
# def fuse_embeddings(self,
# val: Tensor,
# idx: Optional[Tensor] = None,
# inplace: bool = False,
# ) -> Tensor:
# last_embd = self.get_buffer("last_embd")
# return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
# @contextmanager
# def _manager(self, route: Route, async_op: bool):
# global _global_gather_context
# stacked = _global_gather_context
# try:
# _global_gather_context = GatherContext(
# this=self,
# route=route,
# async_op=async_op,
# )
# yield _global_gather_context
# finally:
# _global_gather_context = stacked
# class GatherFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor] = None,
# ):
# gather_ctx = _last_global_gather_context()
# this: Gather = gather_ctx.this
# route: Route = gather_ctx.route
# async_op: bool = gather_ctx.async_op
# return_idx: bool = idx is not None
# current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
# if async_op:
# work = this.last_fw_work or current_work
# this.last_fw_work = current_work
# else:
# work = current_work
# recv_val, recv_idx = work.get()
# recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
# recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
# if this.training:
# ctx.save_for_backward(idx, recv_idx)
# ctx.this = this
# ctx.route = route
# ctx.async_op = async_op
# return recv_val, recv_idx
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# val_grad: Tensor,
# idx_grad: Optional[Tensor],
# ):
# this: Gather = ctx.this
# route: Route = ctx.route
# async_op: bool = ctx.async_op
# with torch.no_grad():
# recv_idx, idx_grad = ctx.saved_tensors
# if idx_grad is not None:
# val_grad = val_grad[idx_grad]
# current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
# if async_op:
# work = this.last_bw_work or current_work
# this.last_bw_work = current_work
# else:
# work = current_work
# recv_val = work.get_val()
# if recv_idx is not None:
# recv_val = recv_val[recv_idx]
# return recv_val, None
# class GatherFuseFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# last_embd: Tensor,
# beta: float,
# inplace: bool,
# training: bool,
# ):
# if not inplace:
# last_embd = last_embd.clone()
# ctx.beta = beta
# if idx is None:
# assert val.size(0) == last_embd.size(0)
# if beta != 1.0:
# last_embd.mul_(1 - beta).add_(val * beta)
# else:
# last_embd[:] = (val)
# else:
# assert val.size(0) == idx.size(0)
# if beta != 1.0:
# last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
# else:
# last_embd[idx] = val
# if training:
# ctx.beta = beta
# ctx.save_for_backward(idx)
# return last_embd
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# beta: float = ctx.beta
# idx, = ctx.saved_tensors
# if idx is not None:
# grad = grad[idx]
# if beta != 1.0:
# grad = grad * beta
# return grad, None, None, None, None, None
# #### private functions
# _global_gather_context: Optional[GatherContext] = None
# def _last_global_gather_context() -> GatherContext:
# global _global_gather_context
# assert _global_gather_context is not None
# return _global_gather_context
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# class Lache:
# def __init__(self,
# cache_size: int,
# data: Tensor,
# ) -> None:
# assert not data.is_cuda, "data must be in CPU Memory"
# cache_size = (cache_size,) + data.shape[1:]
# self.fdata = data
# self.cache = torch.zeros(cache_size, dtype=data.dtype)
# self.no_idx: int = (2**62-1)*2+1
# self.cached_idx = torch.empty(cache_size, dtype=torch.long).fill_(self.no_idx)
# self.read_count = torch.zeros(data.size(0), dtype=torch.long)
# def to(self, device):
# self.cache = self.cache.to(device)
# self.cached_idx = self.cached_idx.to(device)
# self.read_count = self.read_count.to(device)
# return self
# def _push_impl(self, value: Tensor, index: Tensor):
# # self.read_count[index] += 1
# imp = torch.zeros_like(self.read_count)
# def _pull_impl(self, index: Tensor):
# assert index.device == self.cache.device
# self.read_count[index] += 1
# s = index.shape[:1] + self.cache.shape[1:]
# x = torch.empty(s, dtype=self.cache.dtype, device=self.cache.device)
# cache_mask = torch.zeros_like(self.read_count, dtype=torch.bool).index_fill_(0, self.cached_idx, 1)[index]
# cache_index = index[cache_mask]
# x[cache_mask] =
# no_cache_index = index[~cache_mask]
# rt_index = index[~lc_mask].to(self.fdata.device)
# rt_data = self.fdata[rt_index].to(index.device)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
# import torch
# from torch import Tensor
# from typing import *
# class ShrinkData:
# no_idx: int = (2**62-1)*2+1
# def __init__(self,
# src_size: int,
# dst_size: int,
# dst_idx: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> None:
# device = dst_idx.device
# tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
# tmp.fill_(0)
# tmp.index_fill_(0, dst_idx, 1)
# edge_idx = torch.where(tmp[edge_index[1]])[0]
# edge_index = edge_index[:, edge_idx]
# if bipartite:
# tmp.fill_(0)
# tmp.index_fill_(0, edge_index[0], 1)
# src_idx = torch.where(tmp)[0]
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
# dst = imp[edge_index[1]]
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# src = imp[edge_index[0]]
# edge_index = torch.vstack([src, dst])
# else:
# tmp.index_fill_(0, edge_index[0], 1)
# tmp.index_fill_(0, dst_idx, 0)
# src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edge_index = edge_index
# self._src_imp = imp[:src_size]
# def to(self, device):
# self.src_idx = self.src_idx.to(device)
# self.dst_idx = self.dst_idx.to(device)
# self.edge_idx = self.edge_idx.to(device)
# self.edge_index = self.edge_index.to(device)
# self._src_imp = self._src_imp.to(device)
# return self
# def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
# idx = self._src_imp[idx]
# m = (idx != self.no_idx)
# return val[m], idx[m]
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
# @property
# def edge_size(self) -> int:
# return self.edge_idx.size(0)
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# class StraightContext:
# def __init__(self, this, g) -> None:
# self.this = this
# self.g = g
# class Straight(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_samples: int,
# norm_kwargs: Optional[Dict[str, Any]] = None,
# beta: float = 1.0,
# prev: Optional[List[Any]] = None,
# ) -> None:
# super().__init__()
# assert num_samples <= num_nodes
# self.num_nodes = num_nodes
# self.num_samples = num_samples
# self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
# self.prev = prev
# self.register_buffer("last_w", torch.ones(num_nodes))
# self._next_idx = None
# self._next_shrink_helper = None
# self.reset_parameters()
# def reset_parameters(self):
# last_w = self.get_buffer("last_w")
# nn.init.constant_(last_w, 1.0)
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# g,
# ) -> Tensor:
# with self._manager(self, g):
# return StraightFunction.apply(val, idx)
# def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
# if not self.training:
# return None, None
# next_idx = self._next_idx
# self._next_idx = None
# next_sh = self._next_shrink_helper
# self._next_shrink_helper = None
# return next_idx, next_sh
# def _sample_next(self) -> Tensor:
# w = self.get_buffer("last_w")
# if self._next_idx is None:
# if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
# else:
# self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
# elif self.num_samples < self._next_idx.size(0):
# idx = self.sample_impl(w[self._next_idx])
# self._next_idx = self._next_idx[idx]
# return self._next_idx
# def sample_impl(self, w: Tensor) -> Tensor:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
# # def multinomial(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# # w = w / w.sum()
# # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
# # def multinomial_mask(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
# # w = self.multinomial(num_samples, replacement)
# # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
# # return m
# @contextmanager
# def _manager(self, this, g):
# global _global_straight_context
# stacked = _global_straight_context
# try:
# _global_straight_context = StraightContext(this=this, g=g)
# yield _global_straight_context
# finally:
# _global_straight_context = stacked
# class StraightFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# ):
# from ..graph import DistGraph
# stx = _last_global_straight_context()
# this: Straight = stx.this
# g: DistGraph = stx.g
# last_w = this.get_buffer("last_w")
# if idx is None:
# assert val.size(0) == last_w.size(0)
# else:
# assert val.size(0) == idx.size(0)
# if this.training:
# ctx.this = this
# ctx.g = g
# ctx.save_for_backward(idx)
# return val
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
# if this.prev is not None or this._next_idx is None:
# from ..nn.convs.utils import ShrinkHelper
# dst_idx = this._sample_next()
# if this.prev is not None and this._next_idx is not None:
# this._next_shrink_helper = ShrinkHelper(g, dst_idx)
# src_idx = this._next_shrink_helper.src_idx
# work = g.route.gather_backward(
# src_idx, src_idx, async_op=False, return_idx=True)
# prev_dst_idx = work.get_idx()
# for p in this.prev:
# assert isinstance(p, Straight)
# p._next_idx = prev_dst_idx
# return grad, None
# #### private functions
# _global_straight_context: Optional[StraightContext] = None
# def _last_global_straight_context() -> StraightContext:
# global _global_straight_context
# assert _global_straight_context is not None
# return _global_straight_context
# if __name__ == "__main__":
# s = Straight(3, beta=1.1)
# x = torch.rand(3, 10).requires_grad_()
# m = torch.tensor([0, 1, 0], dtype=torch.bool)
# s(x).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x, m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x[m], m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# print(s.multinomial_mask(2))
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from torch_scatter import scatter_sum
# from starrygl.core.route import Route
# def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
# dst_size = route.src_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
# in_deg, _ = route.forward_a2a(in_deg)
# return in_deg
# def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
# src_size = route.dst_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
# out_deg, _ = route.backward_a2a(out_deg)
# out_deg, _ = route.forward_a2a(out_deg)
# return out_deg
# def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
# in_deg = compute_in_degree(edge_index, route)
# out_deg = compute_out_degree(edge_index, route)
# a = in_deg[edge_index[0]].pow(-0.5)
# b = out_deg[edge_index[0]].pow(-0.5)
# x = a * b
# x[x.isinf()] = 0.0
# x[x.isnan()] = 0.0
# return x
# import torch
# import torch.nn as nn
# from torch import Tensor
# from ..graph import DistGraph
# from typing import *
# from .metrics import *
# def train_epoch(
# model: nn.Module,
# opt: torch.optim.Optimizer,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> float:
# model.train()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss: Tensor = criterion(pred, targ)
# opt.zero_grad()
# loss.backward()
# opt.step()
# with torch.no_grad():
# train_loss = all_reduce_loss(loss, targ.size(0))
# train_acc = accuracy(pred.argmax(dim=-1), targ)
# return train_loss, train_acc
# @torch.no_grad()
# def eval_epoch(
# model: nn.Module,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> Tuple[float, float]:
# model.eval()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss = criterion(pred, targ)
# eval_loss = all_reduce_loss(loss, targ.size(0))
# eval_acc = accuracy(pred.argmax(dim=-1), targ)
# return eval_loss, eval_acc
\ No newline at end of file
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# from contextlib import contextmanager
# # import torch_sparse
# from .ndata import NData
# from .edata import EData
# from .utils import init_local_edge_index
# from ..core import MessageCache, Route
# class DistGraph:
# def __init__(self,
# ids: Tensor,
# edge_index: Tensor,
# num_features: int,
# num_layers: int,
# cache_device: str = "cpu",
# **args: Dict[str, Any],
# ):
# # build local_edge_index
# dst_ids = ids
# src_ids, local_edge_index = init_local_edge_index(
# dst_ids=dst_ids,
# edge_index=edge_index,
# )
# self._src_ids = src_ids
# self._dst_ids = dst_ids
# self._message_cache = MessageCache(
# src_ids=src_ids,
# dst_ids=dst_ids,
# edge_index=local_edge_index,
# num_features=num_features,
# num_layers=num_layers,
# cache_device=cache_device,
# bipartite=False,
# )
# # node's attributes
# self.ndata = NData(
# src_size=src_ids.size(0),
# dst_size=dst_ids.size(0),
# )
# # edge's attributes
# self.edata = EData(
# edge_size=local_edge_index.size(1),
# )
# # graph's attributes
# self.args = dict(args)
# def to(self, device):
# self._message_cache.to(device)
# return self
# def cache_data_to(self, device):
# self._message_cache.cached_data_to(device)
# @property
# def cache(self) -> MessageCache:
# return self._message_cache
# @property
# def route(self) -> Route:
# return self._message_cache.route
# @property
# def device(self) -> torch.device:
# return self.edge_index.device
# @property
# def edge_index(self) -> Tensor:
# return self._message_cache.edge_index
# @property
# def src_ids(self) -> Tensor:
# if self._src_ids.device != self.device:
# self._src_ids = self._src_ids.to(self.device)
# return self._src_ids
# @property
# def src_size(self) -> int:
# return self._src_ids.size(0)
# @property
# def dst_ids(self) -> Tensor:
# if self._dst_ids.device != self.device:
# self._dst_ids = self._dst_ids.to(self.device)
# return self._dst_ids
# @property
# def dst_size(self) -> int:
# return self._dst_ids.size(0)
# @contextmanager
# def scoped_manager(self):
# stacked_ndata = self.ndata
# stacked_edata = self.edata
# try:
# self.ndata = NData(
# p=stacked_ndata,
# )
# self.edata = EData(
# p=stacked_edata,
# )
# yield self
# finally:
# self.ndata = stacked_ndata
# self.edata = stacked_edata
# # def permute_edge_(self):
# # perm = self._local_edge_index[1].argsort()
# # self._local_edge_index = self._local_edge_index[:,perm]
# # self._local_edge_ptr = torch.ops.torch_sparse.ind2ptr(self._local_edge_index[1], self.dst_size)
# # self.edata.permute_(perm)
# # @torch.no_grad()
# # def shrink(self,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ):
# # edge_index = self.edge_index
# # device = edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # return self
# # else:
# # if dst_mask is None:
# # dst_mask = torch.zeros(self.ndata.dst_size, **bkw)
# # else:
# # assert dst_mask.size(0) == self.ndata.dst_size
# # dst_mask = dst_mask.clone()
# # if src_mask is not None:
# # m = src_mask[edge_index[0]]
# # m = edge_index[1][m]
# # dst_mask[m] = True
# # edge_mask = dst_mask[edge_index[1]]
# # edge_index = edge_index[:,edge_mask]
# # src_mask = torch.zeros(self.ndata.src_size, **bkw)
# # src_mask[:self.ndata.dst_size] = dst_mask
# # src_mask[edge_index[0]] = True
# # dst_mask = src_mask[:self.ndata.dst_size]
# # # 重新编号edge_index
# # imp = torch.empty(self.ndata.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # # assert dst_mask.count_nonzero() > edge_index[1].max().item()
# # ndata = NData(
# # src_size=None,
# # dst_size=None,
# # p=self.ndata,
# # src_mask=src_mask,
# # dst_mask=dst_mask,
# # )
# # edata = EData(
# # edge_size=None,
# # p=self.edata,
# # edge_mask=edge_mask,
# # )
# # edata[EID] = edge_index
# # return DistGraph(self.args, ndata, edata, self.route)
# # class ShrinkGraph:
# # def __init__(self,
# # g: DistGraph,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ) -> None:
# # device = g.edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # # 如果src_mask和dst_mask都不指定,
# # # 则ShrinkGraph是DistGraph的复刻
# # self.src_ids = g.sr
# # self.dst_size = g.dst_size
# # self.src_mask = torch.ones(self.src_size, **bkw)
# # self.dst_mask_size = g.dst_size
# # self.edge_index = g.edge_index
# # self.pgraph = g
# # else:
# # if src_mask is not None:
# # tmp_mask = torch.zeros(g.dst_size, **bkw)
# # # 计算直接激活的边
# # m = src_mask[g.edge_index[0]]
# # m = g.edge_index[1][m]
# # # 计算直接激活的dst_ids
# # tmp_mask[m] = True
# # # 和已激活的dst_ids合并
# # if dst_mask is not None:
# # tmp_mask |= dst_mask
# # dst_mask = tmp_mask
# # # 计算间接激活的边
# # edge_mask = dst_mask[g.edge_index[1]]
# # edge_index = g.edge_index[:,edge_mask]
# # # 计算间接激活的src_ids
# # src_mask = torch.zeros(g.src_size, **bkw)
# # src_mask[edge_index[0]] = True
# # src_mask[:g.dst_size] |= dst_mask
# # self.src_ids = g.src_ids[src_mask]
# # self.dst_size = src_mask[:g.dst_size].count_nonzero().item()
# # self.src_mask = src_mask
# # self.dst_mask_size = g.dst_size
# # # 重新编号edge_index
# # imp = torch.empty(g.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # self.edge_index = edge_index
# # self.pgraph = g
# # self.ndata = ShrinkData(self.dst_mask, g.ndata)
# # self.edata = ShrinkData(self.edge_mask, g.edata)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class EData:
# def __init__(self,
# edge_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert edge_size is not None
# self.edge_size = edge_size
# else:
# assert edge_size is None
# self.edge_size = p.edge_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError(f"the second parameter's type must be Tensor")
# if tensor.size(0) == self.edge_size:
# self.data[name] = tensor
# else:
# raise ValueError(f"tensor's shape must match the edge_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class NData:
# def __init__(self,
# src_size: Optional[int] = None,
# dst_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert src_size is not None
# assert dst_size is not None
# self.src_size = src_size
# self.dst_size = dst_size
# else:
# assert src_size is None
# assert dst_size is None
# self.src_size = p.src_size
# self.dst_size = p.dst_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError("the second parameter's type must be Tensor")
# if tensor.size(0) == self.src_size:
# self.data[name] = tensor
# elif tensor.size(0) == self.dst_size:
# self.data[name] = tensor
# else:
# raise ValueError("tensor's shape must match the src_size or dst_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
# def get_type(self, key: Union[str, Tensor]):
# if isinstance(key, Tensor):
# if key.size(0) == self.src_size:
# return "src"
# elif key.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
# t = self.__getitem__(key)
# if t.size(0) == self.src_size:
# return "src"
# elif t.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__()
self.channels = channels
self.num_nodes = num_nodes
self.route = route
self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
@property
def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None:
return self.data
else:
return self.data[index]
def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None:
self.data.copy_(value)
else:
# value = value.to(self.device)
self.data[index] = value
def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None:
self.data.add_(value)
else:
# value = value.to(self.device)
self.data[index] += value
def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None:
self.data.zero_()
else:
self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
@staticmethod
def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
@staticmethod
def _method_call(method, rref: rpc.RRef, *args, **kwargs):
self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from starrygl.parallel import get_compute_device
from starrygl.core.route import Route
from starrygl.utils.partition import metis_partition
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from typing import *
def _local_TP_FP_FN(pred: LongTensor, targ: LongTensor, num_classes: int) -> Tensor:
TP, FP, FN = 0, 1, 2
tmp = torch.empty(3, num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
pred_c = (pred == c)
targ_c = (targ == c)
tmp[TP, c] = torch.count_nonzero(pred_c and targ_c)
tmp[FP, c] = torch.count_nonzero(pred_c and not targ_c)
tmp[FN, c] = torch.count_nonzero(not pred_c and targ_c)
return tmp
def micro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes).sum(dim=-1)
dist.all_reduce(tmp)
TP, FP, FN = tmp.tolist()
precision = TP / (TP + FP)
recall = TP / (TP + FN)
return 2 * precision * recall / (precision + recall)
def macro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes)
dist.all_reduce(tmp)
TP, FP, FN = tmp
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
return f1.mean().item()
def accuracy(pred: LongTensor, targ: LongTensor) -> float:
tmp = torch.empty(2, dtype=torch.float32, device=pred.device)
tmp[0] = pred.eq(targ).count_nonzero()
tmp[1] = pred.size(0)
dist.all_reduce(tmp)
a, b = tmp.tolist()
return a / b
def all_reduce_loss(loss: Tensor, batch_size: int) -> float:
tmp = torch.tensor([
loss.item() * batch_size,
batch_size
], dtype=torch.float32, device=loss.device)
dist.all_reduce(tmp)
cum_loss, n = tmp.tolist()
return cum_loss / n
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch.distributed as dist
def sync_print(*args, **kwargs):
rank = dist.get_rank()
world_size = dist.get_world_size()
for i in range(world_size):
if i == rank:
print(f"rank {rank}:", *args, **kwargs)
dist.barrier()
def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
# dist.barrier()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.core import NodeProbe
from starrygl.nn.convs.s_gat_conv import GATConv
from starrygl.graph import DistGraph
from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
class Net(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
heads: int = 8,
) -> None:
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
last_ch = in_channels
self.convs = nn.ModuleList()
for i in range(num_layers):
out_ch = out_channels if i == num_layers - 1 else hidden_channels
self.convs.append(GATConv(last_ch, out_ch, heads))
last_ch = out_ch
def forward(self, g: DistGraph, probe: NodeProbe):
for i in range(self.num_layers):
shrink = g.cache.layers[i].get_shrink_data()
if i == 0:
x = g.cache.fetch_pull_tensor(i)
else:
# x = torch.ones(dst_idx.size(0), self.hidden_channels, device=x.device)
x, _ = g.cache.update_cache(i, x, dst_idx)
dst_idx = shrink.dst_idx
# print(shrink.edge_index[1].unique().size(0), shrink.dst_size)
# print(x.size(0), shrink.edge_index[0].unique().size(0), shrink.src_size)
# e = x[shrink.edge_index[0]]
# from torch_scatter import scatter
from starrygl.utils.printer import main_print
main_print(x.requires_grad)
# print(e.size(0), shrink.edge_index[1].size(0), shrink.dst_size, shrink.edge_index[1].max().item())
# scatter(e, shrink.edge_index[1], dim=0, dim_size=shrink.dst_size)
# return
x = self.convs[i](shrink, x)
x = F.relu(x)
x = probe.apply(i, x, dst_idx)
return x
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(
ids=pdata.ids,
edge_index=pdata.edge_index,
num_features=64,
num_layers=3,
cache_device="cpu"
).to(device)
cached_data_0, _ = g.route.forward_a2a(pdata.x).get()
g.cache.replace_layer_data(0, cached_data_0.to(g.cache.cache_device))
probe = NodeProbe(
g.dst_size,
num_layers=3,
num_samples=128,
).to(device)
probe.assign_message_cache(g.cache)
probe.layers[-1].warmup_sample()
net = Net(pdata.num_features, 64, pdata.num_classes, num_layers=3).to(device)
net(g, probe).sum().backward()
\ No newline at end of file
import torch
import torch.distributed as dist
from starrygl.parallel import init_process_group
from typing import *
if __name__ == "__main__":
device = init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 4
xs = [torch.randn(100+i*10, device=device) for i in range(world_size)]
send_sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long, device=device)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
\ No newline at end of file
# import torch
# from starrygl.core.acopy import get_executor, AsyncCopyWorkBase, List
# if __name__ == "__main__":
# data = torch.arange(1024*1024*1024, dtype=torch.float32).view(-1, 1).repeat(1, 8)
# # buffer = torch.arange(1000, dtype=torch.float32).view(-1, 1).repeat(1, 8).pin_memory()
# index = torch.arange(1024*1024 * 256, dtype=torch.long).cuda()
# values = torch.ones(1024*1024 * 256, 8, dtype=torch.float32).cuda()
# pull_works: List[AsyncCopyWorkBase] = []
# push_works: List[AsyncCopyWorkBase] = []
# for s in range(0, index.size(0), 1024 * 1024 * 256):
# t = min(s + 1024 * 1024 * 256, index.size(0))
# idx = index[s:t]
# val = values[idx]
# pullw = get_executor().async_pull(data, idx)
# pushw = get_executor().async_push(data, idx, val)
# pull_works.append(pullw)
# push_works.append(pushw)
# pull_time = 0.0
# for w in pull_works:
# val = w.get()
# pull_time += w.time_used()
# # print(val)
# push_time = 0.0
# for w in push_works:
# w.get()
# push_time += w.time_used()
# # print(data)
# print(pull_time, push_time)
\ No newline at end of file
import torch
import torch.autograd as autograd
from torch import Tensor
from typing import *
class AFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
return a, b.float()
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
print(a, b)
return a, b.bool()
if __name__ == "__main__":
x = torch.rand(10).requires_grad_()
y = torch.ones(10).bool()
a, b = AFunction.apply(x, y)
(a + b).sum().backward()
print(a, b)
print(x.grad, y.grad)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.loader import NodeLoader, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis")
loader = NodeLoader(pdata.ids, pdata.edge_index, device)
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
for handle in loader.iter(128):
with RouteContext() as ctx:
sync_print(handle.src_size, handle.dst_size)
dst_feats = handle.get_dst_feats(buffers[0]).wait()
ext_fut = handle.get_ext_feats(buffers[1])
sync_print(dst_feats.size())
src_feats, fut = handle.push_and_pull(dst_feats[:,:hidden_size], ext_fut, buffers[1])
ctx.add_futures(fut)
sync_print(src_feats.size())
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
# from starrygl.loader import DataLoader, TensorBuffers
from starrygl.loader.route import Route
from starrygl.loader.buffer import TensorBuffer
from starrygl.loader.utils import init_local_edge_index, local_partition_fn
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
# def get_route_table(device = "cpu") -> Tuple[RouteTable, Tensor]:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
# edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
# src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
# return RouteTable(src_ids, ids.size(0)), local_edge_index
# def get_features(device = "cpu") -> Tensor:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
route, local_edge_index = Route.from_edge_index(pdata.ids, pdata.edge_index)
# route_table = RouteTable(src_ids, dst_size)
# sync_print(route_table.fw_0)
# route_table = get_route_table()[0]
# for name, buffer in route_table.named_buffers():
# if name.startswith("fw_") or name.startswith("bw_"):
# sync_print(name, buffer.tolist())
tb = TensorBuffer(pdata.num_features, route.src_size, route).cpu()
print(tb.data.device, pdata.x.device)
tb.data[:pdata.num_nodes] = pdata.x
dist.barrier()
tb.broadcast()
print(tb.data[pdata.num_nodes:].sum(dim=-1))
# assert (tb.data[:pdata.num_nodes] == pdata.x).all()
# local_src_ids = torch.arange(route.src_size)
# for local_ids, remote_ids in route.parts_iter(local_src_ids):
# sync_print(local_ids.size(), remote_ids.size())
# router, edge_index = get_route_table()
# x = get_features()[:,None]
# buffer = TensorBuffers(router, channels=[1])
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
# main_print("+"*64)
# buffer.zero_grad()
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast2()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
rpc.shutdown()
# buffer.remote_get(0, 0, async_op=True).then()
# num_parts = 50
# node_parts = local_partition_fn(dst_size, local_edge_index, num_parts)
# feat_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# grad_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# node_time = torch.rand(dst_size)
# edge_time = torch.rand(local_edge_index.size(1))
# edge_attr = torch.randn(local_edge_index.size(1), 10)
# loader = DataLoader(node_parts, feat_buffer, grad_buffer, local_edge_index, edge_attr, edge_time, node_time)
# feat_buffer.set(0, None, collect_feat0(src_ids, src_ids[:dst_size], pdata.x))
# for handle, edge_index, edge_attr in loader.iter(2, layer_id=0, device=device, filter=lambda x: x < 0.5):
# print(handle.feat0.sum(dim=-1))
# print(edge_index)
# print(edge_attr.sum(dim=-1))
# handle.push_and_pull(handle.feat0[:handle.dst_size], 0)
# assert (handle.fetch_feat() == handle.feat0).all()
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="mpi")
sync_print(dist.get_backend())
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from multiprocessing.pool import ThreadPool
num_threads = 1
num_tasks = 5
pool = ThreadPool(num_threads)
if __name__ == "__main__":
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
group = dist.new_group()
stream = torch.cuda.Stream(device)
handles = []
for i in range(num_tasks):
x = torch.ones(10, dtype=torch.float32, device=device)
def run(x):
# stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
group.allreduce(x)
return x
stream.wait_stream(torch.cuda.current_stream())
handles.append(pool.apply_async(run, (x,)))
x = torch.ones(3, dtype=torch.float32, device=device)
dist.all_reduce(x)
sync_print(x.tolist())
for i in range(num_tasks):
main_print("="*64)
x = handles[i].get()
assert (handles[i].get() == x).all()
sync_print(x.tolist())
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from starrygl.core.gather import Gather
from starrygl.core.straight import Straight
from starrygl.parallel import init_process_group, compute_degree
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_distgraph(device = "cpu") -> DistGraph:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
return DistGraph(ids, edge_index)
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(g: DistGraph, x: Tensor) -> Tensor:
edge_index = g.edge_index
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
return x
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
g = get_distgraph(device)
# sync_print(g.src_ids.tolist())
# main_print("="*64)
# sync_print(g.dst_ids.tolist())
# main_print("="*64)
# sync_print(g.route.forward_routes[0])
# main_print("="*64)
in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
# straight = Straight(g.dst_size, num_samples=1)
# gather = Gather(g.src_size).to(device)
# sync_print(g)
# sync_print(gather.get_buffer("last_embd").tolist())
# main_print("="*64)
sync_print(x.tolist(), g.route.src_size)
main_print("="*64)
a = torch.arange(g.route.src_size, dtype=torch.long, device=device)
# a = None
# z, a = gather(x, None, route=g.route)
z, a = g.route.apply(x, a, {}, async_op=False)
sync_print(z.tolist(), "" if a is None else a.tolist(), g.route.dst_size)
main_print("="*64)
sync_print(g.edge_index[0])
main_print("="*64)
z = reduce(g, z)
sync_print(z.tolist())
main_print("="*64)
# z = straight(x, a, g)
# sync_print(z.tolist())
# main_print("="*64)
loss = z.sum()
sync_print(loss.item())
main_print("="*64)
loss.backward()
sync_print(x.grad.tolist())
main_print("="*64)
sync_print(f"time_used: {g.route.total_time_used}ms")
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.route import Route
from starrygl.graph.utils import init_local_edge_index
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_route(device = "cpu") -> Tuple[Route, Tensor]:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
return Route(ids, src_ids), local_edge_index
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(x: Tensor, edge_index: Tensor, dim_size: int) -> Tensor:
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=dim_size)
return x
if __name__ == "__main__":
device = init_process_group(backend="mpi")
route, edge_index = get_route(device)
src_size = route.dst_size
dst_size = route.src_size
# in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
sync_print(x.tolist())
main_print("="*64)
z, _ = route.apply(x, None)
sync_print(z.tolist())
main_print("="*64)
h = reduce(z, edge_index, dst_size)
sync_print(h.tolist())
main_print("="*64)
h.sum().backward()
sync_print(x.grad.tolist())
main_print("="*64)
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group, all_gather_remote_objects
from starrygl.utils.printer import main_print, sync_print
class A:
def __init__(self, device) -> None:
self.id = rpc.get_worker_info().id
self.data = torch.tensor([self.id], device=device)
@staticmethod
def get_data(rref: rpc.RRef) -> Tensor:
self: A = rref.local_value()
return self.data
def set_remote_a(self, *rrefs: rpc.RRef):
self.rrefs = tuple(rrefs)
def display(self):
sync_print(f"id = {self.id}")
def gather_data(self, device) -> List[Tensor]:
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
f = rpc.rpc_async(self.rrefs[i].owner(), A.get_data, args=(self.rrefs[i],))
futs.append(f)
results: List[Tensor] = []
for f in futs:
f.wait()
results.append(f.value())
return results
if __name__ == "__main__":
device = init_process_group("nccl")
a = A(device)
a.set_remote_a(*all_gather_remote_objects(a))
a.display()
for t in a.gather_data(device):
sync_print(t, t.device)
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.gather import Gather
from starrygl.parallel import init_process_group
from starrygl.parallel.sync_bn import SyncBatchNorm
from starrygl.graph import DistGraph, EID, SID, DID
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
device = init_process_group(backend="nccl")
bn = SyncBatchNorm(100).to(device)
bn.train()
x = torch.randn(10, 100, dtype=torch.float64).to(device).requires_grad_()
sync_print(x)
from torch.autograd import gradcheck
sync_print(gradcheck(lambda x: bn(x), x))
# sync_print(bn(x))
-r requirements.txt
# used by docs
sphinx-autobuild
sphinx_rtd_theme
matplotlib
ipykernel
\ No newline at end of file
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.1+cu118
torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118
torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118
ogb
tqdm
\ No newline at end of file
......@@ -2,14 +2,17 @@ import torch
import logging
__version__ = "0.1.0"
try:
from .lib import libstarrygl_ops as ops
from .lib import libstarrygl as ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.")
try:
from .lib import libstarrygl_ops_sampler as sampler_ops
from .lib import libstarrygl_sampler as sampler_ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl_sampler.so, some features may not be available.")
\ No newline at end of file
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
......@@ -20,6 +20,9 @@ __all__ = [
]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
......@@ -168,34 +171,100 @@ class GraphData:
return g
@staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"):
def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}")
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
logging.info(f"running partition aglorithm: {algorithm}")
partition_kwargs = partition_kwargs or {}
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
if algorithm == "metis":
node_parts = metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index,
num_nodes, num_parts,
**partition_kwargs,
)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
......@@ -213,13 +282,19 @@ class GraphData:
raw_dst_ids=raw_dst_ids,
)
for key in self.node().keys():
for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys():
for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys():
for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
......
......@@ -152,8 +152,11 @@ def batch_send(
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
dst = dist.get_global_rank(group, dst)
if async_op:
works = []
......@@ -177,8 +180,11 @@ def batch_recv(
if len(tensors) == 0:
return BatchWork(None, None)
if group is None:
group = dist.GroupMember.WORLD
# tensors = tuple(t.data for t in tensors)
backend = dist.get_backend(group)
src = dist.get_global_rank(group, src)
if async_op:
works = []
......
......@@ -7,6 +7,7 @@ import os
from torch import Tensor
from typing import *
import socket
from contextlib import contextmanager
import logging
......@@ -24,6 +25,7 @@ class DistributedContext:
@staticmethod
def init(
backend: str,
use_rpc: bool = False,
use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext':
......@@ -63,7 +65,9 @@ class DistributedContext:
rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size,
local_rank=local_rank,
use_gpu=use_gpu, rpc_gpu=rpc_gpu,
use_rpc=use_rpc,
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
......@@ -86,7 +90,9 @@ class DistributedContext:
rpc_init_method: str,
rank: int, world_size: int,
local_rank: int,
use_gpu: bool, rpc_gpu: bool,
use_rpc: bool,
use_gpu: bool,
rpc_gpu: bool,
) -> None:
if use_gpu:
device = torch.device(f"cuda:{local_rank}")
......@@ -112,18 +118,33 @@ class DistributedContext:
device_map={local_rank: dev},
)
if use_rpc:
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
self._use_rpc = use_rpc
self._local_rank = local_rank
self._compute_device = device
self._hostname = socket.gethostname()
if self.device.type == "cuda":
torch.cuda.set_device(self.device)
rank_to_host = [None] * self.world_size
dist.all_gather_object(rank_to_host, (self.hostname, self.local_rank))
self._rank_to_host: Tuple[Tuple[str, int], ...] = tuple(rank_to_host)
host_index = [h for h, _ in self.rank_to_host]
host_index.sort()
self._host_index: Dict[str, int] = {h:i for i, h in enumerate(host_index)}
self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self):
if self._use_rpc:
rpc.shutdown()
@property
......@@ -139,15 +160,86 @@ class DistributedContext:
return self._local_rank
@property
def hostname(self) -> str:
return self._hostname
@property
def rank_to_host(self):
return self._rank_to_host
@property
def host_index(self):
return self._host_index
@property
def device(self) -> torch.device:
return self._compute_device
def get_default_group(self):
return dist.distributed_c10d._get_default_group()
# return dist.distributed_c10d._get_default_group()
return dist.GroupMember.WORLD
def get_default_store(self):
return dist.distributed_c10d._get_default_store()
def get_ranks_by_host(self, hostname: Optional[str] = None) -> Tuple[int,...]:
if hostname is None:
hostname = self.hostname
ranks: List[int] = []
for i, (h, r) in enumerate(self.rank_to_host):
if h == hostname:
ranks.append(i)
ranks.sort()
return tuple(ranks)
def get_ranks_by_local(self, local_rank: Optional[int] = None) -> Tuple[int,...]:
if local_rank is None:
local_rank = self.local_rank
ranks: List[Tuple[int, str]] = []
for i, (h, r) in enumerate(self.rank_to_host):
if r == local_rank:
ranks.append((i, h))
ranks.sort(key=lambda x: self.host_index[x[1]])
return tuple(i for i, h in ranks)
def get_hybrid_matrix(self) -> Tensor:
hosts = sorted(self.host_index.items(), key=lambda x: x[1])
matrix = []
for h, _ in hosts:
rs = self.get_ranks_by_host(h)
matrix.append(rs)
return torch.tensor(matrix, dtype=torch.long, device="cpu")
def new_hybrid_subgroups(self,
matrix: Optional[Tensor] = None,
backend: Any = None,
) -> Tuple[Any, Any]:
if matrix is None:
matrix = self.get_hybrid_matrix()
assert matrix.dim() == 2
row_group = None
col_group = None
for row in matrix.tolist():
if self.rank in row:
row_group = dist.new_group(
row, backend=backend,
use_local_synchronization=True,
)
break
for col in matrix.t().tolist():
if self.rank in col:
col_group = dist.new_group(
col, backend=backend,
use_local_synchronization=True,
)
break
assert row_group is not None
assert col_group is not None
return row_group, col_group
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
......
from .convs import GCNConv, GATConv, GINConv
# from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
# from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
# class ShrinkGCN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGCNConv(in_channels, out_channels, **kwargs)
# class ShrinkGAT(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGATConv(in_channels, out_channels, **kwargs)
# class ShrinkGIN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
# import torch
# import torch.nn as nn
# import torch.distributed as dist
# from torch import Tensor
# from typing import *
# from starrygl.loader import BatchHandle
# class BaseLayer(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
# return x
# def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat()
# with torch.no_grad():
# x = self.forward(x, edge_index, edge_attr)
# handle.update_feat(x)
# def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat().requires_grad_()
# g = handle.fetch_grad()
# self.forward(x, edge_index, edge_attr).backward(g)
# handle.accumulate_grad(x.grad)
# x.grad = None
# def all_reduce_grad(self):
# for p in self.parameters():
# if p.grad is not None:
# dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# class BaseModel(nn.Module):
# def __init__(self,
# num_features: int,
# layers: List[int],
# prev_layer: bool = False,
# post_layer: bool = False,
# ) -> None:
# super().__init__()
# def init_prev_layer(self) -> Tensor:
# pass
# def init_post_layer(self) -> Tensor:
# pass
# def init_conv_layer(self) -> Tensor:
# pass
\ No newline at end of file
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
# from .shrink_gcn_conv import ShrinkGCNConv
# from .shrink_gat_conv import ShrinkGATConv
# from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
H, C = self.heads, self.out_channels
x = g.ndata["x"]
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GCNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tensor:
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
edge_index = g.edge_index
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GINConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mlp_channels is None:
mlp_channels = in_channels + out_channels
self.mlp_channels = mlp_channels
self.initial_eps = eps
self.nn = nn.Sequential(
nn.Linear(in_channels, mlp_channels),
nn.ReLU(),
nn.Linear(mlp_channels, out_channels),
)
if train_eps:
self.eps = torch.nn.Parameter(torch.tensor[eps])
else:
self.register_buffer("eps", torch.tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
for name, param in self.nn.named_parameters():
if "weight" in name:
nn.init.xavier_normal_(param)
if "bias" in name:
nn.init.zeros_(param)
nn.init.constant_(self.eps, self.initial_eps)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
x = g.ndata["x"]
edge_index = g.edge_index
z = x[edge_index[0]]
z = scatter_sum(z, edge_index[1], dim=0, dim_size=g.dst_size)
x = z + (1 + self.eps) * x[:g.dst_size]
return self.nn(x)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch_geometric.utils import softmax
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gat_conv import GATConv
# from .utils import ShrinkHelper
# class ShrinkGATConv(GATConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# heads: int = 1,
# concat: bool = False,
# negative_slope: float = 0.2,
# dropout: float = 0.0,
# edge_dim: Optional[int] = None,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# heads=heads,
# concat=concat,
# negative_slope=negative_slope,
# dropout=dropout,
# edge_dim=edge_dim,
# bias=bias,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# H, C = self.heads, self.out_channels
# x = g.ndata["x"]
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
# src_x = (src_x @ self.weight).view(-1, H, C)
# dst_x = (dst_x @ self.weight).view(-1, H, C)
# alpha_i = (src_x * self.att_src).sum(dim=-1)
# alpha_j = alpha_j[edge_index[0]]
# alpha_i = (dst_x * self.att_dst).sum(dim=-1)
# alpha_i = alpha_i[edge_index[1]]
# if self.edge_dim is not None:
# edge_attr = g.edata["edge_attr"]
# edge_attr = edge_attr[sh.edge_idx]
# if edge_attr.dim() == 1:
# edge_attr = edge_attr.view(-1, 1)
# e = (edge_attr @ self.lin_edge).view(-1, H, C)
# alpha_e = (e * self.att_edge).sum(dim=-1)
# alpha = alpha_i + alpha_j + alpha_e
# else:
# alpha = alpha_i + alpha_j
# alpha = F.leaky_relu(alpha, self.negative_slope)
# alpha = softmax(
# src=alpha,
# index=edge_index[1],
# num_nodes=sh.dst_size,
# )
# alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# x = x[edge_index[0]] * alpha.view(-1, H, 1)
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.concat:
# x = x.view(-1, H * C)
# else:
# x = x.mean(dim=1)
# if self.bias is not None:
# x += self.bias
# return x
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gcn_conv import GCNConv
# from .utils import ShrinkHelper
# class ShrinkGCNConv(GCNConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# bias: bool = True,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# bias=bias,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# x = g.ndata["x"]
# gcn_norm = g.edata["gcn_norm"].view(-1, 1)
# x = x[sh.src_idx]
# gcn_norm = gcn_norm[sh.edge_idx]
# edge_index = sh.edges
# x = x @ self.weight
# x = x[edge_index[0]] * gcn_norm
# x = scatter_sum(x, edge_index[1], dim=0, dim_size=sh.dst_size)
# if self.bias is not None:
# x += self.bias
# return x
# import torch
# import torch.nn as nn
# from torch_scatter import scatter_sum
# from torch import Tensor
# from typing import *
# from starrygl.graph import DistGraph
# from .gin_conv import GINConv
# from .utils import ShrinkHelper
# class ShrinkGINConv(GINConv):
# def __init__(self,
# in_channels: int,
# out_channels: int,
# mlp_channels: Optional[int] = None,
# eps: float = 0,
# train_eps: bool = False,
# **kwargs
# ) -> None:
# super().__init__(
# in_channels=in_channels,
# out_channels=out_channels,
# mlp_channels=mlp_channels,
# eps=eps,
# train_eps=train_eps,
# **kwargs
# )
# def forward(self,
# g: DistGraph,
# sh: Optional[ShrinkHelper] = None,
# dst_idx: Optional[Tensor] = None,
# ) -> Tensor:
# if sh is None and dst_idx is None:
# return super().forward(g)
# if sh is None:
# sh = ShrinkHelper(g, dst_idx)
# x = g.ndata["x"]
# sh = ShrinkHelper(g, dst_idx)
# src_x = x[sh.src_idx]
# dst_x = x[sh.dst_idx]
# edge_index = sh.edges
# z = scatter_sum(src_x[edge_index[0]], index=edge_index[1], dim=0, dim_size=sh.dst_size)
# x = z + (1 + self.eps) * dst_x
# return self.nn(x)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class ShrinkHelper:
# def __init__(self, g, dst_idx: Tensor) -> None:
# from starrygl.graph import DistGraph
# g: DistGraph = g
# self.device = dst_idx.device
# dst_m = torch.zeros(g.dst_size, dtype=torch.bool, device=self.device)
# dst_m.index_fill_(0, dst_idx, 1)
# edge_idx = torch.where(dst_m[g.edge_index[1]])[0]
# edge_index = g.edge_index[:, edge_idx]
# src_idx = edge_index[0]
# src_m = torch.zeros(g.src_size, dtype=torch.bool, device=self.device)
# src_m.index_fill_(0, src_idx, 1)
# src_idx = torch.where(src_m)[0]
# imp = torch.empty(max(g.src_size, g.dst_size), dtype=torch.long, device=self.device)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=self.device)
# src = imp[edge_index[0]]
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=self.device)
# dst = imp[edge_index[1]]
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edges = torch.vstack([src, dst])
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
\ No newline at end of file
ldg_partition @ da31e8bb
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment