Commit 57d75031 by Wenjie Huang

Refactor

parent e4ac1c7a
......@@ -160,6 +160,8 @@ cython_debug/
#.idea/
cora/
/dataset
/test_*
/*.ipynb
/s.py
/third_party
{
"cmake.configureOnOpen": true,
"cmake.configureSettings": {
"CMAKE_PREFIX_PATH": "/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages",
"Python3_ROOT_DIR": "/home/hwj/.miniconda3/envs/sgl",
"CUDA_TOOLKIT_ROOT_DIR": "/home/hwj/.local/cuda-11.7"
},
}
\ No newline at end of file
cmake_minimum_required(VERSION 3.15)
project(starrygl_ops VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
find_package(OpenMP REQUIRED)
# add_subdirectory(third_party/cccl-2.2.0)
# add_subdirectory(third_party/cutlass-3.2.1)
# add_library(${PROJECT_NAME} SHARED csrc/custom_ops.cpp csrc/custom_kernel.cu)
add_library(${PROJECT_NAME} SHARED csrc/custom_ops.cpp)
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(${PROJECT_NAME} PRIVATE OpenMP::OpenMP_CXX)
# target_link_libraries(${PROJECT_NAME} PRIVATE CCCL::CCCL)
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
# set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" OUTPUT_NAME "_C")
install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_SOURCE_DIR}/starrygl/lib")
// #include <cub/block/block_reduce.cuh>
\ No newline at end of file
#include <torch/extension.h>
#include <omp.h>
torch::Tensor add(torch::Tensor a, torch::Tensor b) {
return a + b;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add", &add, "a function implemented using pybind11");
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hwj/.miniconda3/envs/pyg/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def act_blks(act_mask: Tensor, blk_size: int) -> Tensor:\n",
" assert act_mask.dtype == torch.bool\n",
" assert act_mask.dim() == 1\n",
" n = act_mask.size(0)\n",
" m = (n + blk_size - 1) // blk_size * blk_size\n",
"\n",
" blk_mask = act_mask.detach()\n",
" if n != m:\n",
" blk_mask = torch.zeros(m, dtype=act_mask.dtype, device=act_mask.device)\n",
" blk_mask[:n] = act_mask.detach()\n",
" return blk_mask.view(-1, blk_size).max(dim=-1).values"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"act_mask = torch.rand(1000) < 0.01\n",
"blk_mask = act_blks(act_mask, 64)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cached_blocks = torch.randn(blk_mask.size(0), 64, 8)\n",
"blks = cached_blocks[blk_mask]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def gather_block(blks: Tensor, blk_mask: Tensor, index: Tensor) -> Tensor:\n",
" blk_size = blks.size(1)\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
" \n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
" \n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" s = index.shape[:1] + blks.shape[2:]\n",
" data = torch.zeros(s, dtype=blks.dtype, device=blks.device)\n",
" data[ind_mask] = blks[idx0, idx1, ...]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def scatter_block(x: Tensor, index: Tensor, blk_mask: Tensor, blk_size: int, reduce: str = \"sum\") -> Tensor:\n",
" num_blks = blk_mask.count_nonzero()\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
"\n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
"\n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" data = scatter(\n",
" src=x[ind_mask],\n",
" index=idx0*blk_size+idx1,\n",
" dim=0,\n",
" dim_size=num_blks*blk_size,\n",
" reduce=reduce,\n",
" )\n",
" return data.view(num_blks, blk_size, *x.shape[1:])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# def index_block(index: Tensor, blk_mask: Tensor, blk_size: int) -> Tensor:\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, True, True, False, True, False, False, True, False, True,\n",
" False, True, True, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, True, False, False, True,\n",
" True, False, False, True, True, False, True, True, False, False,\n",
" True, False, True, False, False, False, False, False, False, False,\n",
" True, True, False, False, False, True, True, True, False, False,\n",
" True, False, True, False, True, False, True, False, False, True,\n",
" False, True, True, True, False, False, False, False, False, True,\n",
" False, True, False, True, True, True, False, True, False, True,\n",
" False, True, False, True, True, False, False, True, False, True])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index = torch.randint(1000, size=(100,), dtype=torch.long)\n",
"(gather_block(blks, blk_mask, index)**2).sum(dim=-1) != 0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 64, 8])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(100, 8)\n",
"scatter_block(x, index, blk_mask, 64).size()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True, True, True, True, True])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = scatter_block(x, index, blk_mask, 64)\n",
"t = t.view(-1, *t.shape[2:])\n",
"act_blks((t**2).sum(dim=-1) != 0, 64)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.sort(\n",
"values=tensor([ 5, 15, 18, 19, 25, 48, 62, 64, 68, 76, 88, 92, 94, 106,\n",
" 114, 115, 120, 122, 126, 131, 135, 165, 169, 169, 177, 192, 214, 215,\n",
" 245, 249, 255, 269, 282, 286, 294, 299, 334, 358, 359, 372, 401, 403,\n",
" 430, 437, 455, 459, 459, 459, 471, 484, 490, 504, 505, 516, 528, 529,\n",
" 559, 584, 613, 614, 614, 624, 629, 637, 638, 674, 679, 701, 706, 735,\n",
" 742, 746, 748, 750, 759, 774, 774, 776, 778, 787, 803, 825, 831, 834,\n",
" 861, 862, 868, 877, 882, 889, 899, 909, 921, 933, 938, 954, 960, 969,\n",
" 981, 986]),\n",
"indices=tensor([ 6, 77, 14, 41, 43, 0, 13, 53, 38, 54, 15, 76, 98, 47, 75, 80, 8, 95,\n",
" 17, 34, 83, 87, 62, 93, 81, 20, 67, 39, 32, 5, 70, 4, 89, 72, 37, 64,\n",
" 94, 42, 57, 29, 56, 21, 60, 7, 45, 59, 86, 96, 49, 92, 28, 88, 24, 68,\n",
" 3, 61, 44, 2, 79, 55, 85, 33, 12, 1, 11, 97, 50, 9, 84, 40, 71, 22,\n",
" 26, 30, 73, 16, 90, 27, 19, 35, 48, 18, 46, 74, 65, 78, 52, 31, 23, 58,\n",
" 91, 66, 99, 69, 51, 36, 82, 63, 25, 10]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.sort()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 383,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.autograd as autograd\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter\n",
"from torch_geometric.utils import softmax\n",
"\n",
"import triton\n",
"import triton.language as tl\n",
"\n",
"from copy import deepcopy"
]
},
{
"cell_type": "code",
"execution_count": 319,
"metadata": {},
"outputs": [],
"source": [
"# class SoftmaxSumFunction(autograd.Function):\n",
"# @staticmethod\n",
"# def forward(\n",
"# ctx: autograd.function.FunctionCtx,\n",
"# x: Tensor,\n",
"# a: Tensor,\n",
"# _x: Tensor,\n",
"# index: Tensor,\n",
"# mask: Tensor,\n",
"# exp_a: Tensor,\n",
"# sum_exp_a: Tensor,\n",
"# sum_exp_ah: Tensor,\n",
"# training: bool,\n",
"# ):\n",
"# dlt_exp_ah = exp_a[mask, None] * (x - _x)\n",
"# sum_exp_ah = scatter(dlt_exp_ah, index, dim=0, out=sum_exp_ah.clone(), reduce=\"sum\")\n",
"# out = sum_exp_ah / sum_exp_a[:, None]\n",
"# if training:\n",
"# w = exp_a[mask] / sum_exp_a.index_select(0, index)\n",
"# ctx.save_for_backward(out, x, w, index)\n",
"# return out\n",
" \n",
"# @staticmethod\n",
"# def backward(\n",
"# ctx: autograd.function.FunctionCtx,\n",
"# grad: Tensor,\n",
"# ):\n",
"# out, x, w, index = ctx.saved_tensors\n",
"# g = grad.index_select(0, index)\n",
"# o = out.index_select(0, index)\n",
"\n",
"# grad_x = g * w[:, None]\n",
"# grad_a = (g * (x - o)).sum(dim=-1) * w\n",
"# return grad_x, grad_a, None, None, None, None, None, None, None"
]
},
{
"cell_type": "code",
"execution_count": 320,
"metadata": {},
"outputs": [],
"source": [
"# class SoftmaxSum(nn.Module):\n",
"# def __init__(self,\n",
"# num_nodes: int,\n",
"# num_edges: int,\n",
"# num_features: int,\n",
"# dtype: Optional[torch.dtype] = None,\n",
"# ) -> None:\n",
"# super().__init__()\n",
"# self.num_nodes = num_nodes\n",
"# self.num_edges = num_edges\n",
"\n",
"# self.register_buffer(\"max_hist_a\", torch.zeros(num_nodes, dtype=dtype))\n",
"# self.register_buffer(\"exp_hist_a\", torch.zeros(num_edges, dtype=dtype))\n",
"# self.register_buffer(\"sum_exp_a\", torch.zeros(num_nodes, dtype=dtype))\n",
"# self.register_buffer(\"sum_exp_ah\", torch.zeros(num_nodes, num_features, dtype=dtype))\n",
" \n",
"# @torch.no_grad()\n",
"# def update_snapshot(self, x: Tensor, a: Tensor, index: Tensor):\n",
"# max_a = scatter(a, index, dim=0, dim_size=self.num_nodes, reduce=\"max\")\n",
"# exp_a = (a - max_a.index_select(0, index)).exp_()\n",
"# self.get_buffer(\"max_hist_a\")[:] = max_a\n",
"# self.get_buffer(\"exp_hist_a\")[:] = exp_a\n",
"\n",
"# sum_exp_a = scatter(exp_a, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\") + 1e-10\n",
"# sum_exp_ah = scatter(exp_a[:,None] * x, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\")\n",
"# self.get_buffer(\"sum_exp_a\")[:] = sum_exp_a\n",
"# self.get_buffer(\"sum_exp_ah\")[:] = sum_exp_ah\n",
"# return self\n",
" \n",
"# @torch.no_grad()\n",
"# def update_attention(self, x: Tensor, a: Tensor, _x: Tensor, index: Tensor, mask: Tensor):\n",
"# exp_a = (a - self.get_buffer(\"max_hist_a\").index_select(0, index)).exp_()\n",
"# exp_ah = exp_a[:, None] * x\n",
"\n",
"# exp_hist_a = self.get_buffer(\"exp_hist_a\")[mask]\n",
"# exp_hist_ah = exp_hist_a[:, None] * _x\n",
"\n",
"# self.get_buffer(\"exp_hist_a\")[mask] = exp_a\n",
"# scatter(exp_a - exp_hist_a, index, dim=0, out=self.get_buffer(\"sum_exp_a\"), reduce=\"sum\")\n",
"# scatter(exp_ah - exp_hist_ah, index, dim=0, out=self.get_buffer(\"sum_exp_ah\"), reduce=\"sum\")\n",
"# return self\n",
" \n",
"# def forward(self, x: Tensor, a: Tensor, _x: Tensor, index: Tensor, mask: Tensor):\n",
"# return SoftmaxSumFunction.apply(x, a, _x, index, mask,\n",
"# self.get_buffer(\"exp_hist_a\"),\n",
"# self.get_buffer(\"sum_exp_a\"),\n",
"# self.get_buffer(\"sum_exp_ah\"),\n",
"# self.training)"
]
},
{
"cell_type": "code",
"execution_count": 321,
"metadata": {},
"outputs": [],
"source": [
"# num_nodes = 200\n",
"# num_edges = 100\n",
"# num_features = 8\n",
"# for i in range(10):\n",
"# x = torch.randn(num_edges, num_features)\n",
"# a = torch.randn(num_edges)\n",
"# index = torch.randint(num_nodes, size=(num_edges,), dtype=torch.long)\n",
"# mask = torch.ones(num_edges, dtype=torch.bool)\n",
"\n",
"# ss = SoftmaxSum(num_nodes, num_edges, num_features)\n",
"# ss.update_snapshot(x, a, index)\n",
"# ss.update_attention(x, a, x, index, mask)\n",
"\n",
"# x.requires_grad_()\n",
"# a.requires_grad_()\n",
"# ss(x, a, x, index, mask).sum().backward()\n",
"# grad_x0 = x.grad.clone()\n",
"# grad_a0 = a.grad.clone()\n",
"\n",
"# x.grad = None\n",
"# a.grad = None\n",
"# t = softmax(a, index, num_nodes=num_nodes)[:,None] * x\n",
"# scatter(t, index, dim=0, dim_size=num_nodes, reduce=\"sum\").sum().backward()\n",
"\n",
"# # print(torch.norm(grad_x0 - x.grad), torch.norm(grad_a0 - a.grad))\n",
"# # print(torch.abs(grad_x0 - x.grad).max(), torch.abs(grad_a0 - a.grad).max())\n",
"# # print(grad_x0)\n",
"# # print(grad_a0)\n",
"# print(grad_a0 - a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 355,
"metadata": {},
"outputs": [],
"source": [
"class SoftmaxSumFunction(autograd.Function):\n",
" @staticmethod\n",
" def forward(\n",
" ctx: autograd.function.FunctionCtx,\n",
" x1: Tensor,\n",
" a1: Tensor,\n",
" x0: Tensor,\n",
" a0: Tensor,\n",
" index: Tensor,\n",
" max_a: Tensor,\n",
" sum_exp_a: Tensor,\n",
" sum_exp_ah: Tensor,\n",
" training: bool,\n",
" ):\n",
" exp_a = (a0 - max_a.index_select(0, index)).exp_()\n",
" dlt_exp_ah = exp_a[:, :, None] * (x1 - x0)\n",
" sum_exp_ah = scatter(dlt_exp_ah, index, dim=0, out=sum_exp_ah.clone(), reduce=\"sum\")\n",
" out = sum_exp_ah / sum_exp_a[:, :, None]\n",
" if training:\n",
" w = exp_a / sum_exp_a.index_select(0, index)\n",
" ctx.save_for_backward(out, x1, w, index)\n",
" return out\n",
" \n",
" @staticmethod\n",
" def backward(\n",
" ctx: autograd.function.FunctionCtx,\n",
" grad: Tensor,\n",
" ):\n",
" out, x, w, index = ctx.saved_tensors\n",
" g = grad.index_select(0, index)\n",
" o = out.index_select(0, index)\n",
"\n",
" grad_x = g * w[:, :, None]\n",
" grad_a = (g * (x - o)).sum(dim=-1) * w\n",
" return grad_x, grad_a, None, None, None, None, None, None, None\n",
"\n",
"class SoftmaxSum(nn.Module):\n",
" def __init__(self,\n",
" num_nodes: int,\n",
" num_edges: int,\n",
" num_heads: int,\n",
" num_features: int,\n",
" dtype: Optional[torch.dtype] = None,\n",
" ) -> None:\n",
" super().__init__()\n",
" self.num_nodes = num_nodes\n",
" self.num_edges = num_edges\n",
" self.num_heads = num_heads\n",
"\n",
" self.register_buffer(\"max_a\", torch.zeros(num_nodes, num_heads, dtype=dtype))\n",
" self.register_buffer(\"sum_exp_a\", torch.zeros(num_nodes, num_heads, dtype=dtype))\n",
" self.register_buffer(\"sum_exp_ah\", torch.zeros(num_nodes, num_heads, num_features, dtype=dtype))\n",
" \n",
" @torch.no_grad()\n",
" def update_snapshot(self, x: Tensor, a: Tensor, index: Tensor):\n",
" max_a = scatter(a, index, dim=0, dim_size=self.num_nodes, reduce=\"max\")\n",
" exp_a = (a - max_a.index_select(0, index)).exp_()\n",
" self.get_buffer(\"max_a\")[:] = max_a\n",
"\n",
" sum_exp_a = scatter(exp_a, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\") + 1e-10\n",
" sum_exp_ah = scatter(exp_a[:, :, None] * x, index, dim=0, dim_size=self.num_nodes, reduce=\"sum\")\n",
" self.get_buffer(\"sum_exp_a\")[:] = sum_exp_a\n",
" self.get_buffer(\"sum_exp_ah\")[:] = sum_exp_ah\n",
" return self\n",
" \n",
" @torch.no_grad()\n",
" def update_attention(self, x1: Tensor, a1: Tensor, x0: Tensor, a0: Tensor, index: Tensor):\n",
" max_hist_a = self.get_buffer(\"max_a\").index_select(0, index)\n",
" exp_a0 = (a0 - max_hist_a).exp_()\n",
" exp_a1 = (a1 - max_hist_a).exp_()\n",
" scatter(exp_a1 - exp_a0, index, dim=0, out=self.get_buffer(\"sum_exp_a\"), reduce=\"sum\")\n",
"\n",
" exp_ah0 = exp_a0[:, :, None] * x0\n",
" exp_ah1 = exp_a1[:, :, None] * x1\n",
" scatter(exp_ah1 - exp_ah0, index, dim=0, out=self.get_buffer(\"sum_exp_ah\"), reduce=\"sum\")\n",
" return self\n",
" \n",
" def forward(self, x1: Tensor, a1: Tensor, x0: Tensor, a0: Tensor, index: Tensor) -> Tensor:\n",
" return SoftmaxSumFunction.apply(x1, a1, x0, a0, index,\n",
" self.get_buffer(\"max_a\"),\n",
" self.get_buffer(\"sum_exp_a\"),\n",
" self.get_buffer(\"sum_exp_ah\"),\n",
" self.training)"
]
},
{
"cell_type": "code",
"execution_count": 359,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.) tensor(6.2413e-06)\n",
"tensor([[-2.1367, -3.5242],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.8726, -6.0982]])\n",
"tensor([[-2.1367, -3.5242],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.8726, -6.0982]])\n",
"tensor(0.) tensor(7.5973e-06)\n",
"tensor([[ 4.1189, -1.3866],\n",
" [ 6.2586, -4.8426],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7955, 8.4197],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 4.1189, -1.3866],\n",
" [ 6.2586, -4.8426],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7955, 8.4197],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.0887e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [-1.7955, 4.7750],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-3.7791, -3.4205],\n",
" [ 3.1130, 3.6583],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [-1.7955, 4.7750],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-3.7791, -3.4205],\n",
" [ 3.1130, 3.6583],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(5.9715e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 9.4426, -2.2261],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7483, 3.0900],\n",
" [-9.4426, 2.2261]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 9.4426, -2.2261],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.7483, 3.0900],\n",
" [-9.4426, 2.2261]])\n",
"tensor(0.) tensor(7.1696e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.2263, 0.6597],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.2893, -4.1325],\n",
" [-3.9234, -3.9451],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.2263, 0.6597],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.2893, -4.1325],\n",
" [-3.9234, -3.9451],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.6798e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 1.8056, 3.9442],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.4050, -5.5951]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 1.8056, 3.9442],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-2.4050, -5.5951]])\n",
"tensor(0.) tensor(4.7071e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ -3.6159, -10.4370],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.1697, 4.4867],\n",
" [ 0.8153, 9.0103]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ -3.6159, -10.4370],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 2.1697, 4.4867],\n",
" [ 0.8153, 9.0103]])\n",
"tensor(0.) tensor(6.2858e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 5.5491, -3.5294],\n",
" [ 2.2001, -5.5199],\n",
" [ 1.6753, 1.5286],\n",
" [ 5.7522, -2.5286]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 5.5491, -3.5294],\n",
" [ 2.2001, -5.5199],\n",
" [ 1.6753, 1.5286],\n",
" [ 5.7522, -2.5286]])\n",
"tensor(0.) tensor(5.8452e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-0.3675, 3.0714],\n",
" [ 0.4008, -9.1039],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-0.3675, 3.0714],\n",
" [ 0.4008, -9.1039],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000]])\n",
"tensor(0.) tensor(6.5390e-06)\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 4.4908, 0.8768],\n",
" [ 0.1810, 2.1797],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-4.4908, -0.8768]])\n",
"tensor([[ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 4.4908, 0.8768],\n",
" [ 0.1810, 2.1797],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000],\n",
" [-4.4908, -0.8768]])\n"
]
}
],
"source": [
"num_nodes = 200\n",
"num_edges = 100\n",
"num_heads = 2\n",
"num_features = 256\n",
"for i in range(10):\n",
" x = torch.randn(num_edges, num_heads, num_features)\n",
" a = torch.randn(num_edges, num_heads)\n",
" index = torch.randint(num_nodes, size=(num_edges,), dtype=torch.long)\n",
"\n",
" ss = SoftmaxSum(num_nodes, num_edges, num_heads, num_features)\n",
" ss.update_snapshot(x, a, index)\n",
" ss.update_attention(x, a, x, a, index)\n",
"\n",
" x.requires_grad_()\n",
" a.requires_grad_()\n",
" ss(x, a, x, a, index).sum().backward()\n",
" grad_x0 = x.grad.clone()\n",
" grad_a0 = a.grad.clone()\n",
"\n",
" x.grad = None\n",
" a.grad = None\n",
" t = softmax(a, index, num_nodes=num_nodes)[:, :, None] * x\n",
" scatter(t, index, dim=0, dim_size=num_nodes, reduce=\"sum\").sum().backward()\n",
"\n",
" print(torch.norm(grad_x0 - x.grad), torch.norm(grad_a0 - a.grad))\n",
" # print(torch.abs(grad_x0 - x.grad).max(), torch.abs(grad_a0 - a.grad).max())\n",
" # print(grad_x0)\n",
" print(grad_a0[:8, :])\n",
" print(a.grad[:8, :])\n",
" # print(grad_a0 - a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 399,
"metadata": {},
"outputs": [],
"source": [
"class GATConv(nn.Module):\n",
" def __init__(self,\n",
" in_channels: int,\n",
" out_channels: int,\n",
" heads: int = 1,\n",
" concat: bool = False,\n",
" negative_slope: float = 0.2,\n",
" bias: bool = True,\n",
" **kwargs\n",
" ) -> None:\n",
" super().__init__()\n",
" \n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
" self.heads = heads\n",
" self.concat = concat\n",
" self.negative_slope = negative_slope\n",
" self._softmax_sum_list = [None]\n",
" \n",
" self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))\n",
" self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))\n",
" self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))\n",
" \n",
" if bias and concat:\n",
" self.bias = nn.Parameter(torch.Tensor(heads * out_channels))\n",
" elif bias and not concat:\n",
" self.bias = nn.Parameter(torch.Tensor(out_channels))\n",
" else:\n",
" self.bias = None\n",
"\n",
" self.reset_parameters()\n",
" \n",
" def reset_parameters(self):\n",
" nn.init.xavier_normal_(self.weight)\n",
" nn.init.xavier_normal_(self.att_src)\n",
" nn.init.xavier_normal_(self.att_dst)\n",
" if self.bias is not None:\n",
" nn.init.zeros_(self.bias)\n",
" \n",
" @property\n",
" def softmax_sum(self) -> Optional[SoftmaxSum]:\n",
" return self._softmax_sum_list[0]\n",
" \n",
" def set_softmax_sum(self, m: Optional[SoftmaxSum]):\n",
" self._softmax_sum_list = [m]\n",
" return self\n",
" \n",
" def train(self, mode: bool = True):\n",
" if self.softmax_sum is not None:\n",
" self.softmax_sum.train(mode)\n",
" return super().train(mode)\n",
" \n",
" def update_softmax_snapshot(self,\n",
" x: Tensor,\n",
" edge_index: Tensor,\n",
" snap_net: nn.Module,\n",
" ):\n",
" snap_net: GATConv = snap_net.eval()\n",
" with torch.no_grad():\n",
" x, a = snap_net._att_x(x, edge_index)\n",
" self.softmax_sum.update_snapshot(x, a, edge_index[1])\n",
"\n",
" def _att_x(self, x: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]:\n",
" H, C = self.heads, self.out_channels\n",
" x = (x @ self.weight).view(-1, H, C)\n",
"\n",
" src_a = (x * self.att_src).sum(dim=-1)\n",
" src_a = src_a[edge_index[0]]\n",
" dst_a = (x * self.att_dst).sum(dim=-1)\n",
" dst_a = dst_a[edge_index[1]]\n",
"\n",
" x = x[edge_index[0]]\n",
" a = F.leaky_relu(src_a + dst_a, self.negative_slope)\n",
" return x, a\n",
" \n",
" def forward(self,\n",
" x1: Tensor,\n",
" x0: Tensor,\n",
" edge_index: Tensor,\n",
" snap_net: nn.Module,\n",
" update_att: bool,\n",
" ) -> Tensor:\n",
" H, C = self.heads, self.out_channels\n",
" snap_net: GATConv = snap_net.eval()\n",
"\n",
" index = edge_index[1]\n",
" with torch.no_grad():\n",
" x0, a0 = snap_net._att_x(x0, edge_index)\n",
" if update_att:\n",
" x, a = snap_net._att_x(x1, edge_index)\n",
" self.softmax_sum.update_attention(x, a, x0, a0, index)\n",
" x0, a0 = x, a\n",
" \n",
" x1, a1 = self._att_x(x1, edge_index)\n",
" out = self.softmax_sum.forward(x1, a1, x0, a0, index)\n",
" \n",
" if self.concat:\n",
" out = out.view(-1, H * C)\n",
" else:\n",
" out = out.mean(dim=1)\n",
"\n",
" if self.bias is not None:\n",
" out += self.bias\n",
" return out\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 400,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def slim_local_graph(src_mask: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor]:\n",
" edge_mask = src_mask[edge_index[0]]\n",
" slim_edge_index = edge_index[:, edge_mask]\n",
" \n",
" node_mask = src_mask.clone().index_fill_(0, slim_edge_index[1], 1)\n",
"\n",
" idx = torch.where(node_mask)[0]\n",
" imp = torch.empty_like(src_mask, dtype=torch.long).fill_((2**62-1)*2+1)\n",
" imp[idx] = torch.arange(idx.size(0))\n",
" \n",
" slim_edge_index = imp[slim_edge_index.flatten()].view_as(slim_edge_index)\n",
" return node_mask, edge_mask, slim_edge_index"
]
},
{
"cell_type": "code",
"execution_count": 401,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([200])\n",
"torch.Size([100])\n",
"torch.Size([2, 35])\n"
]
}
],
"source": [
"src_mask = torch.rand(num_nodes) < 0.3\n",
"edge_index = torch.randint(num_nodes, size=(2, num_edges), dtype=torch.long)\n",
"node_mask, edge_mask, slim_edge_index = slim_local_graph(src_mask, edge_index)\n",
"print(node_mask.size())\n",
"print(edge_mask.size())\n",
"print(slim_edge_index.size())"
]
},
{
"cell_type": "code",
"execution_count": 484,
"metadata": {},
"outputs": [],
"source": [
"net = GATConv(3, num_features, num_heads)\n",
"att = SoftmaxSum(num_nodes, num_edges, num_heads, num_features)\n",
"net.set_softmax_sum(att)\n",
"\n",
"snap_net = deepcopy(net)"
]
},
{
"cell_type": "code",
"execution_count": 485,
"metadata": {},
"outputs": [],
"source": [
"x1 = torch.randn(num_nodes, 3)\n",
"x0 = torch.randn(num_nodes, 3)\n",
"edge_index = torch.randint(num_nodes, size=(2, num_edges), dtype=torch.long)\n",
"\n",
"net.update_softmax_snapshot(x0, edge_index, snap_net)"
]
},
{
"cell_type": "code",
"execution_count": 511,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.1482, -0.0744, -0.0213, ..., -0.0421, -0.1849, -0.1039],\n",
" [-0.0075, 0.0321, -0.0046, ..., 0.0154, 0.0506, -0.0110],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" ...,\n",
" [ 0.0878, -0.0887, -0.0550, ..., -0.0982, -0.1763, -0.0175],\n",
" [ 0.1193, 0.0695, -0.0637, ..., 0.0060, 0.0391, -0.1508],\n",
" [ 0.0294, 0.0186, 0.0140, ..., 0.0327, 0.0128, -0.0487]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 511,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = net.forward(x1, x0, edge_index, snap_net, update_att=True)\n",
"x0 = x1\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.core import NodeProbe
from starrygl.nn.convs.s_gat_conv import GATConv
from starrygl.graph import DistGraph
from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
class Net(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
heads: int = 8,
) -> None:
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
last_ch = in_channels
self.convs = nn.ModuleList()
for i in range(num_layers):
out_ch = out_channels if i == num_layers - 1 else hidden_channels
self.convs.append(GATConv(last_ch, out_ch, heads))
last_ch = out_ch
def forward(self, g: DistGraph, probe: NodeProbe):
for i in range(self.num_layers):
shrink = g.cache.layers[i].get_shrink_data()
if i == 0:
x = g.cache.fetch_pull_tensor(i)
else:
# x = torch.ones(dst_idx.size(0), self.hidden_channels, device=x.device)
x, _ = g.cache.update_cache(i, x, dst_idx)
dst_idx = shrink.dst_idx
# print(shrink.edge_index[1].unique().size(0), shrink.dst_size)
# print(x.size(0), shrink.edge_index[0].unique().size(0), shrink.src_size)
# e = x[shrink.edge_index[0]]
# from torch_scatter import scatter
from starrygl.utils.printer import main_print
main_print(x.requires_grad)
# print(e.size(0), shrink.edge_index[1].size(0), shrink.dst_size, shrink.edge_index[1].max().item())
# scatter(e, shrink.edge_index[1], dim=0, dim_size=shrink.dst_size)
# return
x = self.convs[i](shrink, x)
x = F.relu(x)
x = probe.apply(i, x, dst_idx)
return x
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(
ids=pdata.ids,
edge_index=pdata.edge_index,
num_features=64,
num_layers=3,
cache_device="cpu"
).to(device)
cached_data_0, _ = g.route.forward_a2a(pdata.x).get()
g.cache.replace_layer_data(0, cached_data_0.to(g.cache.cache_device))
probe = NodeProbe(
g.dst_size,
num_layers=3,
num_samples=128,
).to(device)
probe.assign_message_cache(g.cache)
probe.layers[-1].warmup_sample()
net = Net(pdata.num_features, 64, pdata.num_classes, num_layers=3).to(device)
net(g, probe).sum().backward()
\ No newline at end of file
import torch
import torch.distributed as dist
from starrygl.parallel import init_process_group
from typing import *
if __name__ == "__main__":
device = init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 4
xs = [torch.randn(100+i*10, device=device) for i in range(world_size)]
send_sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long, device=device)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
\ No newline at end of file
# import torch
# from starrygl.core.acopy import get_executor, AsyncCopyWorkBase, List
# if __name__ == "__main__":
# data = torch.arange(1024*1024*1024, dtype=torch.float32).view(-1, 1).repeat(1, 8)
# # buffer = torch.arange(1000, dtype=torch.float32).view(-1, 1).repeat(1, 8).pin_memory()
# index = torch.arange(1024*1024 * 256, dtype=torch.long).cuda()
# values = torch.ones(1024*1024 * 256, 8, dtype=torch.float32).cuda()
# pull_works: List[AsyncCopyWorkBase] = []
# push_works: List[AsyncCopyWorkBase] = []
# for s in range(0, index.size(0), 1024 * 1024 * 256):
# t = min(s + 1024 * 1024 * 256, index.size(0))
# idx = index[s:t]
# val = values[idx]
# pullw = get_executor().async_pull(data, idx)
# pushw = get_executor().async_push(data, idx, val)
# pull_works.append(pullw)
# push_works.append(pushw)
# pull_time = 0.0
# for w in pull_works:
# val = w.get()
# pull_time += w.time_used()
# # print(val)
# push_time = 0.0
# for w in push_works:
# w.get()
# push_time += w.time_used()
# # print(data)
# print(pull_time, push_time)
\ No newline at end of file
import torch
import torch.autograd as autograd
from torch import Tensor
from typing import *
class AFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
return a, b.float()
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
print(a, b)
return a, b.bool()
if __name__ == "__main__":
x = torch.rand(10).requires_grad_()
y = torch.ones(10).bool()
a, b = AFunction.apply(x, y)
(a + b).sum().backward()
print(a, b)
print(x.grad, y.grad)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.loader import NodeLoader, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis")
loader = NodeLoader(pdata.ids, pdata.edge_index, device)
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
for handle in loader.iter(128):
with RouteContext() as ctx:
sync_print(handle.src_size, handle.dst_size)
dst_feats = handle.get_dst_feats(buffers[0]).wait()
ext_fut = handle.get_ext_feats(buffers[1])
sync_print(dst_feats.size())
src_feats, fut = handle.push_and_pull(dst_feats[:,:hidden_size], ext_fut, buffers[1])
ctx.add_futures(fut)
sync_print(src_feats.size())
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
# from starrygl.loader import DataLoader, TensorBuffers
from starrygl.loader.route import Route
from starrygl.loader.buffer import TensorBuffer
from starrygl.loader.utils import init_local_edge_index, local_partition_fn
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
# def get_route_table(device = "cpu") -> Tuple[RouteTable, Tensor]:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
# edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
# src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
# return RouteTable(src_ids, ids.size(0)), local_edge_index
# def get_features(device = "cpu") -> Tensor:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
route, local_edge_index = Route.from_edge_index(pdata.ids, pdata.edge_index)
# route_table = RouteTable(src_ids, dst_size)
# sync_print(route_table.fw_0)
# route_table = get_route_table()[0]
# for name, buffer in route_table.named_buffers():
# if name.startswith("fw_") or name.startswith("bw_"):
# sync_print(name, buffer.tolist())
tb = TensorBuffer(pdata.num_features, route.src_size, route).cpu()
print(tb.data.device, pdata.x.device)
tb.data[:pdata.num_nodes] = pdata.x
dist.barrier()
tb.broadcast()
print(tb.data[pdata.num_nodes:].sum(dim=-1))
# assert (tb.data[:pdata.num_nodes] == pdata.x).all()
# local_src_ids = torch.arange(route.src_size)
# for local_ids, remote_ids in route.parts_iter(local_src_ids):
# sync_print(local_ids.size(), remote_ids.size())
# router, edge_index = get_route_table()
# x = get_features()[:,None]
# buffer = TensorBuffers(router, channels=[1])
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
# main_print("+"*64)
# buffer.zero_grad()
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast2()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
rpc.shutdown()
# buffer.remote_get(0, 0, async_op=True).then()
# num_parts = 50
# node_parts = local_partition_fn(dst_size, local_edge_index, num_parts)
# feat_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# grad_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# node_time = torch.rand(dst_size)
# edge_time = torch.rand(local_edge_index.size(1))
# edge_attr = torch.randn(local_edge_index.size(1), 10)
# loader = DataLoader(node_parts, feat_buffer, grad_buffer, local_edge_index, edge_attr, edge_time, node_time)
# feat_buffer.set(0, None, collect_feat0(src_ids, src_ids[:dst_size], pdata.x))
# for handle, edge_index, edge_attr in loader.iter(2, layer_id=0, device=device, filter=lambda x: x < 0.5):
# print(handle.feat0.sum(dim=-1))
# print(edge_index)
# print(edge_attr.sum(dim=-1))
# handle.push_and_pull(handle.feat0[:handle.dst_size], 0)
# assert (handle.fetch_feat() == handle.feat0).all()
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="mpi")
sync_print(dist.get_backend())
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from multiprocessing.pool import ThreadPool
num_threads = 1
num_tasks = 5
pool = ThreadPool(num_threads)
if __name__ == "__main__":
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
group = dist.new_group()
stream = torch.cuda.Stream(device)
handles = []
for i in range(num_tasks):
x = torch.ones(10, dtype=torch.float32, device=device)
def run(x):
# stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
group.allreduce(x)
return x
stream.wait_stream(torch.cuda.current_stream())
handles.append(pool.apply_async(run, (x,)))
x = torch.ones(3, dtype=torch.float32, device=device)
dist.all_reduce(x)
sync_print(x.tolist())
for i in range(num_tasks):
main_print("="*64)
x = handles[i].get()
assert (handles[i].get() == x).all()
sync_print(x.tolist())
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from starrygl.core.gather import Gather
from starrygl.core.straight import Straight
from starrygl.parallel import init_process_group, compute_degree
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_distgraph(device = "cpu") -> DistGraph:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
return DistGraph(ids, edge_index)
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(g: DistGraph, x: Tensor) -> Tensor:
edge_index = g.edge_index
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
return x
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
g = get_distgraph(device)
# sync_print(g.src_ids.tolist())
# main_print("="*64)
# sync_print(g.dst_ids.tolist())
# main_print("="*64)
# sync_print(g.route.forward_routes[0])
# main_print("="*64)
in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
# straight = Straight(g.dst_size, num_samples=1)
# gather = Gather(g.src_size).to(device)
# sync_print(g)
# sync_print(gather.get_buffer("last_embd").tolist())
# main_print("="*64)
sync_print(x.tolist(), g.route.src_size)
main_print("="*64)
a = torch.arange(g.route.src_size, dtype=torch.long, device=device)
# a = None
# z, a = gather(x, None, route=g.route)
z, a = g.route.apply(x, a, {}, async_op=False)
sync_print(z.tolist(), "" if a is None else a.tolist(), g.route.dst_size)
main_print("="*64)
sync_print(g.edge_index[0])
main_print("="*64)
z = reduce(g, z)
sync_print(z.tolist())
main_print("="*64)
# z = straight(x, a, g)
# sync_print(z.tolist())
# main_print("="*64)
loss = z.sum()
sync_print(loss.item())
main_print("="*64)
loss.backward()
sync_print(x.grad.tolist())
main_print("="*64)
sync_print(f"time_used: {g.route.total_time_used}ms")
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.route import Route
from starrygl.graph.utils import init_local_edge_index
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_route(device = "cpu") -> Tuple[Route, Tensor]:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
return Route(ids, src_ids), local_edge_index
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(x: Tensor, edge_index: Tensor, dim_size: int) -> Tensor:
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=dim_size)
return x
if __name__ == "__main__":
device = init_process_group(backend="mpi")
route, edge_index = get_route(device)
src_size = route.dst_size
dst_size = route.src_size
# in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
sync_print(x.tolist())
main_print("="*64)
z, _ = route.apply(x, None)
sync_print(z.tolist())
main_print("="*64)
h = reduce(z, edge_index, dst_size)
sync_print(h.tolist())
main_print("="*64)
h.sum().backward()
sync_print(x.grad.tolist())
main_print("="*64)
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group, all_gather_remote_objects
from starrygl.utils.printer import main_print, sync_print
class A:
def __init__(self, device) -> None:
self.id = rpc.get_worker_info().id
self.data = torch.tensor([self.id], device=device)
@staticmethod
def get_data(rref: rpc.RRef) -> Tensor:
self: A = rref.local_value()
return self.data
def set_remote_a(self, *rrefs: rpc.RRef):
self.rrefs = tuple(rrefs)
def display(self):
sync_print(f"id = {self.id}")
def gather_data(self, device) -> List[Tensor]:
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
f = rpc.rpc_async(self.rrefs[i].owner(), A.get_data, args=(self.rrefs[i],))
futs.append(f)
results: List[Tensor] = []
for f in futs:
f.wait()
results.append(f.value())
return results
if __name__ == "__main__":
device = init_process_group("nccl")
a = A(device)
a.set_remote_a(*all_gather_remote_objects(a))
a.display()
for t in a.gather_data(device):
sync_print(t, t.device)
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.gather import Gather
from starrygl.parallel import init_process_group
from starrygl.parallel.sync_bn import SyncBatchNorm
from starrygl.graph import DistGraph, EID, SID, DID
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
device = init_process_group(backend="nccl")
bn = SyncBatchNorm(100).to(device)
bn.train()
x = torch.randn(10, 100, dtype=torch.float64).to(device).requires_grad_()
sync_print(x)
from torch.autograd import gradcheck
sync_print(gradcheck(lambda x: bn(x), x))
# sync_print(bn(x))
import torch
# from .lib import libstarrygl_ops as ops
\ No newline at end of file
......@@ -11,7 +11,11 @@ def all_to_all(
group: Optional[Any] = None,
):
assert len(output_tensor_list) == len(input_tensor_list)
backend = dist.get_backend()
if group is None:
group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
......@@ -25,7 +29,7 @@ def all_to_all(
group=group,
)
else:
assert backend == "gloo"
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -41,58 +45,3 @@ def all_to_all(
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
\ No newline at end of file
# def with_nccl() -> bool:
# return dist.get_backend() == "nccl"
# def with_gloo() -> bool:
# return dist.get_backend() == "gloo"
# class Works:
# def __init__(self, *works) -> None:
# self._works = list(works)
# self._waited = False
# def wait(self) -> None:
# assert not self._waited
# for w in self._works:
# if w is None:
# continue
# w.wait()
# self._waited = True
# def push(self, *works) -> None:
# assert not self._waited
# self._works.extend(works)
# def all_to_all(
# output_tensor_list: List[Tensor],
# input_tensor_list: List[Tensor],
# group: Optional[Any] = None,
# ) -> Works:
# assert len(output_tensor_list) == len(input_tensor_list)
# if with_nccl():
# work = dist.all_to_all(
# output_tensor_list=output_tensor_list,
# input_tensor_list=input_tensor_list,
# group=group,
# async_op=True,
# )
# return Works(work)
# elif with_gloo():
# rank = dist.get_rank()
# world_size = dist.get_world_size()
# works = Works()
# for i in range(1, world_size):
# send_i = (rank + i) % world_size
# recv_i = (rank - i + world_size) % world_size
# send_w = dist.isend(input_tensor_list[send_i], send_i, group=group)
# recv_w = dist.irecv(output_tensor_list[recv_i], recv_i, group=group)
# works.push(recv_w, send_w)
# output_tensor_list[rank][:] = input_tensor_list[rank]
# return works
# else:
# backend = dist.get_backend()
# raise RuntimeError(f"unsupported backend: {backend}")
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from .context import DistributedContext
from .utils import (
TensorAccessor,
DistributedTensor,
DistIndex,
)
def init_distributed_context(backend: str = "gloo") -> DistributedContext:
return DistributedContext.init(backend)
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
class BatchWork:
def __init__(self, works, buffer_tensor_list) -> None:
if isinstance(works, (list, tuple)):
assert len(works) // 2 == len(buffer_tensor_list)
self._works = works
self._buffer_tensor_list = buffer_tensor_list
else:
assert self._buffer_tensor_list is None
self._works = works
self._buffer_tensor_list = None
def wait(self):
if self._works is None:
return
if isinstance(self._works, (list, tuple)):
for i, w in enumerate(self._works):
w.wait()
if i % 2 == 0:
continue
out, buf = self._buffer_tensor_list[i // 2]
if buf is not None:
out.copy_(buf)
else:
self._works.wait()
self._works = None
self._buffer_tensor_list = None
def all_to_all_v(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
async_op: bool = False,
):
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
assert len(output_tensor_list) == world_size
assert len(input_tensor_list) == world_size
if group is None:
group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork(work, None) if async_op else None
elif backend == "mpi":
work = dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
async_op=async_op,
)
return BatchWork(work, None) if async_op else None
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
p2p_op_works = []
buffer_tensor_list = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
send_t = input_tensor_list[send_i]
recv_t = output_tensor_list[recv_i]
if send_t.is_cuda:
send_t = send_t.cpu()
if recv_t.is_cuda:
recv_b = torch.empty_like(recv_t, device="cpu")
buffer_tensor_list.append((recv_t, recv_b))
else:
recv_b = recv_t
buffer_tensor_list.append((recv_t, None))
p2p_op_works.extend([
dist.isend(send_t, send_i, group=group),
dist.irecv(recv_b, recv_i, group=group),
])
work = BatchWork(p2p_op_works, buffer_tensor_list)
output_tensor_list[rank].copy_(input_tensor_list[rank])
if async_op:
return work
work.wait()
# else:
# assert backend == "gloo", f"backend must be nccl, mpi or gloo"
# rank = dist.get_rank(group)
# world_size = dist.get_world_size(group)
# p2p_op_list: List[dist.P2POp] = []
# for i in range(1, world_size):
# send_i = (rank + i) % world_size
# recv_i = (rank - i + world_size) % world_size
# p2p_op_list.extend([
# dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
# dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
# ])
# dist.batch_isend_irecv(p2p_op_list)
# output_tensor_list[rank][:] = input_tensor_list[rank]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
import logging
from .cclib import all_to_all_v
from .rpc import rpc_remote_call
__all__ = [
"DistributedContext",
]
class DistributedContext:
@staticmethod
def init(
backend: str,
use_gpu: Optional[bool] = None,
rpc_gpu: Optional[bool] = None,
) -> 'DistributedContext':
if DistributedContext.is_initialized():
raise RuntimeError("not allowed to call init method twice.")
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is None:
logging.warning(f"LOCAL_RANK has not been set, using the default value 0.")
os.environ["LOCAL_RANK"] = local_rank = "0"
local_rank = int(local_rank)
backend = backend.lower()
if use_gpu is None:
use_gpu = False
if backend == "nccl" or backend == "mpi":
use_gpu = True
else:
use_gpu = bool(use_gpu)
if rpc_gpu is None:
rpc_gpu = use_gpu
else:
rpc_gpu = bool(rpc_gpu)
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext(
backend=backend,
ccl_init_method=ccl_init_url,
rpc_init_method=rpc_init_url,
rank=rank, world_size=world_size,
local_rank=local_rank,
use_gpu=use_gpu, rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
return ctx
@staticmethod
def get_default_context() -> 'DistributedContext':
ctx = _get_default_dist_context()
if ctx is None:
raise RuntimeError("please call the init method first.")
return ctx
@staticmethod
def is_initialized() -> bool:
return _get_default_dist_context() is not None
def __init__(self,
backend: str,
ccl_init_method: str,
rpc_init_method: str,
rank: int, world_size: int,
local_rank: int,
use_gpu: bool, rpc_gpu: bool,
) -> None:
if use_gpu:
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
init_method=ccl_init_method,
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method
if use_gpu and rpc_gpu:
device_maps = torch.zeros(world_size, dtype=torch.long, device=device)
device_maps[rank] = local_rank
dist.all_reduce(device_maps, op=dist.ReduceOp.SUM)
for i, dev in enumerate(device_maps.tolist()):
rpc_backend_options.set_device_map(
to=f"worker{i}",
device_map={local_rank: dev},
)
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
self._local_rank = local_rank
self._compute_device = device
self.__temp_ag_remote_object: Optional[rpc.RRef] = None
def shutdown(self):
rpc.shutdown()
@property
def rank(self) -> int:
return dist.get_rank()
@property
def world_size(self) -> int:
return dist.get_world_size()
@property
def local_rank(self) -> int:
return self._local_rank
@property
def device(self) -> torch.device:
return self._compute_device
def get_default_group(self):
return dist.distributed_c10d._get_default_group()
def get_default_store(self):
return dist.distributed_c10d._get_default_store()
def get_worker_info(self, rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
def remote_call(self, method, rref: rpc.RRef, *args, **kwargs):
return rpc_remote_call(method, rref, *args, **kwargs)
def all_to_all_v(self,
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
async_op: bool = False,
):
return all_to_all_v(
output_tensor_list,
input_tensor_list,
group=group,
async_op=async_op,
)
def all_to_all_g(self,
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
async_op: bool = False,
):
send_sizes = [t.size(0) for t in input_tensor_list]
recv_sizes = self.get_all_to_all_recv_sizes(send_sizes, group)
output_tensor_list: List[Tensor] = []
for s, t in zip(recv_sizes, input_tensor_list):
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device),
)
work = all_to_all_v(
output_tensor_list,
input_tensor_list,
group=group,
async_op=async_op,
)
if async_op:
assert work is not None
return output_tensor_list, work
else:
return output_tensor_list
def get_all_to_all_recv_sizes(self,
send_sizes: List[int],
group: Optional[Any] = None,
) -> List[int]:
world_size = dist.get_world_size(group)
assert len(send_sizes) == world_size
if dist.get_backend(group) == "gloo":
send_t = torch.tensor(send_sizes, dtype=torch.long)
else:
send_t = torch.tensor(send_sizes, dtype=torch.long, device=self.device)
recv_t = torch.empty_like(send_t)
dist.all_to_all_single(recv_t, send_t, group=group)
return recv_t.tolist()
def all_gather_remote_objects(self, obj: Any) -> List[rpc.RRef]:
if not isinstance(obj, rpc.RRef):
obj = rpc.RRef(obj)
self.__temp_ag_remote_object = obj
dist.barrier()
futs: List[torch.futures.Future] = []
for i in range(self.world_size):
info = rpc.get_worker_info(f"worker{i}")
futs.append(rpc.rpc_async(info, DistributedContext._remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
self.__temp_ag_remote_object = None
return rrefs
@staticmethod
def _remote_object():
ctx = DistributedContext.get_default_context()
return ctx.__temp_ag_remote_object
def sync_print(self, *args, **kwargs):
for i in range(self.world_size):
if i == self.rank:
print(f"rank {self.rank}:", *args, **kwargs)
dist.barrier()
def main_print(self, *args, **kwargs):
if self.rank == 0:
print(*args, **kwargs)
dist.barrier()
_DEFAULT_DIST_CONTEXT: Optional['DistributedContext'] = None
def _get_default_dist_context():
global _DEFAULT_DIST_CONTEXT
return _DEFAULT_DIST_CONTEXT
def _set_default_dist_context(ctx):
global _DEFAULT_DIST_CONTEXT
_DEFAULT_DIST_CONTEXT = ctx
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
def rpc_remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), rpc_method_call, args=args, kwargs=kwargs)
def rpc_method_call(method, rref: rpc.RRef, *args, **kwargs):
self = rref.local_value()
return method(self, *args, **kwargs)
from typing import Any
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.types import Number
from typing import *
class TensorAccessor:
def __init__(self, data: Tensor) -> None:
from .context import DistributedContext
self._data = data
self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data)
@property
def data(self):
return self._data
@property
def rref(self):
return self._rref
@property
def ctx(self):
return self._ctx
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
def async_index_select(self, dim: int, index: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(Tensor.index_select, rref, dim=dim, index=index)
def async_index_copy_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_copy_, rref, dim=dim, index=index, source=source)
def async_index_add_(self, dim: int, index: Tensor, source: Tensor, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_add_, rref, dim=dim, index=index, source=source)
def async_index_fill_(self, dim: int, index: Tensor, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(TensorAccessor._index_fill_, rref, dim=dim, index=index, value=value)
def async_fill_(self, value: Number, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
return self.ctx.remote_call(TensorAccessor._fill_, rref, value=value)
def async_zero_(self, rref: Optional[rpc.RRef] = None):
if rref is None:
rref = self.rref
self.ctx.remote_call(TensorAccessor._zero_, rref)
@staticmethod
def _index_copy_(self: Tensor, dim: int, index: Tensor, source: Tensor):
self.index_copy_(dim=dim, index=index, source=source)
@staticmethod
def _index_add_(self: Tensor, dim: int, index: Tensor, source: Tensor):
self.index_add_(dim=dim, index=index, source=source)
@staticmethod
def _index_fill_(self: Tensor, dim: int, index: Tensor, value: Number):
self.index_fill_(dim=dim, index=index, value=value)
@staticmethod
def _fill_(self: Tensor, value: Number):
self.fill_(value=value)
@staticmethod
def _zero_(self: Tensor):
self.zero_()
class DistInt:
def __init__(self, sizes: List[int]) -> None:
self._data = tuple([int(t) for t in sizes])
self._total = sum(self._data)
def __getitem__(self, idx: int) -> int:
return self._data[idx]
def __call__(self) -> int:
return self._total
class DistIndex:
def __init__(self, index: Tensor, part_ids: Optional[Tensor] = None) -> None:
if part_ids is None:
self.data = index.long()
else:
index, part_ids = index.long(), part_ids.long()
self.data = (index & 0xFFFFFFFFFFFF) | ((part_ids & 0xFFFF) << 48)
@property
def loc(self) -> Tensor:
return self.data & 0xFFFFFFFFFFFF
@property
def part(self) -> Tensor:
return (self.data >> 48).int() & 0xFFFF
@property
def dist(self) -> Tensor:
return self.data
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs()
# self.num_parts = len(self.rrefs)
local_sizes = []
for rref in self.rrefs:
n = self.ctx.remote_call(Tensor.size, rref, dim=0).wait()
local_sizes.append(n)
self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs))
@property
def ctx(self):
return self.accessor.ctx
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
def callback(fs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
result: Optional[Tensor] = None
for i, f in enumerate(fs.value()):
t: Tensor = f.value()
if result is None:
result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
)
result[part_idx == i] = t
return result
return torch.futures.collect_all(futs).then(callback)
def index_copy_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
def index_add_(self, dist_index: Union[Tensor, DistIndex], source: Tensor):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
def index_fill_(self, dist_index: Union[Tensor, DistIndex], value: Number):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
mask = part_idx == i
f = self.accessor.async_index_fill_(0, index[mask], value, self.rrefs[i])
futs.append(f)
return torch.futures.collect_all(futs)
# from .distgraph import DistGraph
\ No newline at end of file
from .route import Route
from .utils import init_vc_edge_index
\ No newline at end of file
import torch
import torch.autograd as autograd
import torch.distributed as dist
from multiprocessing.pool import ThreadPool
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
class Route:
@staticmethod
def from_raw_indices(
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool = True,
group: Any = None,
) -> 'Route':
fw_tables, bw_tables = Route._build_route_tables(
src_ids=src_ids, dst_ids=dst_ids,
bipartite=bipartite, group=group,
)
return Route(
src_len=src_ids.size(0),
dst_len=dst_ids.size(0),
**Route.__tables_to_indptr(fw_tables, bw_tables),
group=group,
)
def new_subroute(self, dst_mask: Tensor):
if dst_mask.dtype != torch.bool:
dst_mask = Route.__expand_index_to_mask(dst_mask, self.dst_len)
assert dst_mask.size(0) == self.dst_len
fw_dst_masks, work = self.all_to_all_fw(
input_tensor_list=[dst_mask[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=True,
)
fw_tables: List[Tensor] = []
for i in range(self.num_parts):
m = dst_mask[self.fw_table(i)[0]]
fw_tables.append(self.fw_table(i)[:,m])
work.wait()
bw_tables: List[Tensor] = []
for i, m in enumerate(fw_dst_masks):
bw_tables.append(self.bw_table(i)[:,m])
return Route(
src_len=self.src_len,
dst_len=self.dst_len,
**Route.__tables_to_indptr(fw_tables, bw_tables),
group=self.group,
)
def reverse_route(self):
return Route(
src_len=self.dst_len,
dst_len=self.src_len,
fw_ptr=self._bw_ptr, fw_ind=self._bw_ind,
bw_ptr=self._fw_ptr, bw_ind=self._fw_ind,
group=self.group,
)
def __init__(self,
src_len: int, dst_len: int,
fw_ptr: List[int], fw_ind: Tensor,
bw_ptr: List[int], bw_ind: Tensor,
group: Any,
) -> None:
self._ctx = DistributedContext.get_default_context()
self._src_len = src_len
self._dst_len = dst_len
self._fw_ptr = fw_ptr
self._fw_ind = fw_ind
self._bw_ptr = bw_ptr
self._bw_ind = bw_ind
self._group = group
@property
def ctx(self):
return self._ctx
@property
def group(self):
return self._group
@property
def src_len(self):
return self._src_len
@property
def dst_len(self):
return self._dst_len
@property
def part_id(self) -> int:
return dist.get_rank(self.group)
@property
def num_parts(self) -> int:
return dist.get_world_size(self.group)
def to(self, device: Any):
self._fw_ind = self._fw_ind.to(device)
self._bw_ind = self._bw_ind.to(device)
return self
def fw_table(self, i: int):
return self._fw_ind[:,self._fw_ptr[i]:self._fw_ptr[i+1]]
def bw_table(self, i: int):
return self._bw_ind[:,self._bw_ptr[i]:self._bw_ptr[i+1]]
def apply(self, data: Tensor) -> Tensor:
return RouteAlltoAll.apply(data, self)
def fw_tensor(self, data: Tensor) -> Tensor:
assert data.size(0) == self.dst_len
xs = self.all_to_all_fw(
[data[self.fw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros(
self.src_len, *data.shape[1:],
dtype=data.dtype, device=data.device,
)
for i, t in enumerate(xs):
out[self.bw_table(i)[0]] += t
return out
def bw_tensor(self, data: Tensor) -> Tensor:
assert data.size(0) == self.src_len
xs = self.all_to_all_bw(
[data[self.bw_table(i)[0]] for i in range(self.num_parts)],
async_op=False,
)
out = torch.zeros(
self.dst_len, *data.shape[1:],
dtype=data.dtype, device=data.device,
)
for i, t in enumerate(xs):
out[self.fw_table(i)[0]] += t
return out
def all_to_all_fw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = []
for i in range(self.num_parts):
t = input_tensor_list[i]
assert t.size(0) == self._fw_ptr[i+1] - self._fw_ptr[i]
s = self._bw_ptr[i+1] - self._bw_ptr[i]
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device)
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
)
if async_op:
return output_tensor_list, work
else:
return output_tensor_list
def all_to_all_bw(self, input_tensor_list: List[Tensor], async_op: bool = False):
output_tensor_list: List[Tensor] = []
for i in range(self.num_parts):
t = input_tensor_list[i]
assert t.size(0) == self._bw_ptr[i+1] - self._bw_ptr[i]
s = self._fw_ptr[i+1] - self._fw_ptr[i]
output_tensor_list.append(
torch.empty(s, *t.shape[1:], dtype=t.dtype, device=t.device)
)
work = self.ctx.all_to_all_v(
output_tensor_list,
input_tensor_list,
group=self.group, async_op=async_op,
)
if async_op:
return output_tensor_list, work
else:
return output_tensor_list
@staticmethod
def _build_route_tables(
src_ids: Tensor,
dst_ids: Tensor,
bipartite: bool,
group: Any,
) -> Tuple[List[Tensor], List[Tensor]]:
assert src_ids.dtype == torch.long
assert dst_ids.dtype == torch.long
assert src_ids.dim() == 1
assert dst_ids.dim() == 1
ctx = DistributedContext.get_default_context()
src_len = src_ids.size(0)
dst_len = dst_ids.size(0)
if not bipartite:
assert dst_len <= src_len
assert (src_ids[:dst_len] == dst_ids).all()
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
ikw = dict(dtype=torch.long, device=dst_ids.device)
all_dst_lens: List[int] = [None] * world_size
dist.all_gather_object(all_dst_lens, dst_len, group=group)
num_nodes = torch.zeros(1, **ikw)
if src_ids.numel() > 0:
num_nodes = src_ids.max().max(num_nodes)
if dst_ids.numel() > 0:
num_nodes = dst_ids.max().max(num_nodes)
dist.all_reduce(num_nodes, op=dist.ReduceOp.MAX, group=group)
num_nodes = num_nodes.item() + 1
# src_ids -> local index
smp: Tensor = torch.empty(num_nodes, **ikw).fill_((2**62-1)*2+1)
smp[src_ids] = torch.arange(src_len, **ikw)
# async fetch dst_ids from other partitions
all_dst_ids: List[Tensor] = [None] * world_size
all_dst_get = [None] * world_size
def async_fetch(i: int):
if i == rank:
all_dst_ids[i] = dst_ids
else:
all_dst_ids[i] = torch.empty(all_dst_lens[i], **ikw)
all_dst_get[i] = dist.broadcast(
all_dst_ids[i], src=i, async_op=True, group=group
)
fw_tables: List[Tensor] = []
bw_tables: List[Tensor] = []
xmp = torch.empty_like(smp)
for i in range(world_size):
# prefetch dst_ids
if i == 0:
async_fetch(i)
if i + 1 < world_size:
async_fetch(i + 1)
all_dst_get[i].wait()
ids = all_dst_ids[i]
xmp.fill_(0)
xmp[ids] += 1
xmp[src_ids] += 1
ind = torch.where(xmp > 1)[0]
# dst_ids -> local index
xmp.fill_((2**62-1)*2+1)
xmp[ids] = torch.arange(ids.size(0), **ikw)
# remap src_ids and dst_ids
src_ind = smp[ind]
dst_ind = xmp[ind]
# 此时只有bw_route是正常的,fw_route需要发送给src_ids所在分区
fw_tables.append(torch.vstack([dst_ind, src_ind]))
bw_tables.append(torch.vstack([src_ind, dst_ind]))
fw_tables = [t.t().contiguous() for t in fw_tables]
fw_tables = ctx.all_to_all_g(fw_tables, group=group)
fw_tables = [t.t().contiguous() for t in fw_tables]
# 非二分图,每个点添加自环
if not bipartite:
rank_ind = torch.arange(dst_len, **ikw)
fw_tables[rank] = bw_tables[rank] = torch.vstack([rank_ind, rank_ind])
return fw_tables, bw_tables
@staticmethod
def __expand_index_to_mask(index: Tensor, len: int) -> Tensor:
mask = torch.zeros(len, dtype=torch.bool, device=index.device)
mask[index] = True
return mask
@staticmethod
def __tables_to_indptr(
fw_tables: List[Tensor],
bw_tables: List[Tensor],
):
fw_ptr: List[int] = [0]
for t in fw_tables:
last_n = fw_ptr[-1]
fw_ptr.append(last_n + t.size(1))
fw_ind = torch.cat(fw_tables, dim=1)
bw_ptr: List[int] = [0]
for t in bw_tables:
last_n = bw_ptr[-1]
bw_ptr.append(last_n + t.size(1))
bw_ind = torch.cat(bw_tables, dim=1)
return {
"fw_ptr": fw_ptr, "bw_ptr": bw_ptr,
"fw_ind": fw_ind, "bw_ind": bw_ind,
}
class RouteAlltoAll(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
route: Route,
):
ctx.saved_route = route
return route.fw_tensor(x)
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
) -> Tensor:
route: Route = ctx.saved_route
return route.bw_tensor(grad)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# def init_local_edge_index(
# dst_ids: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# max_ids = calc_max_ids(dst_ids, edge_index)
# ikw = dict(dtype=torch.long, device=dst_ids.device)
# xmp = torch.zeros(max_ids + 1, **ikw)
# # 判断是不是点分割且所有边被划分到目标点所在分区
# xmp[edge_index[1].unique()] += 0b01
# xmp[dst_ids.unique()] += 0b10
# if not (xmp != 0x01).all():
# raise RuntimeError(f"must be vertex-cut partition graph")
# if bipartite:
# src_ids = edge_index[0].unique()
# else:
# # 假设是同构图
# # src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
# xmp.fill_(0)
# xmp[edge_index[0]] = 1
# xmp[dst_ids] = 0
# src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# # 计算局部索引
# xmp.fill_((2**62-1)*2+1)
# xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
# src = xmp[edge_index[0]]
# xmp.fill_((2**62-1)*2+1)
# xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
# dst = xmp[edge_index[1]]
# local_edge_index = torch.vstack([src, dst])
# return src_ids, local_edge_index
# def calc_max_ids(*ids: Tensor) -> int:
# x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
# return max(*x)
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def init_vc_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = True,
) -> Tuple[Tensor, Tensor]:
ikw = dict(dtype=torch.long, device=dst_ids.device)
local_num_nodes = torch.zeros(1, **ikw)
if dst_ids.numel() > 0:
local_num_nodes = dst_ids.max().max(local_num_nodes)
if edge_index.numel() > 0:
local_num_nodes = edge_index.max().max(local_num_nodes)
local_num_nodes = local_num_nodes.item() + 1
xmp: Tensor = torch.zeros(local_num_nodes, **ikw)
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment