Commit 542831ac by Wenjie Huang

add constraint params to GraphData.save_partition

parent 150eafe6
...@@ -20,6 +20,9 @@ __all__ = [ ...@@ -20,6 +20,9 @@ __all__ = [
] ]
Strings = Sequence[str]
OptStrings = Optional[Strings]
class GraphData: class GraphData:
def __init__(self, def __init__(self,
edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_indices: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],
...@@ -168,33 +171,99 @@ class GraphData: ...@@ -168,33 +171,99 @@ class GraphData:
return g return g
@staticmethod @staticmethod
def load_partition(root: str, part_id: int, num_parts: int, algo: str = "metis") -> 'GraphData': def load_partition(
p = Path(root).expanduser().resolve() / f"{algo}_{num_parts}" / f"{part_id:03d}" root: str,
part_id: int,
num_parts: int,
algorithm: str = "metis",
) -> 'GraphData':
p = Path(root).expanduser().resolve() / f"{algorithm}_{num_parts}" / f"{part_id:03d}"
return torch.load(p.__str__()) return torch.load(p.__str__())
def save_partition(self, root: str, num_parts: int, algo: str = "metis"): def save_partition(self,
root: str,
num_parts: int,
node_weight: Optional[str] = None,
edge_weight: Optional[str] = None,
include_node_attrs: Optional[Sequence[str]] = None,
include_edge_attrs: Optional[Sequence[str]] = None,
include_meta_attrs: Optional[Sequence[str]] = None,
ignore_node_attrs: Optional[Sequence[str]] = None,
ignore_edge_attrs: Optional[Sequence[str]] = None,
ignore_meta_attrs: Optional[Sequence[str]] = None,
algorithm: str = "metis",
partition_kwargs = None,
):
assert not self.is_heterogeneous, "only support homomorphic graph" assert not self.is_heterogeneous, "only support homomorphic graph"
num_nodes: int = self.node().num_nodes num_nodes: int = self.node().num_nodes
edge_index: Tensor = self.edge_index() edge_index: Tensor = self.edge_index()
logging.info(f"running partition aglorithm: {algo}") logging.info(f"running partition aglorithm: {algorithm}")
if algo == "metis": partition_kwargs = partition_kwargs or {}
node_parts = metis_partition(edge_index, num_nodes, num_parts)
elif algo == "mt-metis": if node_weight is not None:
node_parts = mt_metis_partition(edge_index, num_nodes, num_parts) node_weight = self.node()[node_weight]
elif algo == "random":
node_parts = random_partition(edge_index, num_nodes, num_parts) if edge_weight is not None:
edge_weight = self.edge()[edge_weight]
if algorithm == "metis":
node_parts = metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "mt-metis":
node_parts = mt_metis_partition(
edge_index,
num_nodes, num_parts,
node_weight=node_weight,
edge_weight=edge_weight,
**partition_kwargs,
)
elif algorithm == "random":
node_parts = random_partition(
edge_index,
num_nodes, num_parts,
**partition_kwargs,
)
else: else:
raise ValueError(f"unknown partition algorithm: {algo}") raise ValueError(f"unknown partition algorithm: {algorithm}")
root_path = Path(root).expanduser().resolve() root_path = Path(root).expanduser().resolve()
base_path = root_path / f"{algo}_{num_parts}" base_path = root_path / f"{algorithm}_{num_parts}"
if base_path.exists(): if base_path.exists():
logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.") logging.warning(f"directory '{base_path.__str__()}' exists, and will be removed.")
shutil.rmtree(base_path.__str__()) shutil.rmtree(base_path.__str__())
base_path.mkdir(parents=True) base_path.mkdir(parents=True)
if include_node_attrs is None:
include_node_attrs = self.node().keys()
if include_edge_attrs is None:
include_edge_attrs = self.edge().keys()
if include_meta_attrs is None:
include_meta_attrs = self.meta().keys()
if ignore_node_attrs is None:
ignore_node_attrs = set()
else:
ignore_node_attrs = set(ignore_node_attrs)
if ignore_edge_attrs is None:
ignore_edge_attrs = set()
else:
ignore_edge_attrs = set(ignore_edge_attrs)
if ignore_meta_attrs is None:
ignore_meta_attrs = set()
else:
ignore_meta_attrs = set(ignore_meta_attrs)
for i in range(num_parts): for i in range(num_parts):
npart_mask = node_parts == i npart_mask = node_parts == i
...@@ -213,13 +282,19 @@ class GraphData: ...@@ -213,13 +282,19 @@ class GraphData:
raw_dst_ids=raw_dst_ids, raw_dst_ids=raw_dst_ids,
) )
for key in self.node().keys(): for key in include_node_attrs:
if key in ignore_node_attrs:
continue
g.node("dst")[key] = self.node()[key][npart_mask] g.node("dst")[key] = self.node()[key][npart_mask]
for key in self.edge().keys(): for key in include_edge_attrs:
if key in ignore_edge_attrs:
continue
g.edge()[key] = self.edge()[key][epart_mask] g.edge()[key] = self.edge()[key][epart_mask]
for key in self.meta().keys(): for key in include_meta_attrs:
if key in ignore_meta_attrs:
continue
g.meta()[key] = self.meta()[key] g.meta()[key] = self.meta()[key]
logging.info(f"saving partition data: {i+1}/{num_parts}") logging.info(f"saving partition data: {i+1}/{num_parts}")
......
...@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"): ...@@ -39,7 +39,7 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"):
logging.info(f"GraphData.node().keys(): {g.node().keys()}") logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}") logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, part_algo) g.save_partition(root, num_parts, algorithm=part_algo)
return g return g
class SimpleConv(pyg_nn.MessagePassing): class SimpleConv(pyg_nn.MessagePassing):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment