Commit 6cb3e218 by zlj

add cache

parent 54424543
a.out
third_party
install.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......@@ -26,6 +30,10 @@ share/python-wheels/
*.egg
MANIFEST
# IDE temporary files (generated by IDEs like CLion, etc.)
.idea/
cmake-build-*/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
......@@ -162,8 +170,8 @@ cython_debug/
*.pt
/nohup.out
/cora
/dataset
/test_*
/*.ipynb
/third_party
!/third_party/ldg_partition
[submodule "third_party/ldg_partition"]
path = third_party/ldg_partition
url = https://gitee.com/onlynagesha/graph-partition-v4
set(CMAKE_SOURCE_DIR "/mnt/data/")
cmake_minimum_required(VERSION 3.15)
project(starrygl_ops VERSION 0.1)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" OFF)
set(WITH_LDG OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
......@@ -38,15 +38,15 @@ if(WITH_CUDA)
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm SHARED ${UVM_SRCS})
target_link_libraries(uvm PRIVATE ${TORCH_LIBRARIES})
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
add_definitions(-DWITH_METIS)
set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
message(${METIS_DIR})
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
......@@ -55,10 +55,10 @@ if(WITH_METIS)
include_directories(${METIS_INCLUDE_DIRS})
add_library(metis SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis PRIVATE ${METIS_LIBRARIES})
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
......@@ -69,15 +69,28 @@ if(WITH_MTMETIS)
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_PARTITIONS)
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
message("${WITH_LDG}")
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
......@@ -91,24 +104,32 @@ if(WITH_PYTHON)
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
target_link_libraries(${PROJECT_NAME} PRIVATE metis)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
......@@ -4,15 +4,13 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef WITH_CUDA
m.def("uvm_storage_new", &uvm_storage_new, "return storage of unified virtual memory");
m.def("uvm_storage_to_cuda", &uvm_storage_to_cuda, "share uvm storage with another cuda device");
m.def("uvm_storage_to_cpu", &uvm_storage_to_cpu, "share uvm storage with cpu");
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
......@@ -20,4 +18,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("cudaMemAdviseUnsetPreferredLocation", cudaMemoryAdvise::cudaMemAdviseUnsetPreferredLocation)
.value("cudaMemAdviseSetReadMostly", cudaMemoryAdvise::cudaMemAdviseSetReadMostly)
.value("cudaMemAdviseUnsetReadMostly", cudaMemoryAdvise::cudaMemAdviseUnsetReadMostly);
#endif
#ifdef WITH_METIS
m.def("metis_partition", &metis_partition, "metis graph partition");
#endif
#ifdef WITH_MTMETIS
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
#endif
#ifdef WITH_LGD
// Note: the switch WITH_MULTITHREADING=ON shall be triggered during compilation
// to enable multi-threading functionality.
m.def("ldg_partition", &ldg_partition, "(multi-threaded optionally) LDG graph partition");
#endif
}
......@@ -22,3 +22,11 @@ at::Tensor mt_metis_partition(
int64_t num_workers,
bool recursive
);
at::Tensor ldg_partition(
at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers
);
#include <torch/all.h>
#include "vertex_partition/vertex_partition.h"
#include "vertex_partition/params.h"
at::Tensor ldg_partition(at::Tensor edges,
at::optional<at::Tensor> vertex_weights,
at::optional<at::Tensor> initial_partition,
int64_t num_parts,
int64_t num_workers) {
AT_ASSERT(edges.dim() == 2);
auto edges_n_cols = edges.size(1);
AT_ASSERT(edges_n_cols >= 2 && edges_n_cols <= 3);
// Note: other checks are performed in the implementation of vp::ldg_partition function series.
auto n = edges.slice(1, 0, 2).max().item<int64_t>() + 1;
auto params = vp::LDGParams{.N = n, .K = num_parts, .openmp_n_threads = static_cast<int>(num_workers)};
auto edges_clone = edges.clone();
if (vertex_weights.has_value()) {
auto vertex_weights_clone = vertex_weights->clone();
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_v_init(edges_clone, vertex_weights_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition_v(edges_clone, vertex_weights_clone, params);
}
} else {
if (initial_partition.has_value()) {
auto initial_partition_clone = initial_partition->clone();
vp::ldg_partition_init(edges_clone, initial_partition_clone, params);
return initial_partition_clone;
} else {
return vp::ldg_partition(edges_clone, params);
}
}
}
......@@ -258,19 +258,16 @@ TemporalNeighborBlock& get_neighbors(
if(is_distinct){
for(int64_t i=0; i<num_nodes; i++){
//收集单边去重节点度
//收集去重邻居
phmap::parallel_flat_hash_set<NodeIDType> temp_s;
temp_s.insert(tnb.neighbors[i].begin(), tnb.neighbors[i].end());
tnb.neighbors_set.emplace_back(temp_s);
tnb.deg[i] = tnb.neighbors_set[i].size();
}
}
else{
for(int64_t i=0; i<num_nodes; i++){
//收集单边节点度
tnb.deg[i] = tnb.neighbors[i].size();
}
}
double end_time = omp_get_wtime();
cout<<"get_neighbors consume: "<<end_time-start_time<<"s"<<endl;
return tnb;
......
......@@ -33,16 +33,16 @@ class ParallelSampler
ret.resize(num_layers);
}
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts);
void neighbor_sample_from_nodes_static(th::Tensor nodes);
void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer);
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique);
void neighbor_sample_from_nodes_static(th::Tensor nodes, bool part_unique);
void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer, bool part_unique);
void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts);
void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer);
};
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts)
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts, optional<bool> part_unique)
{
omp_set_num_threads(threads);
if(policy == "weighted")
......@@ -60,11 +60,12 @@ void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th
neighbor_sample_from_nodes_with_before(nodes, root_ts.value());
}
else{
neighbor_sample_from_nodes_static(nodes);
bool flag = part_unique.has_value() ? part_unique.value() : true;
neighbor_sample_from_nodes_static(nodes, flag);
}
}
void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer){
void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer, bool part_unique){
py::gil_scoped_release release;
double tot_start_time = omp_get_wtime();
......@@ -73,24 +74,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
ret[cur_layer] = TemporalGraphBlock();
auto nodes_data = get_data_ptr<NodeIDType>(nodes);
vector<phmap::parallel_flat_hash_set<NodeIDType>> node_s_threads(threads);
// vector<vector<NodeIDType>> node_threads(threads);
vector<vector<NodeIDType>> node_threads(threads);
phmap::parallel_flat_hash_set<NodeIDType> node_s;
vector<vector<NodeIDType>> eid_threads(threads);//row_threads(threads),col_threads(threads);
vector<vector<NodeIDType>> eid_threads(threads);
vector<vector<NodeIDType>> src_index_threads(threads);
AT_ASSERTM(tnb.with_eid, "Tnb has no eid infomation! We need eid!");
// double start_time = omp_get_wtime();
int reserve_capacity = int(ceil(nodes.size(0) / threads)) * fanout;
#pragma omp parallel
{
int tid = omp_get_thread_num();
unsigned int loc_seed = tid;
eid_threads[tid].reserve(reserve_capacity);
src_index_threads[tid].reserve(reserve_capacity);
// node_threads[tid].reserve(reserve_capacity);
if(!part_unique)
node_threads[tid].reserve(reserve_capacity);
#pragma omp for schedule(static, int(ceil(static_cast<float>((nodes.size(0)) / threads))))
for(int64_t i=0; i<nodes.size(0); i++){
// int tid = omp_get_thread_num();
NodeIDType node = nodes_data[i];
vector<NodeIDType>& nei = tnb.neighbors[node];
vector<EdgeIDType> edge;
......@@ -98,38 +100,47 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
double s_start_time = omp_get_wtime();
if(tnb.deg[node]>fanout){
src_index_threads[tid].insert(src_index_threads[tid].end(), fanout, i);
phmap::flat_hash_set<NodeIDType> temp_s;
default_random_engine e(8);//(time(0));
uniform_int_distribution<> u(0, tnb.deg[node]-1);
while(temp_s.size()!=fanout){
// for(int i=0;i<fanout;i++){
//循环选择fanout个邻居
// uniform_int_distribution<> u(0, tnb.deg[node]-1);
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){
for(int i=0;i<fanout;i++){
//ѭ��ѡ��fanout���ھ�
NodeIDType indice;
if(policy == "weighted"){//考虑边权重信
if(policy == "weighted"){//���DZ�Ȩ����Ϣ
const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e);
}
else if(policy == "uniform"){//均匀采样
else if(policy == "uniform"){//���Ȳ���
// indice = u(e);
indice = rand_r(&loc_seed) % (tnb.deg[node]);
indice = rand_r(&loc_seed) % (nei.size());
}
auto chosen_n_iter = nei.begin() + indice;
// auto chosen_e_iter = edge.begin() + indice;
// eid_threads[tid].emplace_back(*chosen_e_iter);
// node_threads[tid].emplace_back(*chosen_n_iter);
auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //不重复
auto chosen_e_iter = edge.begin() + indice;
if(part_unique){
auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //���ظ�
eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter);
if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++;
}
}
else{
eid_threads[tid].emplace_back(*chosen_e_iter);
node_threads[tid].emplace_back(*chosen_n_iter);
}
}
if(part_unique)
src_index_threads[tid].insert(src_index_threads[tid].end(), temp_s.size(), i);
else
src_index_threads[tid].insert(src_index_threads[tid].end(), fanout, i);
}
else{
src_index_threads[tid].insert(src_index_threads[tid].end(), tnb.deg[node], i);
// node_threads[tid].insert(node_threads[tid].end(), nei.begin(), nei.end());
if(part_unique)
node_s_threads[tid].insert(nei.begin(), nei.end());
else
node_threads[tid].insert(node_threads[tid].end(), nei.begin(), nei.end());
eid_threads[tid].insert(eid_threads[tid].end(),edge.begin(), edge.end());
}
if(tid==0)
......@@ -153,22 +164,25 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
copy(eid_threads[i].begin(), eid_threads[i].end(), ret[cur_layer].eid.begin()+each_begin[i]);
// copy(node_threads[i].begin(), node_threads[i].end(), ret[cur_layer].sample_nodes.begin()+each_begin[i]);
if(!part_unique)
copy(node_threads[i].begin(), node_threads[i].end(), ret[cur_layer].sample_nodes.begin()+each_begin[i]);
copy(src_index_threads[i].begin(), src_index_threads[i].end(), ret[cur_layer].src_index.begin()+each_begin[i]);
}
if(part_unique){
for(int i = 0; i<threads; i++)
node_s.insert(node_s_threads[i].begin(), node_s_threads[i].end());
ret[cur_layer].sample_nodes.assign(node_s.begin(), node_s.end());
}
ret[0].tot_time += omp_get_wtime() - tot_start_time;
ret[0].sample_edge_num += ret[cur_layer].eid.size();
py::gil_scoped_acquire acquire;
}
void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes){
void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes, bool part_unique){
for(int i=0;i<num_layers;i++){
if(i==0) neighbor_sample_from_nodes_static_layer(nodes, i);
else neighbor_sample_from_nodes_static_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes), i);
if(i==0) neighbor_sample_from_nodes_static_layer(nodes, i, part_unique);
else neighbor_sample_from_nodes_static_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes), i, part_unique);
}
}
......@@ -215,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
}
}
else{
//可选邻居边大于扇出的话需要随机选择fanout个邻居
//��ѡ�ھӱߴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl;
......
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
Advanced Concepts
=================
.. toctree::
sampling_parallel/index
partition_parallel/index
timeline_parallel/index
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Feature Fetching
============================
\ No newline at end of file
Distributed Sampling Parallel
=============================
.. note::
基于分布式时序图采样的训练模式
.. toctree::
sampler
features
memory
Distributed Memory Updater
==========================
\ No newline at end of file
Distributed Temporal Sampling
=============================
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
starrygl.distributed
====================
.. note::
自动生辰的api文档,需要在starrygl源代码目录里添加码内注释
.. currentmodule:: starrygl.distributed
.. autosummary::
DistributedContext
\ No newline at end of file
Package References
==================
.. toctree::
distributed
\ No newline at end of file
Cheatsheets
===========
.. note::
可以放一些框架内置的GNN模型,GNN数据集清单以及默认的接口调用规则,参考
`pyg cheatsheets <https://pytorch-geometric.readthedocs.io/en/latest/cheatsheet/data_cheatsheet.html>`_
\ No newline at end of file
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import starrygl
project = 'StarryGL'
copyright = '2023, StarryGL Team'
author = 'StarryGL Team'
version = starrygl.__version__
release = starrygl.__version__
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.duration",
]
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
External Resources
==================
.. note::
!!!放置一些StarryGL相关的论文,GNN相关的项目以及有助于使用StarryGL的资源,可以参考:
`PyG Resources <https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html>`_
\ No newline at end of file
Get Started
===========
.. toctree::
install_guide
intro_example
\ No newline at end of file
Installation
============
.. note::
安装方法
\ No newline at end of file
Introduction by Example
=======================
.. note::
安装好之后,快速开始一个时序图模型训练。先展示一下基础使用,可以是单机的?
\ No newline at end of file
StarryGL Documentation
======================
.. toctree::
guide/index
tutorial/index
advanced/index
api/python/index
cheatsheets/index
external/index
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
.. * :ref:`search`
User Cases and Applications
===========================
.. note::
放一些时序GNN的用例和应用
\ No newline at end of file
Preparing the Temporal Graph Dataset
====================================
.. note::
包含从原始数据开始的数据清洗和预处理步骤,最终形成可以被StarryGL使用的数据文件
\ No newline at end of file
Distributed Training
====================
.. note::
开始介绍分布式训练相关的方法,并展示一个最简单的分布式DDP训练流程
\ No newline at end of file
Tutorials
===============
.. toctree::
intro
module
dataset
application
distributed
\ No newline at end of file
Introduction to Temporal GNN
==============================================
.. note::
简单介绍一下时序GNN,应用场景,需要解决的问题等,相当于一个总体的介绍
Creating Temporal GNN Models
============================
.. note::
介绍如何创建GNN模型,找最经典最简洁的两个例子即可。包括 **离散时间动态图模型** 模型构建和 **连续时间动态图模型**。
\ No newline at end of file
......@@ -3,7 +3,7 @@
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.8/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
&& make -j32 \
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hwj/.miniconda3/envs/pyg/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from torch import Tensor\n",
"from typing import *\n",
"\n",
"from torch_scatter import scatter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def act_blks(act_mask: Tensor, blk_size: int) -> Tensor:\n",
" assert act_mask.dtype == torch.bool\n",
" assert act_mask.dim() == 1\n",
" n = act_mask.size(0)\n",
" m = (n + blk_size - 1) // blk_size * blk_size\n",
"\n",
" blk_mask = act_mask.detach()\n",
" if n != m:\n",
" blk_mask = torch.zeros(m, dtype=act_mask.dtype, device=act_mask.device)\n",
" blk_mask[:n] = act_mask.detach()\n",
" return blk_mask.view(-1, blk_size).max(dim=-1).values"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"act_mask = torch.rand(1000) < 0.01\n",
"blk_mask = act_blks(act_mask, 64)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cached_blocks = torch.randn(blk_mask.size(0), 64, 8)\n",
"blks = cached_blocks[blk_mask]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def gather_block(blks: Tensor, blk_mask: Tensor, index: Tensor) -> Tensor:\n",
" blk_size = blks.size(1)\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
" \n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
" \n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" s = index.shape[:1] + blks.shape[2:]\n",
" data = torch.zeros(s, dtype=blks.dtype, device=blks.device)\n",
" data[ind_mask] = blks[idx0, idx1, ...]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def scatter_block(x: Tensor, index: Tensor, blk_mask: Tensor, blk_size: int, reduce: str = \"sum\") -> Tensor:\n",
" num_blks = blk_mask.count_nonzero()\n",
" blk_mask = blk_mask.type(torch.long).cumsum_(dim=0) * blk_mask - 1\n",
"\n",
" idx0 = index.div(blk_size, rounding_mode=\"floor\")\n",
" ind_mask = blk_mask[idx0] >= 0\n",
"\n",
" idx0 = blk_mask[idx0][ind_mask]\n",
" idx1 = index[ind_mask] % blk_size\n",
"\n",
" data = scatter(\n",
" src=x[ind_mask],\n",
" index=idx0*blk_size+idx1,\n",
" dim=0,\n",
" dim_size=num_blks*blk_size,\n",
" reduce=reduce,\n",
" )\n",
" return data.view(num_blks, blk_size, *x.shape[1:])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# def index_block(index: Tensor, blk_mask: Tensor, blk_size: int) -> Tensor:\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, True, True, False, True, False, False, True, False, True,\n",
" False, True, True, False, False, False, False, False, False, False,\n",
" False, True, True, False, False, False, True, False, False, True,\n",
" True, False, False, True, True, False, True, True, False, False,\n",
" True, False, True, False, False, False, False, False, False, False,\n",
" True, True, False, False, False, True, True, True, False, False,\n",
" True, False, True, False, True, False, True, False, False, True,\n",
" False, True, True, True, False, False, False, False, False, True,\n",
" False, True, False, True, True, True, False, True, False, True,\n",
" False, True, False, True, True, False, False, True, False, True])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index = torch.randint(1000, size=(100,), dtype=torch.long)\n",
"(gather_block(blks, blk_mask, index)**2).sum(dim=-1) != 0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 64, 8])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(100, 8)\n",
"scatter_block(x, index, blk_mask, 64).size()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([True, True, True, True, True, True, True, True])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = scatter_block(x, index, blk_mask, 64)\n",
"t = t.view(-1, *t.shape[2:])\n",
"act_blks((t**2).sum(dim=-1) != 0, 64)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.sort(\n",
"values=tensor([ 5, 15, 18, 19, 25, 48, 62, 64, 68, 76, 88, 92, 94, 106,\n",
" 114, 115, 120, 122, 126, 131, 135, 165, 169, 169, 177, 192, 214, 215,\n",
" 245, 249, 255, 269, 282, 286, 294, 299, 334, 358, 359, 372, 401, 403,\n",
" 430, 437, 455, 459, 459, 459, 471, 484, 490, 504, 505, 516, 528, 529,\n",
" 559, 584, 613, 614, 614, 624, 629, 637, 638, 674, 679, 701, 706, 735,\n",
" 742, 746, 748, 750, 759, 774, 774, 776, 778, 787, 803, 825, 831, 834,\n",
" 861, 862, 868, 877, 882, 889, 899, 909, 921, 933, 938, 954, 960, 969,\n",
" 981, 986]),\n",
"indices=tensor([ 6, 77, 14, 41, 43, 0, 13, 53, 38, 54, 15, 76, 98, 47, 75, 80, 8, 95,\n",
" 17, 34, 83, 87, 62, 93, 81, 20, 67, 39, 32, 5, 70, 4, 89, 72, 37, 64,\n",
" 94, 42, 57, 29, 56, 21, 60, 7, 45, 59, 86, 96, 49, 92, 28, 88, 24, 68,\n",
" 3, 61, 44, 2, 79, 55, 85, 33, 12, 1, 11, 97, 50, 9, 84, 40, 71, 22,\n",
" 26, 30, 73, 16, 90, 27, 19, 35, 48, 18, 46, 74, 65, 78, 52, 31, 23, 58,\n",
" 91, 66, 99, 69, 51, 36, 82, 63, 25, 10]))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.sort()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
# from .a2a import all_to_all, with_gloo, with_nccl
# from .cache import EmbeddingCache
# from .gather import Gather
# from .route import Route, GatherWork
from .route import Route, RouteWorkBase
# from .cache import MessageCache, NodeProbe
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
def all_to_all(
output_tensor_list: List[Tensor],
input_tensor_list: List[Tensor],
group: Optional[Any] = None,
):
assert len(output_tensor_list) == len(input_tensor_list)
if group is None:
group = dist.distributed_c10d._get_default_group()
backend = dist.get_backend(group)
if backend == "nccl":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
elif backend == "mpi":
dist.all_to_all(
output_tensor_list=output_tensor_list,
input_tensor_list=input_tensor_list,
group=group,
)
else:
assert backend == "gloo", f"backend must be nccl, mpi or gloo"
rank = dist.get_rank()
world_size = dist.get_world_size()
p2p_op_list: List[dist.P2POp] = []
for i in range(1, world_size):
send_i = (rank + i) % world_size
recv_i = (rank - i + world_size) % world_size
p2p_op_list.extend([
dist.P2POp(dist.isend, input_tensor_list[send_i], send_i, group=group),
dist.P2POp(dist.irecv, output_tensor_list[recv_i], recv_i, group=group),
])
dist.batch_isend_irecv(p2p_op_list)
output_tensor_list[rank][:] = input_tensor_list[rank]
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
import torch
import time
from torch import Tensor
from typing import *
class Event:
def __init__(self, use_cuda: bool = True) -> None:
self._use_cuda = use_cuda
if use_cuda:
self._event = torch.cuda.Event(enable_timing=True)
else:
self._time: Optional[float] = None
def record(self):
if self._use_cuda:
self._event.record()
else:
self._time = time.time()
def wait(self, stream: Optional[Tensor] = None):
if self._use_cuda:
return
self._event.wait(stream)
def elapsed_time(self, other) -> float:
if self._use_cuda:
return self._event.elapsed_time(other._event)
else:
return (other._time - self._time) * 1000.0
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# from .route import Route, GatherWork
# # from .cache import CachedEmbeddings
# class GatherContext:
# def __init__(self,
# this,
# route: Route,
# async_op: bool,
# ) -> None:
# self.this = this
# self.route = route
# self.async_op = async_op
# class Gather(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_features: Optional[int] = None,
# beta: float = 1.0,
# ) -> None:
# super().__init__()
# self.beta = beta
# if num_features is None:
# self.register_buffer("last_embd", torch.zeros(num_nodes))
# else:
# self.register_buffer("last_embd", torch.zeros(num_nodes, num_features))
# self.last_fw_work: Optional[GatherWork] = None
# self.last_bw_work: Optional[GatherWork] = None
# self.reset_parameters()
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# route: Route,
# async_op: bool = False,
# ) -> Tuple[Tensor, Tensor]:
# with self._manager(route=route, async_op=async_op):
# return GatherFunction.apply(val, idx)
# def reset_parameters(self):
# last_embd = self.get_buffer("last_embd")
# nn.init.normal_(last_embd, mean=0, std=1)
# def fuse_embeddings(self,
# val: Tensor,
# idx: Optional[Tensor] = None,
# inplace: bool = False,
# ) -> Tensor:
# last_embd = self.get_buffer("last_embd")
# return GatherFuseFunction.apply(val, idx, last_embd, self.beta, inplace, self.training)
# @contextmanager
# def _manager(self, route: Route, async_op: bool):
# global _global_gather_context
# stacked = _global_gather_context
# try:
# _global_gather_context = GatherContext(
# this=self,
# route=route,
# async_op=async_op,
# )
# yield _global_gather_context
# finally:
# _global_gather_context = stacked
# class GatherFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor] = None,
# ):
# gather_ctx = _last_global_gather_context()
# this: Gather = gather_ctx.this
# route: Route = gather_ctx.route
# async_op: bool = gather_ctx.async_op
# return_idx: bool = idx is not None
# current_work = route.gather_forward(val, idx, async_op=async_op, return_idx=return_idx)
# if async_op:
# work = this.last_fw_work or current_work
# this.last_fw_work = current_work
# else:
# work = current_work
# recv_val, recv_idx = work.get()
# recv_val = recv_val if recv_idx is None else recv_val[recv_idx]
# recv_val = this.fuse_embeddings(recv_val, recv_idx, inplace=True)
# if this.training:
# ctx.save_for_backward(idx, recv_idx)
# ctx.this = this
# ctx.route = route
# ctx.async_op = async_op
# return recv_val, recv_idx
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# val_grad: Tensor,
# idx_grad: Optional[Tensor],
# ):
# this: Gather = ctx.this
# route: Route = ctx.route
# async_op: bool = ctx.async_op
# with torch.no_grad():
# recv_idx, idx_grad = ctx.saved_tensors
# if idx_grad is not None:
# val_grad = val_grad[idx_grad]
# current_work = route.gather_backward(val_grad, idx_grad, async_op=async_op, return_idx=False)
# if async_op:
# work = this.last_bw_work or current_work
# this.last_bw_work = current_work
# else:
# work = current_work
# recv_val = work.get_val()
# if recv_idx is not None:
# recv_val = recv_val[recv_idx]
# return recv_val, None
# class GatherFuseFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# last_embd: Tensor,
# beta: float,
# inplace: bool,
# training: bool,
# ):
# if not inplace:
# last_embd = last_embd.clone()
# ctx.beta = beta
# if idx is None:
# assert val.size(0) == last_embd.size(0)
# if beta != 1.0:
# last_embd.mul_(1 - beta).add_(val * beta)
# else:
# last_embd[:] = (val)
# else:
# assert val.size(0) == idx.size(0)
# if beta != 1.0:
# last_embd[idx] = last_embd[idx] * (1 - beta) + val * beta
# else:
# last_embd[idx] = val
# if training:
# ctx.beta = beta
# ctx.save_for_backward(idx)
# return last_embd
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# beta: float = ctx.beta
# idx, = ctx.saved_tensors
# if idx is not None:
# grad = grad[idx]
# if beta != 1.0:
# grad = grad * beta
# return grad, None, None, None, None, None
# #### private functions
# _global_gather_context: Optional[GatherContext] = None
# def _last_global_gather_context() -> GatherContext:
# global _global_gather_context
# assert _global_gather_context is not None
# return _global_gather_context
# import torch
# from torch import Tensor
# from typing import *
# from multiprocessing.pool import ThreadPool
# from .event import Event
# class AsyncCopyWorkBase:
# def __init__(self) -> None:
# self._events: Optional[Tuple[Event, Event]] = None
# def wait(self):
# raise NotImplementedError
# def get(self) -> Tensor:
# raise NotImplementedError
# def has_events(self) -> bool:
# return self._events is not None
# def set_events(self, start, end):
# self._events = (start, end)
# return self
# def time_used(self) -> float:
# if self._events is None:
# raise RuntimeError("not found events")
# start, end = self._events
# return start.elapsed_time(end)
# class AsyncPushWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor, values: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# data.index_copy_(0, index, values)
# self._events[1].record()
# def wait(self):
# pass
# def get(self):
# pass
# class AsyncPullWork(AsyncCopyWorkBase):
# def __init__(self, data: Tensor, index: Tensor) -> None:
# super().__init__()
# assert data.device == index.device
# self.set_events(
# Event(use_cuda=index.is_cuda),
# Event(use_cuda=index.is_cuda),
# )
# self._events[0].record()
# self._val = data.index_select(0, index)
# self._events[1].record()
# def wait(self):
# pass
# def get(self) -> Tensor:
# return self._val
# class AsyncOffloadWork(AsyncCopyWorkBase):
# def __init__(self, handle) -> None:
# super().__init__()
# self._handle = handle
# def wait(self):
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# self._handle.wait()
# def get(self) -> Tensor:
# if self._events is not None:
# self._events[1].wait(torch.cuda.current_stream())
# return self._handle.get()
# class AsyncCopyExecutor:
# def __init__(self) -> None:
# self._stream = torch.cuda.Stream()
# self._executor = ThreadPool(processes=1)
# @torch.no_grad()
# def async_pull(self, data: Tensor, index: Tensor) -> AsyncCopyWorkBase:
# # 这边的代码最好全部放到一个独立的线程里,一方面方便计时,也便于异步
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# dst = torch.zeros(
# size=(index.size(0),) + data.shape[1:],
# dtype=data.dtype,
# device=index.device,
# )
# dst.copy_(data[idx])
# end.record()
# return dst
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# # data and index in the same device
# return AsyncPullWork(data, index)
# @torch.no_grad()
# def async_push(self, data: Tensor, index: Tensor, values: Tensor) -> AsyncCopyWorkBase:
# assert index.device == values.device
# if data.device != index.device:
# assert not data.is_cuda
# assert index.is_cuda
# start = Event(index.is_cuda)
# end = Event(index.is_cuda)
# stream: Optional[torch.cuda.Stream] = self._stream
# def run():
# start.wait(stream)
# with torch.cuda.stream(stream):
# idx = index.to(data.device)
# val = values.to(data.device)
# data[idx] = val
# end.record()
# start.record()
# handle = self._executor.apply_async(run)
# return AsyncOffloadWork(handle).set_events(start, end)
# else:
# return AsyncPushWork(data, index, values)
# class Lache:
# def __init__(self,
# cache_size: int,
# data: Tensor,
# ) -> None:
# assert not data.is_cuda, "data must be in CPU Memory"
# cache_size = (cache_size,) + data.shape[1:]
# self.fdata = data
# self.cache = torch.zeros(cache_size, dtype=data.dtype)
# self.no_idx: int = (2**62-1)*2+1
# self.cached_idx = torch.empty(cache_size, dtype=torch.long).fill_(self.no_idx)
# self.read_count = torch.zeros(data.size(0), dtype=torch.long)
# def to(self, device):
# self.cache = self.cache.to(device)
# self.cached_idx = self.cached_idx.to(device)
# self.read_count = self.read_count.to(device)
# return self
# def _push_impl(self, value: Tensor, index: Tensor):
# # self.read_count[index] += 1
# imp = torch.zeros_like(self.read_count)
# def _pull_impl(self, index: Tensor):
# assert index.device == self.cache.device
# self.read_count[index] += 1
# s = index.shape[:1] + self.cache.shape[1:]
# x = torch.empty(s, dtype=self.cache.dtype, device=self.cache.device)
# cache_mask = torch.zeros_like(self.read_count, dtype=torch.bool).index_fill_(0, self.cached_idx, 1)[index]
# cache_index = index[cache_mask]
# x[cache_mask] =
# no_cache_index = index[~cache_mask]
# rt_index = index[~lc_mask].to(self.fdata.device)
# rt_data = self.fdata[rt_index].to(index.device)
# _THREAD_EXEC: Optional[AsyncCopyExecutor] = None
# def get_executor() -> AsyncCopyExecutor:
# global _THREAD_EXEC
# if _THREAD_EXEC is None:
# _THREAD_EXEC = AsyncCopyExecutor()
# return _THREAD_EXEC
# import torch
# from torch import Tensor
# from typing import *
# class ShrinkData:
# no_idx: int = (2**62-1)*2+1
# def __init__(self,
# src_size: int,
# dst_size: int,
# dst_idx: Tensor,
# edge_index: Tensor,
# bipartite: bool = False,
# ) -> None:
# device = dst_idx.device
# tmp = torch.empty(max(src_size, dst_size), dtype=torch.bool, device=device)
# tmp.fill_(0)
# tmp.index_fill_(0, dst_idx, 1)
# edge_idx = torch.where(tmp[edge_index[1]])[0]
# edge_index = edge_index[:, edge_idx]
# if bipartite:
# tmp.fill_(0)
# tmp.index_fill_(0, edge_index[0], 1)
# src_idx = torch.where(tmp)[0]
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp[dst_idx] = torch.arange(dst_idx.size(0), dtype=torch.long, device=device)
# dst = imp[edge_index[1]]
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# src = imp[edge_index[0]]
# edge_index = torch.vstack([src, dst])
# else:
# tmp.index_fill_(0, edge_index[0], 1)
# tmp.index_fill_(0, dst_idx, 0)
# src_idx = torch.cat([dst_idx, torch.where(tmp)[0]], dim=0)
# imp = torch.empty(max(src_size, dst_size), dtype=torch.long, device=device)
# imp.fill_(self.no_idx)
# imp[src_idx] = torch.arange(src_idx.size(0), dtype=torch.long, device=device)
# edge_index = imp[edge_index.flatten()].view_as(edge_index)
# self.src_idx = src_idx
# self.dst_idx = dst_idx
# self.edge_idx = edge_idx
# self.edge_index = edge_index
# self._src_imp = imp[:src_size]
# def to(self, device):
# self.src_idx = self.src_idx.to(device)
# self.dst_idx = self.dst_idx.to(device)
# self.edge_idx = self.edge_idx.to(device)
# self.edge_index = self.edge_index.to(device)
# self._src_imp = self._src_imp.to(device)
# return self
# def shrink_src_val_and_idx(self, val: Tensor, idx: Tensor) -> Tuple[Tensor, Tensor]:
# idx = self._src_imp[idx]
# m = (idx != self.no_idx)
# return val[m], idx[m]
# @property
# def src_size(self) -> int:
# return self.src_idx.size(0)
# @property
# def dst_size(self) -> int:
# return self.dst_idx.size(0)
# @property
# def edge_size(self) -> int:
# return self.edge_idx.size(0)
# import torch
# import torch.nn as nn
# import torch.autograd as autograd
# from torch import Tensor
# from contextlib import contextmanager
# from typing import *
# class StraightContext:
# def __init__(self, this, g) -> None:
# self.this = this
# self.g = g
# class Straight(nn.Module):
# def __init__(self,
# num_nodes: int,
# num_samples: int,
# norm_kwargs: Optional[Dict[str, Any]] = None,
# beta: float = 1.0,
# prev: Optional[List[Any]] = None,
# ) -> None:
# super().__init__()
# assert num_samples <= num_nodes
# self.num_nodes = num_nodes
# self.num_samples = num_samples
# self.norm_kwargs = norm_kwargs or dict(p=2, dim=-1)
# self.beta = beta
# self.prev = prev
# self.register_buffer("last_w", torch.ones(num_nodes))
# self._next_idx = None
# self._next_shrink_helper = None
# self.reset_parameters()
# def reset_parameters(self):
# last_w = self.get_buffer("last_w")
# nn.init.constant_(last_w, 1.0)
# def forward(self,
# val: Tensor,
# idx: Optional[Tensor],
# g,
# ) -> Tensor:
# with self._manager(self, g):
# return StraightFunction.apply(val, idx)
# def pop_next_shrink_helper(self) -> Tuple[Optional[Tensor], Any]:
# if not self.training:
# return None, None
# next_idx = self._next_idx
# self._next_idx = None
# next_sh = self._next_shrink_helper
# self._next_shrink_helper = None
# return next_idx, next_sh
# def _sample_next(self) -> Tensor:
# w = self.get_buffer("last_w")
# if self._next_idx is None:
# if self.num_samples < w.size(0):
# self._next_idx = self.sample_impl(w)
# else:
# self._next_idx = torch.arange(w.size(0), dtype=torch.long, device=w.device)
# elif self.num_samples < self._next_idx.size(0):
# idx = self.sample_impl(w[self._next_idx])
# self._next_idx = self._next_idx[idx]
# return self._next_idx
# def sample_impl(self, w: Tensor) -> Tensor:
# w = w / w.sum()
# return torch.multinomial(w, num_samples=self.num_samples, replacement=False)
# # def multinomial(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.arange(self.num_nodes, dtype=torch.long, device=w.device)
# # w = w / w.sum()
# # return torch.multinomial(w, num_samples=num_samples, replacement=replacement)
# # def multinomial_mask(self,
# # num_samples: int,
# # replacement: bool = False,
# # ) -> Tensor:
# # w = self.get_buffer("last_w")
# # if num_samples <= 0:
# # return torch.ones(self.num_nodes, dtype=torch.bool, device=w.device)
# # w = self.multinomial(num_samples, replacement)
# # m = torch.zeros(self.num_nodes, dtype=torch.bool, device=w.device)
# # m[w] = True
# # return m
# @contextmanager
# def _manager(self, this, g):
# global _global_straight_context
# stacked = _global_straight_context
# try:
# _global_straight_context = StraightContext(this=this, g=g)
# yield _global_straight_context
# finally:
# _global_straight_context = stacked
# class StraightFunction(autograd.Function):
# @staticmethod
# def forward(
# ctx: autograd.function.FunctionCtx,
# val: Tensor,
# idx: Optional[Tensor],
# ):
# from ..graph import DistGraph
# stx = _last_global_straight_context()
# this: Straight = stx.this
# g: DistGraph = stx.g
# last_w = this.get_buffer("last_w")
# if idx is None:
# assert val.size(0) == last_w.size(0)
# else:
# assert val.size(0) == idx.size(0)
# if this.training:
# ctx.this = this
# ctx.g = g
# ctx.save_for_backward(idx)
# return val
# @staticmethod
# def backward(
# ctx: autograd.function.FunctionCtx,
# grad: Tensor,
# ):
# from ..graph import DistGraph
# this: Straight = ctx.this
# g: DistGraph = ctx.g
# idx, = ctx.saved_tensors
# last_w = this.get_buffer("last_w")
# if this.beta != 1.0:
# last_w.mul_(this.beta)
# norm = grad.norm(**this.norm_kwargs)
# if idx is None:
# last_w[:] = norm
# else:
# last_w[idx] = norm
# if this.prev is not None or this._next_idx is None:
# from ..nn.convs.utils import ShrinkHelper
# dst_idx = this._sample_next()
# if this.prev is not None and this._next_idx is not None:
# this._next_shrink_helper = ShrinkHelper(g, dst_idx)
# src_idx = this._next_shrink_helper.src_idx
# work = g.route.gather_backward(
# src_idx, src_idx, async_op=False, return_idx=True)
# prev_dst_idx = work.get_idx()
# for p in this.prev:
# assert isinstance(p, Straight)
# p._next_idx = prev_dst_idx
# return grad, None
# #### private functions
# _global_straight_context: Optional[StraightContext] = None
# def _last_global_straight_context() -> StraightContext:
# global _global_straight_context
# assert _global_straight_context is not None
# return _global_straight_context
# if __name__ == "__main__":
# s = Straight(3, beta=1.1)
# x = torch.rand(3, 10).requires_grad_()
# m = torch.tensor([0, 1, 0], dtype=torch.bool)
# s(x).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x, m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# s(x[m], m).sum().backward()
# print(s.grad_norm)
# print(s.multinomial(2))
# print(s.multinomial_mask(2))
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# from torch_scatter import scatter_sum
# from starrygl.core.route import Route
# def compute_in_degree(edge_index: Tensor, route: Route) -> Tensor:
# dst_size = route.src_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# in_deg = scatter_sum(x, edge_index[1], dim=0, dim_size=dst_size)
# in_deg, _ = route.forward_a2a(in_deg)
# return in_deg
# def compute_out_degree(edge_index: Tensor, route: Route) -> Tensor:
# src_size = route.dst_size
# x = torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
# out_deg = scatter_sum(x, edge_index[0], dim=0, dim_size=src_size)
# out_deg, _ = route.backward_a2a(out_deg)
# out_deg, _ = route.forward_a2a(out_deg)
# return out_deg
# def compute_gcn_norm(edge_index: Tensor, route: Route) -> Tensor:
# in_deg = compute_in_degree(edge_index, route)
# out_deg = compute_out_degree(edge_index, route)
# a = in_deg[edge_index[0]].pow(-0.5)
# b = out_deg[edge_index[0]].pow(-0.5)
# x = a * b
# x[x.isinf()] = 0.0
# x[x.isnan()] = 0.0
# return x
# import torch
# import torch.nn as nn
# from torch import Tensor
# from ..graph import DistGraph
# from typing import *
# from .metrics import *
# def train_epoch(
# model: nn.Module,
# opt: torch.optim.Optimizer,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> float:
# model.train()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss: Tensor = criterion(pred, targ)
# opt.zero_grad()
# loss.backward()
# opt.step()
# with torch.no_grad():
# train_loss = all_reduce_loss(loss, targ.size(0))
# train_acc = accuracy(pred.argmax(dim=-1), targ)
# return train_loss, train_acc
# @torch.no_grad()
# def eval_epoch(
# model: nn.Module,
# g: DistGraph,
# mask: Optional[Tensor] = None,
# ) -> Tuple[float, float]:
# model.eval()
# criterion = nn.CrossEntropyLoss()
# pred: Tensor = model(g)
# targ: Tensor = g.ndata["y"]
# if mask is not None:
# pred = pred[mask]
# targ = targ[mask]
# loss = criterion(pred, targ)
# eval_loss = all_reduce_loss(loss, targ.size(0))
# eval_acc = accuracy(pred.argmax(dim=-1), targ)
# return eval_loss, eval_acc
\ No newline at end of file
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import *
# from contextlib import contextmanager
# # import torch_sparse
# from .ndata import NData
# from .edata import EData
# from .utils import init_local_edge_index
# from ..core import MessageCache, Route
# class DistGraph:
# def __init__(self,
# ids: Tensor,
# edge_index: Tensor,
# num_features: int,
# num_layers: int,
# cache_device: str = "cpu",
# **args: Dict[str, Any],
# ):
# # build local_edge_index
# dst_ids = ids
# src_ids, local_edge_index = init_local_edge_index(
# dst_ids=dst_ids,
# edge_index=edge_index,
# )
# self._src_ids = src_ids
# self._dst_ids = dst_ids
# self._message_cache = MessageCache(
# src_ids=src_ids,
# dst_ids=dst_ids,
# edge_index=local_edge_index,
# num_features=num_features,
# num_layers=num_layers,
# cache_device=cache_device,
# bipartite=False,
# )
# # node's attributes
# self.ndata = NData(
# src_size=src_ids.size(0),
# dst_size=dst_ids.size(0),
# )
# # edge's attributes
# self.edata = EData(
# edge_size=local_edge_index.size(1),
# )
# # graph's attributes
# self.args = dict(args)
# def to(self, device):
# self._message_cache.to(device)
# return self
# def cache_data_to(self, device):
# self._message_cache.cached_data_to(device)
# @property
# def cache(self) -> MessageCache:
# return self._message_cache
# @property
# def route(self) -> Route:
# return self._message_cache.route
# @property
# def device(self) -> torch.device:
# return self.edge_index.device
# @property
# def edge_index(self) -> Tensor:
# return self._message_cache.edge_index
# @property
# def src_ids(self) -> Tensor:
# if self._src_ids.device != self.device:
# self._src_ids = self._src_ids.to(self.device)
# return self._src_ids
# @property
# def src_size(self) -> int:
# return self._src_ids.size(0)
# @property
# def dst_ids(self) -> Tensor:
# if self._dst_ids.device != self.device:
# self._dst_ids = self._dst_ids.to(self.device)
# return self._dst_ids
# @property
# def dst_size(self) -> int:
# return self._dst_ids.size(0)
# @contextmanager
# def scoped_manager(self):
# stacked_ndata = self.ndata
# stacked_edata = self.edata
# try:
# self.ndata = NData(
# p=stacked_ndata,
# )
# self.edata = EData(
# p=stacked_edata,
# )
# yield self
# finally:
# self.ndata = stacked_ndata
# self.edata = stacked_edata
# # def permute_edge_(self):
# # perm = self._local_edge_index[1].argsort()
# # self._local_edge_index = self._local_edge_index[:,perm]
# # self._local_edge_ptr = torch.ops.torch_sparse.ind2ptr(self._local_edge_index[1], self.dst_size)
# # self.edata.permute_(perm)
# # @torch.no_grad()
# # def shrink(self,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ):
# # edge_index = self.edge_index
# # device = edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # return self
# # else:
# # if dst_mask is None:
# # dst_mask = torch.zeros(self.ndata.dst_size, **bkw)
# # else:
# # assert dst_mask.size(0) == self.ndata.dst_size
# # dst_mask = dst_mask.clone()
# # if src_mask is not None:
# # m = src_mask[edge_index[0]]
# # m = edge_index[1][m]
# # dst_mask[m] = True
# # edge_mask = dst_mask[edge_index[1]]
# # edge_index = edge_index[:,edge_mask]
# # src_mask = torch.zeros(self.ndata.src_size, **bkw)
# # src_mask[:self.ndata.dst_size] = dst_mask
# # src_mask[edge_index[0]] = True
# # dst_mask = src_mask[:self.ndata.dst_size]
# # # 重新编号edge_index
# # imp = torch.empty(self.ndata.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # # assert dst_mask.count_nonzero() > edge_index[1].max().item()
# # ndata = NData(
# # src_size=None,
# # dst_size=None,
# # p=self.ndata,
# # src_mask=src_mask,
# # dst_mask=dst_mask,
# # )
# # edata = EData(
# # edge_size=None,
# # p=self.edata,
# # edge_mask=edge_mask,
# # )
# # edata[EID] = edge_index
# # return DistGraph(self.args, ndata, edata, self.route)
# # class ShrinkGraph:
# # def __init__(self,
# # g: DistGraph,
# # src_mask: Optional[Tensor] = None,
# # dst_mask: Optional[Tensor] = None,
# # ) -> None:
# # device = g.edge_index.device
# # ikw = dict(dtype=torch.long, device=device)
# # bkw = dict(dtype=torch.bool, device=device)
# # if src_mask is None and dst_mask is None:
# # # 如果src_mask和dst_mask都不指定,
# # # 则ShrinkGraph是DistGraph的复刻
# # self.src_ids = g.sr
# # self.dst_size = g.dst_size
# # self.src_mask = torch.ones(self.src_size, **bkw)
# # self.dst_mask_size = g.dst_size
# # self.edge_index = g.edge_index
# # self.pgraph = g
# # else:
# # if src_mask is not None:
# # tmp_mask = torch.zeros(g.dst_size, **bkw)
# # # 计算直接激活的边
# # m = src_mask[g.edge_index[0]]
# # m = g.edge_index[1][m]
# # # 计算直接激活的dst_ids
# # tmp_mask[m] = True
# # # 和已激活的dst_ids合并
# # if dst_mask is not None:
# # tmp_mask |= dst_mask
# # dst_mask = tmp_mask
# # # 计算间接激活的边
# # edge_mask = dst_mask[g.edge_index[1]]
# # edge_index = g.edge_index[:,edge_mask]
# # # 计算间接激活的src_ids
# # src_mask = torch.zeros(g.src_size, **bkw)
# # src_mask[edge_index[0]] = True
# # src_mask[:g.dst_size] |= dst_mask
# # self.src_ids = g.src_ids[src_mask]
# # self.dst_size = src_mask[:g.dst_size].count_nonzero().item()
# # self.src_mask = src_mask
# # self.dst_mask_size = g.dst_size
# # # 重新编号edge_index
# # imp = torch.empty(g.src_size, **ikw).fill_((2**62-1)*2+1)
# # idx = torch.where(src_mask)[0]
# # imp[idx] = torch.arange(idx.size(0), **ikw)
# # edge_index = imp[edge_index.flatten()].view_as(edge_index)
# # self.edge_index = edge_index
# # self.pgraph = g
# # self.ndata = ShrinkData(self.dst_mask, g.ndata)
# # self.edata = ShrinkData(self.edge_mask, g.edata)
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class EData:
# def __init__(self,
# edge_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert edge_size is not None
# self.edge_size = edge_size
# else:
# assert edge_size is None
# self.edge_size = p.edge_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError(f"the second parameter's type must be Tensor")
# if tensor.size(0) == self.edge_size:
# self.data[name] = tensor
# else:
# raise ValueError(f"tensor's shape must match the edge_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
\ No newline at end of file
# import torch
# from torch import Tensor
# from typing import *
# class NData:
# def __init__(self,
# src_size: Optional[int] = None,
# dst_size: Optional[int] = None,
# p = None,
# ) -> None:
# if p is None:
# assert src_size is not None
# assert dst_size is not None
# self.src_size = src_size
# self.dst_size = dst_size
# else:
# assert src_size is None
# assert dst_size is None
# self.src_size = p.src_size
# self.dst_size = p.dst_size
# self.prev_data = p
# self.data: Dict[str, Tensor] = {}
# def __getitem__(self, name: str) -> Tensor:
# t = self.get(name)
# if t is None:
# raise ValueError(f"not found '{name}' in data")
# return t
# def __setitem__(self, name: str, tensor: Tensor):
# if not isinstance(tensor, Tensor):
# raise ValueError("the second parameter's type must be Tensor")
# if tensor.size(0) == self.src_size:
# self.data[name] = tensor
# elif tensor.size(0) == self.dst_size:
# self.data[name] = tensor
# else:
# raise ValueError("tensor's shape must match the src_size or dst_size")
# def __delitem__(self, name: str) -> None:
# self.pop(name)
# def get(self, name: str) -> Optional[Tensor]:
# p, t = self, None
# while p is not None:
# t = p.data.get(name)
# if t is not None:
# break
# p = p.prev_data
# return t
# def pop(self, name: str) -> Tensor:
# return self.data.pop(name)
# def permute_(self, perm: Tensor):
# p = self
# while p is not None:
# for key in list(p.data.keys()):
# val = p.data.get(key)
# p.data[key] = val[perm]
# p = p.prev_data
# def get_type(self, key: Union[str, Tensor]):
# if isinstance(key, Tensor):
# if key.size(0) == self.src_size:
# return "src"
# elif key.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
# t = self.__getitem__(key)
# if t.size(0) == self.src_size:
# return "src"
# elif t.size(0) == self.dst_size:
# return "dst"
# else:
# raise RuntimeError
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .route import Route
from threading import Lock
class TensorBuffer(nn.Module):
def __init__(self,
channels: int,
num_nodes: int,
route: Route,
) -> None:
super().__init__()
self.channels = channels
self.num_nodes = num_nodes
self.route = route
self.local_lock = Lock()
self.register_buffer("_data", torch.zeros(num_nodes, channels), persistent=False)
self.rrefs = all_gather_remote_objects(self)
@property
def data(self) -> Tensor:
return self.get_buffer("_data")
@property
def device(self) -> torch.device:
return self.data.device
def local_get(self, index: Optional[Tensor] = None, lock: bool = True) -> Tensor:
if lock:
with self.local_lock:
return self.local_get(index, lock=False)
if index is None:
return self.data
else:
return self.data[index]
def local_set(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_set(value, index, lock=False)
if index is None:
self.data.copy_(value)
else:
# value = value.to(self.device)
self.data[index] = value
def local_add(self, value: Tensor, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_add(value, index, lock=False)
if index is None:
self.data.add_(value)
else:
# value = value.to(self.device)
self.data[index] += value
def local_cls(self, index: Optional[Tensor] = None, lock: bool = True):
if lock:
with self.local_lock:
return self.local_cls(index, lock=False)
if index is None:
self.data.zero_()
else:
self.data[index] = 0
def remote_get(self, dst: int, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_get, self.rrefs[dst], index=index, lock=lock)
def remote_set(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_set, self.rrefs[dst], value, index=index, lock=lock)
def remote_add(self, dst: int, value: Tensor, index: Tensor, lock: bool = True):
return TensorBuffer._remote_call(TensorBuffer.local_add, self.rrefs[dst], value, index=index, lock=lock)
def all_remote_get(self, index: Tensor, lock: bool = True):
def cb0(idx):
def f(x: torch.futures.Future[Tensor]):
return x.value(), idx
return f
def cb1(buf):
def f(xs: torch.futures.Future[List[torch.futures.Future]]) -> Tensor:
for x in xs.value():
dat, idx = x.value()
# print(dat.size(), idx.size())
buf[idx] += dat
return buf
return f
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_get(i, remote_idx, lock=lock).then(cb0(idx)))
futs = torch.futures.collect_all(futs)
buf = torch.zeros(index.size(0), self.channels, dtype=self.data.dtype, device=self.data.device)
return futs.then(cb1(buf))
def all_remote_set(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_set(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def all_remote_add(self, value: Tensor, index: Tensor, lock: bool = True):
futs = []
for i, (idx, remote_idx) in enumerate(self.route.parts_iter(index)):
futs.append(self.remote_add(i, value[idx], remote_idx, lock=lock))
return torch.futures.collect_all(futs)
def broadcast(self, barrier: bool = True):
if barrier:
dist.barrier()
index = torch.arange(self.num_nodes, dtype=torch.long, device=self.data.device)
data = self.all_remote_get(index, lock=True).wait()
self.local_set(data, lock=True)
if barrier:
dist.barrier()
# def remote_get(self, dst: int, i: int, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Union[Tensor, Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_get, self.rrefs[dst], i, index, global_index = global_index)
# def remote_set(self, dst: int, i: int, value: Tensor, index: Optional[Tensor], global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_set, self.rrefs[dst], i, value, index, global_index = global_index)
# def remote_add(self, dst: int, i: int, value: Tensor, index: Optional[Tensor] = None, global_index: bool = False, async_op: bool = False) -> Optional[Future]:
# return TensorBuffer._remote_call(async_op, TensorBuffer.local_add, self.rrefs[dst], i, value, index, global_index = global_index)
# def async_scatter_fw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_fw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.fw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_set(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_set(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def async_scatter_bw_add(self, i: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Future]:
# futures: List[Future] = []
# for dst in range(self.world_size):
# val, ind = self.router.bw_value_index(dst, value, index)
# futures.append(self.remote_add(dst, i, val, ind, global_index=True, async_op=True))
# return tuple(futures)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.num_layers < i and i < self.num_layers
# i = (self.num_layers + i) % self.num_layers
# return i, self.get_buffer(f"data{i}")
@staticmethod
def _remote_call(method, rref: rpc.RRef, *args, **kwargs):
args = (method, rref) + args
return rpc.rpc_async(rref.owner(), TensorBuffer._method_call, args=args, kwargs=kwargs)
@staticmethod
def _method_call(method, rref: rpc.RRef, *args, **kwargs):
self: TensorBuffer = rref.local_value()
index = kwargs["index"]
kwargs["index"] = self.route.to_local_ids(index)
return method(self, *args, **kwargs)
\ No newline at end of file
import torch
from contextlib import contextmanager
from torch import Tensor
from typing import *
class RouteContext:
def __init__(self) -> None:
self._futs: List[torch.futures.Future] = []
def synchronize(self):
for fut in self._futs:
fut.wait()
self._futs = []
def add_futures(self, *futs):
for fut in futs:
assert isinstance(fut, torch.futures.Future)
self._futs.append(fut)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
raise exc_type(exc_value)
self.synchronize()
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
from starrygl.parallel import all_gather_remote_objects
from .utils import init_local_edge_index
class Route(nn.Module):
def __init__(self,
src_ids: Tensor,
dst_size: int,
) -> None:
super().__init__()
self.register_buffer("_src_ids", src_ids, persistent=False)
self.dst_size = dst_size
self._init_nids_mapper()
self._init_part_mapper()
@staticmethod
def from_edge_index(dst_ids: Tensor, edge_index: Tensor):
src_ids, local_edge_index = init_local_edge_index(dst_ids, edge_index)
return Route(src_ids, dst_ids.size(0)), local_edge_index
@property
def src_ids(self) -> Tensor:
return self.get_buffer("_src_ids")
@property
def src_size(self) -> int:
return self.src_ids.size(0)
@property
def dst_ids(self) -> Tensor:
return self.src_ids[:self.dst_size]
@property
def ext_ids(self) -> Tensor:
return self.src_ids[self.dst_size:]
@property
def ext_size(self) -> int:
return self.src_size - self.dst_size
def parts_iter(self, local_ids: Tensor) -> Iterator[Tuple[Tensor, Tensor]]:
world_size = dist.get_world_size()
part_mapper = self.part_mapper[local_ids]
for i in range(world_size):
# part_ids = local_ids[part_mapper == i]
part_ids = torch.where(part_mapper == i)[0]
glob_ids = self.src_ids[part_ids]
yield part_ids, glob_ids
def to_local_ids(self, ids: Tensor) -> Tensor:
return self.nids_mapper[ids]
def _init_nids_mapper(self):
num_nodes: int = self.src_ids.max().item() + 1
device: torch.device = self.src_ids.device
mapper = torch.empty(num_nodes, dtype=torch.long, device=device).fill_((2**62-1)*2+1)
mapper[self.src_ids] = torch.arange(self.src_ids.size(0), dtype=torch.long, device=device)
self.register_buffer("nids_mapper", mapper, persistent=False)
def _init_part_mapper(self):
device: torch.device = self.src_ids.device
nids_mapper = self.get_buffer("nids_mapper")
mapper = torch.empty(self.src_size, dtype=torch.int32, device=device).fill_(-1)
for i, dst_ids in enumerate(all_gather_remote_objects(self.dst_ids)):
dst_ids: Tensor = dst_ids.to_here().to(device)
dst_ids = dst_ids[dst_ids < nids_mapper.size(0)]
dst_local_inds = nids_mapper[dst_ids]
dst_local_mask = dst_local_inds != ((2**62-1)*2+1)
dst_local_inds = dst_local_inds[dst_local_mask]
mapper[dst_local_inds] = i
assert (mapper >= 0).all()
self.register_buffer("part_mapper", mapper, persistent=False)
# class RouteTable(nn.Module):
# def __init__(self,
# src_ids: Tensor,
# dst_size: int,
# ) -> None:
# super().__init__()
# self.register_buffer("src_ids", src_ids)
# self.src_size: int = src_ids.size(0)
# self.dst_size = dst_size
# assert self.src_size >= self.dst_size
# self._init_mapper()
# rank, world_size = rank_world_size()
# rrefs = all_gather_remote_objects(self)
# gather_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._get_dst_ids, args=(rref,))
# gather_futures.append(fut)
# max_src_ids: int = src_ids.max().item()
# smp = torch.empty(max_src_ids + 1, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# smp[src_ids] = torch.arange(src_ids.size(0), dtype=smp.dtype, device=smp.device)
# self.fw_masker = RouteMasker(self.dst_size, world_size)
# self.bw_masker = RouteMasker(self.src_size, world_size)
# dist.barrier()
# scatter_futures: List[torch.futures.Future] = []
# for i in range(world_size):
# fut = gather_futures[i]
# s_ids: Tensor = src_ids
# d_ids: Tensor = fut.wait()
# num_ids: int = max(s_ids.max().item(), d_ids.max().item()) + 1
# imp = torch.zeros(num_ids, dtype=torch.long, device=self._get_device())
# imp[s_ids] += 1
# imp[d_ids] += 1
# ind = torch.where(imp > 1)[0]
# imp.fill_((2**62-1)*2+1)
# imp[d_ids] = torch.arange(d_ids.size(0), dtype=imp.dtype, device=imp.device)
# s_ind = smp[ind]
# d_ind = imp[ind]
# rref = rrefs[i]
# fut = rpc.rpc_async(rref.owner(), RouteTable._set_fw_mask, args=(rref, rank, d_ind))
# scatter_futures.append(fut)
# bw_mask = torch.zeros(self.src_size, dtype=torch.bool).index_fill_(0, s_ind, 1)
# self.bw_masker.set_mask(i, bw_mask)
# for fut in scatter_futures:
# fut.wait()
# dist.barrier()
# # def fw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.fw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# # def bw_index(self, dst: int, index: Tensor) -> Tensor:
# # mask = self.bw_masker.select(dst, index)
# # return self.get_global_index(index[mask])
# def fw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.dst_size
# mask = self.fw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[:self.dst_size][mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.fw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def bw_value_index(self, dst: int, value: Tensor, index: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# if index is None:
# assert value.size(0) == self.src_size
# mask = self.bw_masker.select(dst)
# return value[mask], self.get_buffer("src_ids")[mask]
# else:
# assert value.size(0) == index.size(0)
# mask = self.bw_masker.select(dst, index)
# value, index = value[mask], index[mask]
# return value, self.get_global_index(index)
# def get_global_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("src_ids")[index]
# def get_local_index(self, index: Tensor) -> Tensor:
# return self.get_buffer("mapper")[index]
# @staticmethod
# def _get_dst_ids(rref: rpc.RRef):
# self: RouteTable = rref.local_value()
# src_ids = self.get_buffer("src_ids")
# return src_ids[:self.dst_size]
# @staticmethod
# def _set_fw_mask(rref: rpc.RRef, dst: int, fw_ind: Tensor):
# self: RouteTable = rref.local_value()
# fw_mask = torch.zeros(self.dst_size, dtype=torch.bool).index_fill_(0, fw_ind, 1)
# self.fw_masker.set_mask(dst, fw_mask)
# def _get_device(self):
# return self.get_buffer("src_ids").device
# def _init_mapper(self):
# src_ids = self.get_buffer("src_ids")
# num_nodes: int = src_ids.max().item() + 1
# mapper = torch.empty(num_nodes, dtype=torch.long, device=src_ids.device).fill_((2**62-1)*2+1)
# mapper[src_ids] = torch.arange(src_ids.size(0), dtype=torch.long)
# self.register_buffer("mapper", mapper)
# class RouteMasker(nn.Module):
# def __init__(self,
# num_nodes: int,
# world_size: int,
# ) -> None:
# super().__init__()
# m = (world_size + 7) // 8
# self.num_nodes = num_nodes
# self.world_size = world_size
# self.register_buffer("data", torch.zeros(m, num_nodes, dtype=torch.uint8))
# def forward(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# return self.select(i, index)
# def select(self, i: int, index: Optional[Tensor] = None) -> Tensor:
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# if index is None:
# mask = data[k].bitwise_right_shift(r).bitwise_and_(1)
# else:
# mask = data[k][index].bitwise_right_shift_(r).bitwise_and_(1)
# return mask.type(dtype=torch.bool)
# def set_mask(self, i: int, mask: Tensor) -> Tensor:
# assert mask.size(0) == self.num_nodes
# i, data = self._idx_data(i)
# k, r = i // 8, i % 8
# data[k] &= ~(1<<r)
# data[k] |= mask.type(torch.uint8).bitwise_left_shift_(r)
# def _idx_data(self, i: int) -> Tuple[int, Tensor]:
# assert -self.world_size < i and i < self.world_size
# i = (i + self.world_size) % self.world_size
# return i, self.get_buffer("data")
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
from starrygl.parallel import get_compute_device
from starrygl.core.route import Route
from starrygl.utils.partition import metis_partition
def init_local_edge_index(
dst_ids: Tensor,
edge_index: Tensor,
bipartite: bool = False,
) -> Tuple[Tensor, Tensor]:
max_ids = calc_max_ids(dst_ids, edge_index)
ikw = dict(dtype=torch.long, device=dst_ids.device)
xmp = torch.zeros(max_ids + 1, **ikw)
# 判断是不是点分割且所有边被划分到目标点所在分区
xmp[edge_index[1].unique()] += 0b01
xmp[dst_ids.unique()] += 0b10
if not (xmp != 0x01).all():
raise RuntimeError(f"must be vertex-cut partition graph")
if bipartite:
src_ids = edge_index[0].unique()
else:
# 假设是同构图
# src_ids 等于 [dst_ids, edge_index[0] except dst_ids]
xmp.fill_(0)
xmp[edge_index[0]] = 1
xmp[dst_ids] = 0
src_ids = torch.cat([dst_ids, torch.where(xmp > 0)[0]], dim=-1)
# 计算局部索引
xmp.fill_((2**62-1)*2+1)
xmp[src_ids] = torch.arange(src_ids.size(0), **ikw)
src = xmp[edge_index[0]]
xmp.fill_((2**62-1)*2+1)
xmp[dst_ids] = torch.arange(dst_ids.size(0), **ikw)
dst = xmp[edge_index[1]]
local_edge_index = torch.vstack([src, dst])
return src_ids, local_edge_index
def calc_max_ids(*ids: Tensor) -> int:
x = [t.max().item() if t.numel() > 0 else 0 for t in ids]
return max(*x)
def local_partition_fn(dst_size: Tensor, edge_index: Tensor, num_parts: int) -> Tensor:
edge_index = edge_index[:, edge_index[0] < dst_size]
return metis_partition(edge_index, dst_size, num_parts)[0]
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor, LongTensor
from typing import *
def _local_TP_FP_FN(pred: LongTensor, targ: LongTensor, num_classes: int) -> Tensor:
TP, FP, FN = 0, 1, 2
tmp = torch.empty(3, num_classes, dtype=torch.float32, device=pred.device)
for c in range(num_classes):
pred_c = (pred == c)
targ_c = (targ == c)
tmp[TP, c] = torch.count_nonzero(pred_c and targ_c)
tmp[FP, c] = torch.count_nonzero(pred_c and not targ_c)
tmp[FN, c] = torch.count_nonzero(not pred_c and targ_c)
return tmp
def micro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes).sum(dim=-1)
dist.all_reduce(tmp)
TP, FP, FN = tmp.tolist()
precision = TP / (TP + FP)
recall = TP / (TP + FN)
return 2 * precision * recall / (precision + recall)
def macro_f1(pred: LongTensor, targ: LongTensor, num_classes: int) -> float:
tmp = _local_TP_FP_FN(pred, targ, num_classes)
dist.all_reduce(tmp)
TP, FP, FN = tmp
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)
return f1.mean().item()
def accuracy(pred: LongTensor, targ: LongTensor) -> float:
tmp = torch.empty(2, dtype=torch.float32, device=pred.device)
tmp[0] = pred.eq(targ).count_nonzero()
tmp[1] = pred.size(0)
dist.all_reduce(tmp)
a, b = tmp.tolist()
return a / b
def all_reduce_loss(loss: Tensor, batch_size: int) -> float:
tmp = torch.tensor([
loss.item() * batch_size,
batch_size
], dtype=torch.float32, device=loss.device)
dist.all_reduce(tmp)
cum_loss, n = tmp.tolist()
return cum_loss / n
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from typing import *
# from .degree import compute_in_degree, compute_out_degree, compute_gcn_norm
from .sync_bn import SyncBatchNorm
def convert_parallel_model(
net: nn.Module,
find_unused_parameters=False,
) -> nn.parallel.DistributedDataParallel:
net = SyncBatchNorm.convert_sync_batchnorm(net)
net = nn.parallel.DistributedDataParallel(net,
find_unused_parameters=find_unused_parameters,
)
return net
def init_process_group(backend: str = "gloo") -> torch.device:
rank = int(os.getenv("RANK") or os.getenv("OMPI_COMM_WORLD_RANK"))
world_size = int(os.getenv("WORLD_SIZE") or os.getenv("OMPI_COMM_WORLD_SIZE"))
dist.init_process_group(
backend=backend,
init_method=ccl_init_method(),
rank=rank, world_size=world_size,
)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = rpc_init_method()
for i in range(world_size):
rpc_backend_options.set_device_map(f"worker{i}", {rank: i})
rpc.init_rpc(
name=f"worker{rank}",
rank=rank, world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_rank = os.getenv("LOCAL_RANK") or os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
if backend == "nccl" or backend == "mpi":
device = torch.device(f"cuda:{local_rank or rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
global _COMPUTE_DEVICE
_COMPUTE_DEVICE = device
return device
def rank_world_size() -> Tuple[int, int]:
return dist.get_rank(), dist.get_world_size()
def get_worker_info(rank: Optional[int] = None) -> rpc.WorkerInfo:
rank = dist.get_rank() if rank is None else rank
return rpc.get_worker_info(f"worker{rank}")
_COMPUTE_DEVICE = torch.device("cpu")
def get_compute_device() -> torch.device:
global _COMPUTE_DEVICE
return _COMPUTE_DEVICE
_TEMP_AG_REMOTE_OBJECT = None
def _remote_object():
global _TEMP_AG_REMOTE_OBJECT
return _TEMP_AG_REMOTE_OBJECT
def all_gather_remote_objects(obj: Any) -> List[rpc.RRef]:
global _TEMP_AG_REMOTE_OBJECT
_TEMP_AG_REMOTE_OBJECT = rpc.RRef(obj)
dist.barrier()
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
info = get_worker_info(i)
futs.append(rpc.rpc_async(info, _remote_object))
rrefs: List[rpc.RRef] = []
for f in futs:
f.wait()
rrefs.append(f.value())
dist.barrier()
_TEMP_AG_REMOTE_OBJECT = None
return rrefs
def ccl_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port}"
def rpc_init_method() -> str:
master_addr = os.environ["MASTER_ADDR"]
master_port = int(os.environ["MASTER_PORT"])
return f"tcp://{master_addr}:{master_port+1}"
\ No newline at end of file
import torch.distributed as dist
def sync_print(*args, **kwargs):
rank = dist.get_rank()
world_size = dist.get_world_size()
for i in range(world_size):
if i == rank:
print(f"rank {rank}:", *args, **kwargs)
dist.barrier()
def main_print(*args, **kwargs):
rank = dist.get_rank()
if rank == 0:
print(*args, **kwargs)
# dist.barrier()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.core import NodeProbe
from starrygl.nn.convs.s_gat_conv import GATConv
from starrygl.graph import DistGraph
from starrygl.parallel import init_process_group, convert_parallel_model
from starrygl.parallel import compute_gcn_norm, SyncBatchNorm, with_nccl
from starrygl.utils import train_epoch, eval_epoch, partition_load, main_print, sync_print
class Net(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
heads: int = 8,
) -> None:
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers
last_ch = in_channels
self.convs = nn.ModuleList()
for i in range(num_layers):
out_ch = out_channels if i == num_layers - 1 else hidden_channels
self.convs.append(GATConv(last_ch, out_ch, heads))
last_ch = out_ch
def forward(self, g: DistGraph, probe: NodeProbe):
for i in range(self.num_layers):
shrink = g.cache.layers[i].get_shrink_data()
if i == 0:
x = g.cache.fetch_pull_tensor(i)
else:
# x = torch.ones(dst_idx.size(0), self.hidden_channels, device=x.device)
x, _ = g.cache.update_cache(i, x, dst_idx)
dst_idx = shrink.dst_idx
# print(shrink.edge_index[1].unique().size(0), shrink.dst_size)
# print(x.size(0), shrink.edge_index[0].unique().size(0), shrink.src_size)
# e = x[shrink.edge_index[0]]
# from torch_scatter import scatter
from starrygl.utils.printer import main_print
main_print(x.requires_grad)
# print(e.size(0), shrink.edge_index[1].size(0), shrink.dst_size, shrink.edge_index[1].max().item())
# scatter(e, shrink.edge_index[1], dim=0, dim_size=shrink.dst_size)
# return
x = self.convs[i](shrink, x)
x = F.relu(x)
x = probe.apply(i, x, dst_idx)
return x
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
g = DistGraph(
ids=pdata.ids,
edge_index=pdata.edge_index,
num_features=64,
num_layers=3,
cache_device="cpu"
).to(device)
cached_data_0, _ = g.route.forward_a2a(pdata.x).get()
g.cache.replace_layer_data(0, cached_data_0.to(g.cache.cache_device))
probe = NodeProbe(
g.dst_size,
num_layers=3,
num_samples=128,
).to(device)
probe.assign_message_cache(g.cache)
probe.layers[-1].warmup_sample()
net = Net(pdata.num_features, 64, pdata.num_classes, num_layers=3).to(device)
net(g, probe).sum().backward()
\ No newline at end of file
import torch
import torch.distributed as dist
from starrygl.parallel import init_process_group
from typing import *
if __name__ == "__main__":
device = init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 4
xs = [torch.randn(100+i*10, device=device) for i in range(world_size)]
send_sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long, device=device)
recv_sizes = torch.zeros_like(send_sizes)
dist.all_to_all_single(recv_sizes, send_sizes)
\ No newline at end of file
# import torch
# from starrygl.core.acopy import get_executor, AsyncCopyWorkBase, List
# if __name__ == "__main__":
# data = torch.arange(1024*1024*1024, dtype=torch.float32).view(-1, 1).repeat(1, 8)
# # buffer = torch.arange(1000, dtype=torch.float32).view(-1, 1).repeat(1, 8).pin_memory()
# index = torch.arange(1024*1024 * 256, dtype=torch.long).cuda()
# values = torch.ones(1024*1024 * 256, 8, dtype=torch.float32).cuda()
# pull_works: List[AsyncCopyWorkBase] = []
# push_works: List[AsyncCopyWorkBase] = []
# for s in range(0, index.size(0), 1024 * 1024 * 256):
# t = min(s + 1024 * 1024 * 256, index.size(0))
# idx = index[s:t]
# val = values[idx]
# pullw = get_executor().async_pull(data, idx)
# pushw = get_executor().async_push(data, idx, val)
# pull_works.append(pullw)
# push_works.append(pushw)
# pull_time = 0.0
# for w in pull_works:
# val = w.get()
# pull_time += w.time_used()
# # print(val)
# push_time = 0.0
# for w in push_works:
# w.get()
# push_time += w.time_used()
# # print(data)
# print(pull_time, push_time)
\ No newline at end of file
import torch
import torch.autograd as autograd
from torch import Tensor
from typing import *
class AFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
return a, b.float()
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
a: Tensor,
b: Tensor,
):
print(a, b)
return a, b.bool()
if __name__ == "__main__":
x = torch.rand(10).requires_grad_()
y = torch.ones(10).bool()
a, b = AFunction.apply(x, y)
(a + b).sum().backward()
print(a, b)
print(x.grad, y.grad)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
from starrygl.loader import NodeLoader, TensorBuffer, RouteContext
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis")
loader = NodeLoader(pdata.ids, pdata.edge_index, device)
hidden_size = 64
buffers: List[TensorBuffer] = [
TensorBuffer(pdata.num_features, loader.src_size, loader.route),
TensorBuffer(hidden_size, loader.src_size, loader.route),
]
buffers[0].data[:loader.dst_size] = pdata.x
buffers[0].broadcast()
for handle in loader.iter(128):
with RouteContext() as ctx:
sync_print(handle.src_size, handle.dst_size)
dst_feats = handle.get_dst_feats(buffers[0]).wait()
ext_fut = handle.get_ext_feats(buffers[1])
sync_print(dst_feats.size())
src_feats, fut = handle.push_and_pull(dst_feats[:,:hidden_size], ext_fut, buffers[1])
ctx.add_futures(fut)
sync_print(src_feats.size())
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from typing import *
import os
import time
import psutil
# from starrygl.loader import DataLoader, TensorBuffers
from starrygl.loader.route import Route
from starrygl.loader.buffer import TensorBuffer
from starrygl.loader.utils import init_local_edge_index, local_partition_fn
from starrygl.parallel import init_process_group
from starrygl.utils import partition_load, main_print, sync_print
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
# def get_route_table(device = "cpu") -> Tuple[RouteTable, Tensor]:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
# edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
# src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
# return RouteTable(src_ids, ids.size(0)), local_edge_index
# def get_features(device = "cpu") -> Tensor:
# assert dist.get_world_size() == 3
# rank = dist.get_rank()
# return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
if __name__ == "__main__":
# 启动分布式进程组,并分配计算设备
device = init_process_group(backend="nccl")
# 加载数据集
pdata = partition_load("./cora", algo="metis").to(device)
route, local_edge_index = Route.from_edge_index(pdata.ids, pdata.edge_index)
# route_table = RouteTable(src_ids, dst_size)
# sync_print(route_table.fw_0)
# route_table = get_route_table()[0]
# for name, buffer in route_table.named_buffers():
# if name.startswith("fw_") or name.startswith("bw_"):
# sync_print(name, buffer.tolist())
tb = TensorBuffer(pdata.num_features, route.src_size, route).cpu()
print(tb.data.device, pdata.x.device)
tb.data[:pdata.num_nodes] = pdata.x
dist.barrier()
tb.broadcast()
print(tb.data[pdata.num_nodes:].sum(dim=-1))
# assert (tb.data[:pdata.num_nodes] == pdata.x).all()
# local_src_ids = torch.arange(route.src_size)
# for local_ids, remote_ids in route.parts_iter(local_src_ids):
# sync_print(local_ids.size(), remote_ids.size())
# router, edge_index = get_route_table()
# x = get_features()[:,None]
# buffer = TensorBuffers(router, channels=[1])
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
# main_print("+"*64)
# buffer.zero_grad()
# buffer.local_set(0, x)
# sync_print(buffer.local_get(0).view(-1).tolist())
# buffer.broadcast2()
# main_print("="*64)
# sync_print(buffer.local_get(0).view(-1).tolist())
rpc.shutdown()
# buffer.remote_get(0, 0, async_op=True).then()
# num_parts = 50
# node_parts = local_partition_fn(dst_size, local_edge_index, num_parts)
# feat_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# grad_buffer = TensorBuffers(src_ids.size(0), [pdata.num_features] * 4)
# node_time = torch.rand(dst_size)
# edge_time = torch.rand(local_edge_index.size(1))
# edge_attr = torch.randn(local_edge_index.size(1), 10)
# loader = DataLoader(node_parts, feat_buffer, grad_buffer, local_edge_index, edge_attr, edge_time, node_time)
# feat_buffer.set(0, None, collect_feat0(src_ids, src_ids[:dst_size], pdata.x))
# for handle, edge_index, edge_attr in loader.iter(2, layer_id=0, device=device, filter=lambda x: x < 0.5):
# print(handle.feat0.sum(dim=-1))
# print(edge_index)
# print(edge_attr.sum(dim=-1))
# handle.push_and_pull(handle.feat0[:handle.dst_size], 0)
# assert (handle.fetch_feat() == handle.feat0).all()
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="mpi")
sync_print(dist.get_backend())
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from multiprocessing.pool import ThreadPool
num_threads = 1
num_tasks = 5
pool = ThreadPool(num_threads)
if __name__ == "__main__":
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
group = dist.new_group()
stream = torch.cuda.Stream(device)
handles = []
for i in range(num_tasks):
x = torch.ones(10, dtype=torch.float32, device=device)
def run(x):
# stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
group.allreduce(x)
return x
stream.wait_stream(torch.cuda.current_stream())
handles.append(pool.apply_async(run, (x,)))
x = torch.ones(3, dtype=torch.float32, device=device)
dist.all_reduce(x)
sync_print(x.tolist())
for i in range(num_tasks):
main_print("="*64)
x = handles[i].get()
assert (handles[i].get() == x).all()
sync_print(x.tolist())
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
from starrygl.core.gather import Gather
from starrygl.core.straight import Straight
from starrygl.parallel import init_process_group, compute_degree
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_distgraph(device = "cpu") -> DistGraph:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
return DistGraph(ids, edge_index)
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(g: DistGraph, x: Tensor) -> Tensor:
edge_index = g.edge_index
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
return x
if __name__ == "__main__":
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = init_process_group(backend="nccl")
g = get_distgraph(device)
# sync_print(g.src_ids.tolist())
# main_print("="*64)
# sync_print(g.dst_ids.tolist())
# main_print("="*64)
# sync_print(g.route.forward_routes[0])
# main_print("="*64)
in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
# straight = Straight(g.dst_size, num_samples=1)
# gather = Gather(g.src_size).to(device)
# sync_print(g)
# sync_print(gather.get_buffer("last_embd").tolist())
# main_print("="*64)
sync_print(x.tolist(), g.route.src_size)
main_print("="*64)
a = torch.arange(g.route.src_size, dtype=torch.long, device=device)
# a = None
# z, a = gather(x, None, route=g.route)
z, a = g.route.apply(x, a, {}, async_op=False)
sync_print(z.tolist(), "" if a is None else a.tolist(), g.route.dst_size)
main_print("="*64)
sync_print(g.edge_index[0])
main_print("="*64)
z = reduce(g, z)
sync_print(z.tolist())
main_print("="*64)
# z = straight(x, a, g)
# sync_print(z.tolist())
# main_print("="*64)
loss = z.sum()
sync_print(loss.item())
main_print("="*64)
loss.backward()
sync_print(x.grad.tolist())
main_print("="*64)
sync_print(f"time_used: {g.route.total_time_used}ms")
\ No newline at end of file
import torch
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.route import Route
from starrygl.graph.utils import init_local_edge_index
from starrygl.parallel import init_process_group
from starrygl.utils.printer import main_print, sync_print
from torch_scatter import scatter_sum
all_nparts = [
[0, 1],
[2, 3, 4],
[5, 6],
]
all_eparts = [
[
[2, 0], [1, 0], [0, 1], [3, 1], [4, 1],
],
[
[0, 2], [3, 2], [5, 2], [1, 3], [2, 3],
[4, 3], [6, 3], [1, 4], [3, 4],
],
[
[2, 5], [6, 5], [5, 6], [3, 6],
],
]
def get_route(device = "cpu") -> Tuple[Route, Tensor]:
assert dist.get_world_size() == 3
rank = dist.get_rank()
ids = torch.tensor(all_nparts[rank], dtype=torch.long, device=device)
edge_index = torch.tensor(all_eparts[rank], dtype=torch.long, device=device).t()
src_ids, local_edge_index = init_local_edge_index(ids, edge_index)
return Route(ids, src_ids), local_edge_index
def get_features(device = "cpu") -> Tensor:
assert dist.get_world_size() == 3
rank = dist.get_rank()
return torch.tensor(all_nparts[rank], dtype=torch.float32, device=device) + 1.0
def reduce(x: Tensor, edge_index: Tensor, dim_size: int) -> Tensor:
x = x[edge_index[0]]
x = scatter_sum(x, edge_index[1], dim=0, dim_size=dim_size)
return x
if __name__ == "__main__":
device = init_process_group(backend="mpi")
route, edge_index = get_route(device)
src_size = route.dst_size
dst_size = route.src_size
# in_degree, out_degree = compute_degree(g)
# sync_print(in_degree.tolist())
# sync_print(out_degree.tolist())
x = get_features(device).requires_grad_()
sync_print(x.tolist())
main_print("="*64)
z, _ = route.apply(x, None)
sync_print(z.tolist())
main_print("="*64)
h = reduce(z, edge_index, dst_size)
sync_print(h.tolist())
main_print("="*64)
h.sum().backward()
sync_print(x.grad.tolist())
main_print("="*64)
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import os
from torch import Tensor
from typing import *
from starrygl.parallel import init_process_group, all_gather_remote_objects
from starrygl.utils.printer import main_print, sync_print
class A:
def __init__(self, device) -> None:
self.id = rpc.get_worker_info().id
self.data = torch.tensor([self.id], device=device)
@staticmethod
def get_data(rref: rpc.RRef) -> Tensor:
self: A = rref.local_value()
return self.data
def set_remote_a(self, *rrefs: rpc.RRef):
self.rrefs = tuple(rrefs)
def display(self):
sync_print(f"id = {self.id}")
def gather_data(self, device) -> List[Tensor]:
world_size = dist.get_world_size()
futs: List[torch.futures.Future] = []
for i in range(world_size):
f = rpc.rpc_async(self.rrefs[i].owner(), A.get_data, args=(self.rrefs[i],))
futs.append(f)
results: List[Tensor] = []
for f in futs:
f.wait()
results.append(f.value())
return results
if __name__ == "__main__":
device = init_process_group("nccl")
a = A(device)
a.set_remote_a(*all_gather_remote_objects(a))
a.display()
for t in a.gather_data(device):
sync_print(t, t.device)
rpc.shutdown()
\ No newline at end of file
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.core.gather import Gather
from starrygl.parallel import init_process_group
from starrygl.parallel.sync_bn import SyncBatchNorm
from starrygl.graph import DistGraph, EID, SID, DID
from starrygl.utils.printer import main_print, sync_print
if __name__ == "__main__":
device = init_process_group(backend="nccl")
bn = SyncBatchNorm(100).to(device)
bn.train()
x = torch.randn(10, 100, dtype=torch.float64).to(device).requires_grad_()
sync_print(x)
from torch.autograd import gradcheck
sync_print(gradcheck(lambda x: bn(x), x))
# sync_print(bn(x))
-r requirements.txt
# used by docs
sphinx-autobuild
sphinx_rtd_theme
matplotlib
ipykernel
\ No newline at end of file
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.1+cu118
torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118
torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118
ogb
tqdm
\ No newline at end of file
......@@ -2,14 +2,17 @@ import torch
import logging
__version__ = "0.1.0"
try:
from .lib import libstarrygl_ops as ops
from .lib import libstarrygl as ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.")
try:
from .lib import libstarrygl_ops_sampler as sampler_ops
from .lib import libstarrygl_sampler as sampler_ops
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl_sampler.so, some features may not be available.")
\ No newline at end of file
import torch
from torch import Tensor
from contextlib import contextmanager
from typing import *
__all__ = [
"ABCStream",
"ABCEvent",
"phony_tensor",
"new_stream",
"current_stream",
"default_stream",
"use_stream",
"use_device",
"wait_stream",
"wait_event",
"record_stream",
]
class CPUStreamType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
class CPUEventType:
def __init__(self) -> None:
self._device = torch.device("cpu")
@property
def device(self):
return self._device
def __call__(self):
return self
CPUStream = CPUStreamType()
ABCStream = Union[torch.cuda.Stream, CPUStreamType]
CPUEvent = CPUEventType()
ABCEvent = Union[torch.cuda.Event, CPUEventType]
def new_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.Stream(device)
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def phony_tensor(device: Any, requires_grad: bool = True):
device = torch.device(device)
key = (device, requires_grad)
if key not in _phonies:
with use_stream(default_stream(device)):
_phonies[key] = torch.empty(
0, device=device,
requires_grad=requires_grad,
)
return _phonies[key]
def current_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.current_stream(device)
def default_stream(device: Any) -> ABCStream:
device = torch.device(device)
if device.type != "cuda":
return CPUStream()
return torch.cuda.default_stream(device)
@contextmanager
def use_stream(stream: ABCStream, fence_event: bool = False):
if isinstance(stream, CPUStreamType):
if fence_event:
event = CPUEvent()
yield event
else:
yield
return
with torch.cuda.stream(stream):
if fence_event:
event = torch.cuda.Event()
yield event
event.record()
else:
yield
@contextmanager
def use_device(device: Any):
device = torch.device(device)
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
def wait_stream(source: ABCStream, target: ABCStream):
if isinstance(target, CPUStreamType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_stream(target)
def wait_event(source: ABCStream, target: ABCEvent):
if isinstance(target, CPUEventType):
return
if isinstance(source, CPUStreamType):
target.synchronize()
else:
source.wait_event(target)
def record_stream(tensor: Tensor, stream: ABCStream):
if isinstance(stream, CPUStreamType):
return
storage = tensor.untyped_storage()
tensor = tensor.new_empty(0).set_(storage)
tensor.record_stream(stream)
......@@ -58,7 +58,6 @@ class DistributedContext:
master_port = int(os.environ["MASTER_PORT"])
ccl_init_url = f"tcp://{master_addr}:{master_port}"
rpc_init_url = f"tcp://{master_addr}:{master_port + 1}"
ctx = DistributedContext(
backend=backend,
ccl_init_method=ccl_init_url,
......@@ -69,6 +68,7 @@ class DistributedContext:
use_gpu=use_gpu,
rpc_gpu=rpc_gpu,
)
_set_default_dist_context(ctx)
return ctx
......
......@@ -16,7 +16,10 @@ class TensorAccessor:
self._data = data
self._ctx = DistributedContext.get_default_context()
if self._ctx._use_rpc is True:
self._rref = rpc.RRef(data)
else:
self._rref = None
self.stream = torch.cuda.Stream()
@property
......@@ -141,6 +144,7 @@ class DistIndex:
class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
if self.accessor.rref is not None:
self.rrefs = self.accessor.all_gather_rrefs()
local_sizes = []
......@@ -149,11 +153,21 @@ class DistributedTensor:
local_sizes.append(n)
self._num_nodes: int = sum(local_sizes)
self._num_part_nodes: Tuple[int,...] = tuple(int(s) for s in local_sizes)
else:
self.rrefs = None
self._num_nodes: int = dist.get_world_size()
self._num_part_nodes:List = [torch.tensor(data.size(0),device = data.device) for _ in range(self._num_nodes)]
dist.all_gather(self._num_part_nodes,torch.tensor(data.size(0),device = data.device))
self._num_nodes = sum(self._num_part_nodes)
self._part_id: int = self.accessor.ctx.rank
self._num_parts: int = self.accessor.ctx.world_size
@property
def shape(self):
return self.accessor.data.shape
@property
def dtype(self):
return self.accessor.data.dtype
......@@ -166,7 +180,7 @@ class DistributedTensor:
return self._num_nodes
@property
def num_part_nodes(self) -> tuple[int,...]:
def num_part_nodes(self):# -> tuple[int,...]:
return self._num_part_nodes
@property
......@@ -190,7 +204,7 @@ class DistributedTensor:
def all_to_all_ind2ptr(self, dist_index: Union[Tensor, DistIndex],group = None) -> Dict[str, Union[List[int], Tensor]]:
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts())
send_ptr = torch.ops.torch_sparse.ind2ptr(dist_index.part, self.num_parts)
send_sizes = send_ptr[1:] - send_ptr[:-1]
recv_sizes = torch.empty_like(send_sizes)
......@@ -202,7 +216,7 @@ class DistributedTensor:
send_ptr = send_ptr.tolist()
recv_ptr = recv_ptr.tolist()
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=dist_index.device)
recv_ind = torch.full((recv_ptr[-1],), (2**62-1)*2+1, dtype=dist_index.dtype, device=self.device)
all_to_all_s(recv_ind, dist_index.loc, recv_ptr, send_ptr,group=group)
return {
......@@ -225,7 +239,7 @@ class DistributedTensor:
recv_ind = dist_dict["recv_ind"]
data = self.accessor.data[recv_ind]
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
recv = torch.empty(send_ptr[-1], *data.shape[1:], dtype=data.dtype, device=self.device)
all_to_all_s(recv, data, send_ptr, recv_ptr,group=group)
return recv
......
......@@ -32,8 +32,8 @@ class EdgePredictor(torch.nn.Module):
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[:num_edge]) #
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2]) #
h_neg_src = self.src_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
......
......@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
......
from .convs import GCNConv, GATConv, GINConv
# from .convs import ShrinkGCNConv, ShrinkGATConv, ShrinkGINConv
# from .convs import ShrinkHelper
# from .basic_gnn import BasicGNN, BasicLayerOptions, BasicInputOptions, BasicStraightOptions
# class ShrinkGCN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGCNConv(in_channels, out_channels, **kwargs)
# class ShrinkGAT(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGATConv(in_channels, out_channels, **kwargs)
# class ShrinkGIN(BasicGNN):
# def init_conv(self, in_channels: int, out_channels: int, **kwargs):
# return ShrinkGINConv(in_channels, out_channels, **kwargs)
\ No newline at end of file
# import torch
# import torch.nn as nn
# import torch.distributed as dist
# from torch import Tensor
# from typing import *
# from starrygl.loader import BatchHandle
# class BaseLayer(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor:
# return x
# def update_forward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat()
# with torch.no_grad():
# x = self.forward(x, edge_index, edge_attr)
# handle.update_feat(x)
# def block_backward(self, handle: BatchHandle, edge_index: Tensor, edge_attr: Optional[Tensor] = None):
# x = handle.fetch_feat().requires_grad_()
# g = handle.fetch_grad()
# self.forward(x, edge_index, edge_attr).backward(g)
# handle.accumulate_grad(x.grad)
# x.grad = None
# def all_reduce_grad(self):
# for p in self.parameters():
# if p.grad is not None:
# dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# class BaseModel(nn.Module):
# def __init__(self,
# num_features: int,
# layers: List[int],
# prev_layer: bool = False,
# post_layer: bool = False,
# ) -> None:
# super().__init__()
# def init_prev_layer(self) -> Tensor:
# pass
# def init_post_layer(self) -> Tensor:
# pass
# def init_conv_layer(self) -> Tensor:
# pass
\ No newline at end of file
from .gcn_conv import GCNConv
from .gat_conv import GATConv
from .gin_conv import GINConv
# from .shrink_gcn_conv import ShrinkGCNConv
# from .shrink_gat_conv import ShrinkGATConv
# from .shrink_gin_conv import ShrinkGINConv
# from .utils import ShrinkHelper
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
H, C = self.heads, self.out_channels
x = g.ndata["x"]
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
edge_attr = g.edata["edge_attr"]
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GCNConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g: DistGraph) -> Tensor:
x = g.ndata["x"]
gcn_norm = g.edata["gcn_norm"].view(-1, 1)
edge_index = g.edge_index
x = x @ self.weight
x = x[edge_index[0]] * gcn_norm
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.bias is not None:
x += self.bias
return x
import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GINConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
mlp_channels: Optional[int] = None,
eps: float = 0,
train_eps: bool = False,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mlp_channels is None:
mlp_channels = in_channels + out_channels
self.mlp_channels = mlp_channels
self.initial_eps = eps
self.nn = nn.Sequential(
nn.Linear(in_channels, mlp_channels),
nn.ReLU(),
nn.Linear(mlp_channels, out_channels),
)
if train_eps:
self.eps = torch.nn.Parameter(torch.tensor[eps])
else:
self.register_buffer("eps", torch.tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
for name, param in self.nn.named_parameters():
if "weight" in name:
nn.init.xavier_normal_(param)
if "bias" in name:
nn.init.zeros_(param)
nn.init.constant_(self.eps, self.initial_eps)
def forward(self, g: DistGraph) -> Tuple[Tensor, Tensor]:
x = g.ndata["x"]
edge_index = g.edge_index
z = x[edge_index[0]]
z = scatter_sum(z, edge_index[1], dim=0, dim_size=g.dst_size)
x = z + (1 + self.eps) * x[:g.dst_size]
return self.nn(x)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_scatter import scatter_sum
from torch import Tensor
from typing import *
from starrygl.graph import DistGraph
class GATConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
bias: bool = True,
**kwargs
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.edge_dim = edge_dim
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att_src = nn.Parameter(torch.Tensor(1, heads, out_channels))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, out_channels))
if edge_dim is not None:
self.lin_edge = nn.Parameter(torch.Tensor(edge_dim, heads * out_channels))
self.att_edge = nn.Parameter(torch.Tensor(1, heads, out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.xavier_normal_(self.att_src)
nn.init.xavier_normal_(self.att_dst)
if self.edge_dim is not None:
nn.init.xavier_normal_(self.lin_edge)
nn.init.xavier_normal_(self.att_edge)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, g, x: Tensor, edge_attr: Optional[Tensor] = None):
H, C = self.heads, self.out_channels
edge_index = g.edge_index
x = (x @ self.weight).view(-1, H, C)
alpha_j = (x * self.att_src).sum(dim=-1)
alpha_j = alpha_j[edge_index[0]]
alpha_i = (x * self.att_dst).sum(dim=-1)
alpha_i = alpha_i[edge_index[1]]
if self.edge_dim is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
e = (edge_attr @ self.lin_edge).view(-1, H, C)
alpha_e = (e * self.att_edge).sum(dim=-1)
alpha = alpha_i + alpha_j + alpha_e
else:
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(
src=alpha,
index=edge_index[1],
num_nodes=g.dst_size,
)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
x = x[edge_index[0]] * alpha.view(-1, H, 1)
x = scatter_sum(x, edge_index[1], dim=0, dim_size=g.dst_size)
if self.concat:
x = x.view(-1, H * C)
else:
x = x.mean(dim=1)
if self.bias is not None:
x += self.bias
return x
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed. Click to expand it.
ldg_partition @ da31e8bb
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment