Commit 542831ac by Wenjie Huang

add constraint params to GraphData.save_partition

parent 150eafe6
......@@ -20,6 +20,9 @@ __all__ = [
]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData:
def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
......@@ -168,34 +171,100 @@ class GraphData:
return g
@staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}"
def load_partition(
root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"):
def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}")
if algo == "metis":
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis":
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts)
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts)
logging.info(f"running partition aglorithm: {algorithm}")
partition_kwargs = partition_kwargs or {}
if node_weight is not None:
node_weight = self.node()[node_weight]
if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
if algorithm == "metis":
node_parts = metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index,
num_nodes, num_parts,
**partition_kwargs,
)
else:
raise ValueError(f"unknown partition algorithm: {algo}")
raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}"
base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts):
npart_mask = node_parts == i
epart_mask = npart_mask[edge_index[1]]
......@@ -213,13 +282,19 @@ class GraphData:
raw_dst_ids=raw_dst_ids,
)
for key in self.node().keys():
for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys():
for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys():
for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}")
......
......@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"):
logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo)
g.save_partition(root, num_parts, algorithm=part_algo)
return g
class SimpleConv(pyg_nn.MessagePassing):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment