Commit 10c38111 by Wenjie Huang

add GraphData & metis & mt-metis

parent b84f3d4c
......@@ -159,6 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
/cora
/dataset
/test_*
/*.ipynb
......
......@@ -3,6 +3,6 @@
"cmake.configureSettings": {
"CMAKE_PREFIX_PATH": "/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages",
"Python3_ROOT_DIR": "/home/hwj/.miniconda3/envs/sgl",
"CUDA_TOOLKIT_ROOT_DIR": "/home/hwj/.local/cuda-11.7"
"CUDA_TOOLKIT_ROOT_DIR": "/home/hwj/.local/cuda-11.8"
},
}
\ No newline at end of file
......@@ -14,6 +14,13 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
......@@ -22,9 +29,16 @@ endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm SHARED ${UVM_SRCS})
target_link_libraries(uvm PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
......@@ -33,13 +47,17 @@ if(WITH_METIS)
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
set(GKLIB_LIBRARIES "${GKLIB_DIR}/lib")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
set(METIS_LIBRARIES "${METIS_DIR}/lib")
file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
include_directories(${METIS_INCLUDE_DIRS})
link_libraries(${METIS_LIBRARIES})
add_library(metis SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
......@@ -47,33 +65,38 @@ if(WITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
set(MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
link_libraries(${MTMETIS_LIBRARIES})
add_library(mtmetis SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
file(GLOB_RECURSE METIS_SRCS "csrc/uvm/*.cpp")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp ${UVM_SRCS} ${METIS_SRCS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm)
endif()
# set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" OUTPUT_NAME "_C")
# install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_SOURCE_DIR}/starrygl/lib")
if (WITH_METIS)
target_link_libraries(${PROJECT_NAME} PRIVATE metis)
endif()
if (WITH_MTMETIS)
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis)
endif()
......@@ -4,7 +4,10 @@ from torch_geometric.utils import add_remaining_self_loops, to_undirected
import os.path as osp
import sys
from starrygl.utils.partition import partition_save
from starrygl.utils.data import partition_pyg
import logging
logging.getLogger().setLevel(logging.INFO)
if __name__ == "__main__":
data = Planetoid("/mnt/nfs/hwj/pyg_datasets/Planetoid/Cora", "Cora")[0]
......@@ -17,11 +20,11 @@ if __name__ == "__main__":
print(f"num_features: {data.num_features}")
num_parts_list = [1, 2, 3, 5, 7, 9, 11]
algos = ["metis", "random"]
algos = ["metis", 'mt-metis', "random"]
root = osp.splitext(osp.abspath(__file__))[0]
print(f"root: {root}")
for num_parts in num_parts_list:
for algo in algos:
print(f"======== {num_parts} + {algo} ========")
partition_save(root, data, num_parts, algo)
\ No newline at end of file
partition_pyg(root, data, num_parts, algo)
\ No newline at end of file
#include "extension.h"
#include "uvm.h"
#include "partition.h"
torch::Tensor add(torch::Tensor a, torch::Tensor b) {
return a + b;
......@@ -13,6 +14,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("uvm_storage_advise", &uvm_storage_advise, "apply cudaMemAdvise() to uvm storage");
m.def("uvm_storage_prefetch", &uvm_storage_prefetch, "apply cudaMemPrefetchAsync() to uvm storage");
m.def("metis_partition", &metis_partition, "metis graph partition");
m.def("mt_metis_partition", &mt_metis_partition, "multi-threaded metis graph partition");
py::enum_<cudaMemoryAdvise>(m, "cudaMemoryAdvise")
.value("cudaMemAdviseSetAccessedBy", cudaMemoryAdvise::cudaMemAdviseSetAccessedBy)
.value("cudaMemAdviseUnsetAccessedBy", cudaMemoryAdvise::cudaMemAdviseUnsetAccessedBy)
......
#pragma once
#include "extension.h"
at::Tensor metis_partition(
at::Tensor rowptr,
at::Tensor col,
at::optional<at::Tensor> opt_value,
at::optional<at::Tensor> opt_vtx_w,
at::optional<at::Tensor> opt_vtx_s,
int64_t num_parts,
bool recursive,
bool min_edge_cut
);
at::Tensor mt_metis_partition(
at::Tensor rowptr,
at::Tensor col,
at::optional<at::Tensor> opt_value,
at::optional<at::Tensor> opt_vtx_w,
int64_t num_parts,
int64_t num_workers,
bool recursive
);
\ No newline at end of file
#pragma once
#include "extension.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
\ No newline at end of file
#include <torch/all.h>
#include <metis.h>
#include <mtmetis.h>
// at::Tensor metis_partition(
// at::Tensor rowptr,
// at::Tensor col,
// at::optional<at::Tensor> optional_value,
// ) {
// }
// at::Tensor metis_mt_partition() {
// }
\ No newline at end of file
#include <torch/all.h>
#include <metis.h>
#include <mtmetis.h>
#include "utils.h"
at::Tensor metis_partition(
at::Tensor rowptr,
at::Tensor col,
at::optional<at::Tensor> opt_value,
at::optional<at::Tensor> opt_vtx_w,
at::optional<at::Tensor> opt_vtx_s,
int64_t num_parts,
bool recursive,
bool min_edge_cut
) {
static_assert(sizeof(idx_t) == sizeof(int64_t));
CHECK_CPU(rowptr);
AT_ASSERT(rowptr.dim() == 1);
AT_ASSERT(rowptr.is_contiguous());
CHECK_CPU(col);
AT_ASSERT(col.dim() == 1);
AT_ASSERT(col.is_contiguous());
idx_t nvtxs = rowptr.numel() - 1;
idx_t ncon = 1;
idx_t *xadj = rowptr.data_ptr<idx_t>();
idx_t *adjncy = col.data_ptr<idx_t>();
// edge weights
idx_t *adjwgt = nullptr;
if (opt_value.has_value()) {
CHECK_CPU(opt_value.value());
AT_ASSERT(opt_value.value().dim() == 1);
AT_ASSERT(opt_value.value().numel() == col.numel());
AT_ASSERT(opt_value.value().is_contiguous());
adjwgt = opt_value.value().data_ptr<idx_t>();
}
// node weights
idx_t *vwgt = nullptr;
if (opt_vtx_w.has_value()) {
CHECK_CPU(opt_vtx_w.value());
AT_ASSERT(opt_vtx_w.value().dim() <= 2);
AT_ASSERT(opt_vtx_w.value().size(0) == nvtxs);
AT_ASSERT(opt_vtx_w.value().is_contiguous());
vwgt = opt_vtx_w.value().data_ptr<idx_t>();
if (opt_vtx_w.value().dim() == 2) {
ncon = opt_vtx_w.value().size(1);
}
}
idx_t *vsize = nullptr;
if (opt_vtx_s.has_value()) {
CHECK_CPU(opt_vtx_s.value());
AT_ASSERT(opt_vtx_s.value().dim() == 1);
AT_ASSERT(opt_vtx_s.value().numel() == nvtxs);
AT_ASSERT(opt_vtx_s.value().is_contiguous());
vsize = opt_vtx_s.value().data_ptr<idx_t>();
}
idx_t nparts = num_parts;
idx_t objval = -1;
auto part = at::empty({nvtxs}, rowptr.options());
idx_t *part_data = part.data_ptr<idx_t>();
idx_t options[METIS_NOPTIONS];
METIS_SetDefaultOptions(options);
if (min_edge_cut) {
options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_CUT;
} else {
options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;
}
int ret;
if (recursive) {
ret = METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, vsize, adjwgt,
&nparts, NULL, NULL, options, &objval, part_data);
} else {
ret = METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, vsize, adjwgt,
&nparts, NULL, NULL, options, &objval, part_data);
}
AT_ASSERT(ret == METIS_OK);
return part;
}
#include <torch/all.h>
#include <metis.h>
#include <mtmetis.h>
#include "utils.h"
at::Tensor mt_metis_partition(
at::Tensor rowptr,
at::Tensor col,
at::optional<at::Tensor> opt_value,
at::optional<at::Tensor> opt_vtx_w,
int64_t num_parts,
int64_t num_workers,
bool recursive
) {
static_assert(sizeof(mtmetis_vtx_type) == sizeof(int64_t));
static_assert(sizeof(mtmetis_adj_type) == sizeof(int64_t));
static_assert(sizeof(mtmetis_wgt_type) == sizeof(int64_t));
static_assert(sizeof(mtmetis_pid_type) == sizeof(int64_t));
CHECK_CPU(rowptr);
AT_ASSERT(rowptr.dim() == 1);
AT_ASSERT(rowptr.is_contiguous());
CHECK_CPU(col);
AT_ASSERT(col.dim() == 1);
AT_ASSERT(col.is_contiguous());
mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
mtmetis_vtx_type ncon = 1;
mtmetis_adj_type *xadj = reinterpret_cast<mtmetis_adj_type*>(rowptr.data_ptr<int64_t>());
mtmetis_vtx_type *adjncy = reinterpret_cast<mtmetis_vtx_type*>(col.data_ptr<int64_t>());
// edge weights
mtmetis_wgt_type *adjwgt = NULL;
if (opt_value.has_value()) {
CHECK_CPU(opt_value.value());
AT_ASSERT(opt_value.value().dim() == 1);
AT_ASSERT(opt_value.value().numel() == col.numel());
AT_ASSERT(opt_value.value().is_contiguous());
adjwgt = reinterpret_cast<mtmetis_wgt_type*>(opt_value.value().data_ptr<int64_t>());
}
// node weights
mtmetis_wgt_type *vwgt = NULL;
if (opt_vtx_w.has_value()) {
CHECK_CPU(opt_vtx_w.value());
AT_ASSERT(opt_vtx_w.value().dim() <= 2);
AT_ASSERT(opt_vtx_w.value().size(0) == nvtxs);
AT_ASSERT(opt_vtx_w.value().is_contiguous());
vwgt = reinterpret_cast<mtmetis_wgt_type*>(opt_vtx_w.value().data_ptr<int64_t>());
if (opt_vtx_w.value().dim() == 2) {
ncon = opt_vtx_w.value().size(1);
}
}
mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1;
auto part = at::empty({static_cast<int64_t>(nvtxs)}, rowptr.options());
mtmetis_pid_type *part_data = reinterpret_cast<mtmetis_pid_type*>(part.data_ptr<int64_t>());
double *options = mtmetis_init_options();
options[MTMETIS_OPTION_NTHREADS] = num_workers;
int ret;
if (recursive) {
ret = MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, options, &objval, part_data);
} else {
ret = MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
&nparts, NULL, NULL, options, &objval, part_data);
}
AT_ASSERT(ret == MTMETIS_SUCCESS);
return part;
}
\ No newline at end of file
#!/bin/bash
mkdir -p build starrygl/lib && cd build
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.7" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \
&& make -j32 \
&& cp libstarrygl_ops.so ../starrygl/lib/
\ No newline at end of file
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
\ No newline at end of file
......@@ -4,5 +4,6 @@ import logging
try:
from .lib import libstarrygl_ops as ops
except:
except Exception as e:
logging.error(e)
logging.error("unable to import libstarrygl.so, some features may not be available.")
\ No newline at end of file
# from .functional import train_epoch, eval_epoch
from .partition import partition_load, partition_save, partition_data
# from .partition import partition_load, partition_save, partition_data
from .printer import sync_print, main_print
from .metrics import all_reduce_loss, accuracy
\ No newline at end of file
import torch
from torch import Tensor
from typing import *
import os
import os.path as osp
import shutil
from pathlib import Path
from torch_sparse import SparseTensor
from .partition import *
import logging
__all__ = [
"GraphData",
"partition_pyg",
"partition_load",
]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
num_nodes: Union[int, Dict[str, int]],
) -> None:
if isinstance(edge_indices, Tensor):
self._heterogeneous = False
edge_indices = {("#", "@", "#"): edge_indices}
num_nodes = {"#": int(num_nodes)}
else:
self._heterogeneous = True
self._num_nodes: Dict[str, int] = {}
self._node_data: Dict[str, 'NodeData'] = {}
for ntype, num in num_nodes.items():
ntype, num = str(ntype), int(num)
self._num_nodes[ntype] = num
self._node_data[ntype] = NodeData(ntype, num)
self._edge_indices: Dict[Tuple[str, str, str], Tensor] = {}
self._edge_data: Dict[Tuple[str, str, str], 'EdgeData'] = {}
for (es, et, ed), edge_index in edge_indices.items():
assert isinstance(edge_index, Tensor), f"edge_index must be a tensor, got {type(edge_index)}"
assert edge_index.dim() == 2 and edge_index.size(0) == 2
es, et, ed = str(es), str(et), str(ed)
assert es in self._num_nodes, f"unknown node type '{es}', should be one of {list(self._num_nodes.keys())}."
assert ed in self._num_nodes, f"unknown node type '{ed}', should be one of {list(self._num_nodes.keys())}."
etype = (es, et, ed)
self._edge_indices[etype] = edge_index
self._edge_data[etype] = EdgeData(etype, edge_index.size(1))
self._meta = MetaData()
def meta(self) -> 'MetaData':
return self._meta
def node(self, node_type: Optional[str] = None) -> 'NodeData':
if len(self._node_data) == 1:
for data in self._node_data.values():
return data
return self._node_data[node_type]
def edge(self, edge_type: Optional[Tuple[str, str, str]] = None) -> 'EdgeData':
if len(self._edge_data) == 1:
for data in self._edge_data.values():
return data
return self._edge_data[edge_type]
def edge_index(self, edge_type: Optional[Tuple[str, str, str]] = None) -> Tensor:
if len(self._edge_indices) == 1:
for data in self._edge_indices.values():
return data
return self._edge_indices[edge_type]
@property
def is_heterogeneous(self) -> bool:
return self._heterogeneous
def to(self, device: Any) -> 'GraphData':
self._meta.to(device)
for ndata in self._node_data.values():
ndata.to(device)
for edata in self._edge_data.values():
edata.to(device)
self._edge_indices = {k:v.to(device) for k,v in self._edge_indices.items()}
return self
class MetaData:
def __init__(self) -> None:
self._data: Dict[str, Any] = {}
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Any:
return self._data[key]
def __setitem__(self, key: str, val: Any):
assert isinstance(key, str)
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'MetaData':
for k in self.keys():
v = self._data[k]
if isinstance(v, Tensor):
self._data[k] = v.to(device)
return self
class NodeData:
def __init__(self,
node_type: str,
num_nodes: int,
) -> None:
self._node_type = str(node_type)
self._num_nodes = int(num_nodes)
self._data: Dict[str, Tensor] = {}
@property
def node_type(self) -> str:
return self._node_type
@property
def num_nodes(self) -> int:
return self._num_nodes
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Tensor):
assert isinstance(key, str)
assert val.size(0) == self._num_nodes
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'NodeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
class EdgeData:
def __init__(self,
edge_type: Tuple[str, str, str],
num_edges: int,
) -> None:
self._edge_type = tuple(str(t) for t in edge_type)
self._num_edges = num_edges
assert len(self._edge_type) == 3
self._data: Dict[str, Tensor] = {}
@property
def edge_type(self) -> Tuple[str, str, str]:
return self._edge_type
@property
def num_edges(self) -> int:
return self._num_edges
def keys(self) -> List[str]:
return list(self._data.keys())
def __getitem__(self, key: str) -> Tensor:
return self._data[key]
def __setitem__(self, key: str, val: Optional[Tensor]) -> Tensor:
assert isinstance(key, str)
assert val.size(0) == self._num_edges
self._data[key] = val
def pop(self, key: str) -> Tensor:
if key in self._data:
return self._data.pop(key)
def to(self, device: Any) -> 'EdgeData':
self._data = {k:v.to(device) for k,v in self._data.items()}
return self
def partition_load(root: str, part_id: int, num_parts: int, algo: str = "metis") -> GraphData:
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def partition_pyg(root: str, data, num_parts: int, algo: str = "metis"):
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
if base_path.exists():
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True, exist_ok=True)
for i, g in enumerate(partition_pyg_data(data, num_parts, algo)):
logging.info(f"saving partition data: {i+1}/{num_parts}")
torch.save(g, (base_path / f"{i:03d}").__str__())
def partition_pyg_data(data, num_parts: int, algo: str = "metis") -> Iterator[GraphData]:
from torch_geometric.data import Data
assert isinstance(data, Data), f"must be Data class in pyg"
logging.info(f"running partition aglorithm: {algo}")
num_nodes: int = data.num_nodes
num_edges: int = data.num_edges
edge_index: Tensor = data.edge_index
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
if data.y.dtype == torch.long:
if data.y.dim() == 1:
num_classes = data.y.max().item() + 1
else:
num_classes = data.y.size(1)
else:
num_classes = None
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
local_edges = edge_index[:, epart_mask]
raw_src_ids: Tensor = local_edges[0].unique()
raw_dst_ids: Tensor = torch.where(npart_mask)[0]
M: int = raw_src_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_src_ids)
imap[raw_src_ids] = torch.arange(raw_src_ids.numel()).type_as(raw_src_ids)
local_src = imap[local_edges[0]]
M: int = raw_dst_ids.max().item() + 1
imap = torch.full((M,), (2**62-1)*2+1).type_as(raw_dst_ids)
imap[raw_dst_ids] = torch.arange(raw_dst_ids.numel()).type_as(raw_dst_ids)
local_dst = imap[local_edges[1]]
local_edges = torch.vstack([local_src, local_dst])
g = GraphData(
edge_indices={
("src", "@", "dst"): local_edges,
},
num_nodes={
"src": raw_src_ids.numel(),
"dst": raw_dst_ids.numel(),
},
)
g.node("src")["raw_ids"] = raw_src_ids
g.node("dst")["raw_ids"] = raw_dst_ids
if num_classes is not None:
g.meta()["num_classes"] = num_classes
for key, val in data:
if key == "edge_index":
continue
elif isinstance(val, Tensor):
if val.size(0) == num_nodes:
g.node("dst")[key] = val[npart_mask]
elif val.size(0) == num_edges:
g.edge()[key] = val[epart_mask]
elif isinstance(val, SparseTensor):
pass
else:
g.meta()[key] = val
yield g
import torch
import torch.distributed as dist
import starrygl
from torch import Tensor, LongTensor
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import os
import os.path as osp
import shutil
from typing import *
def partition_save(root: str, data: Data, num_parts: int, algo: str = "metis"):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
for i, pdata in enumerate(partition_data(data, num_parts, algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_data(data: Data, num_parts: int, algo: str, verbose: bool = False) -> List[Data]:
if algo == "metis":
part_fn = metis_partition
elif algo == "random":
part_fn = random_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose: print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose: print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose: print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:,epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
}
if num_classes is not None:
pdata["num_classes"] = num_classes
__all__ = [
"metis_partition",
"mt_metis_partition",
"random_partition",
]
def _nopart(edge_index: Tensor, num_nodes: int):
return torch.zeros(num_nodes).type_as(edge_index)
def metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
node_sizes: Optional[Tensor] = None,
recursive: bool = False,
min_edge_cut: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
for key, val in data:
if key == "edge_index":
continue
if isinstance(val, Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def compute_gcn_norm(edge_index: LongTensor, num_nodes: int) -> Tensor:
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def _nopart(edge_index: LongTensor, num_nodes: int) -> Tuple[LongTensor, LongTensor]:
node_parts = torch.zeros(num_nodes, dtype=torch.long)
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
def metis_partition(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.metis_partition(
rowptr, col, value, node_weight, node_sizes, num_parts, recursive, min_edge_cut)
return node_parts
def mt_metis_partition(
edge_index: Tensor,
num_nodes: int,
num_parts: int,
node_weight: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
num_workers: int = 8,
recursive: bool = False,
) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)).to_symmetric()
rowptr, col, _ = adj_t.csr()
node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=(num_nodes, num_nodes))
rowptr, col, value = adj_t.coalesce().to_symmetric().csr()
node_parts = starrygl.ops.mt_metis_partition(
rowptr, col, value, node_weight, num_parts, num_workers, recursive)
return node_parts
def random_partition(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
def random_partition(edge_index: Tensor, num_nodes: int, num_parts: int) -> Tensor:
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
node_parts = torch.randint(num_parts, size=(num_nodes,), dtype=edge_index.dtype)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
\ No newline at end of file
return torch.randint(num_parts, size=(num_nodes,)).type_as(edge_index)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment