diff --git a/docs/examples/model/graph_examples.ipynb b/docs/examples/model/graph_examples.ipynb index 832ff5da..2dfb887f 100644 --- a/docs/examples/model/graph_examples.ipynb +++ b/docs/examples/model/graph_examples.ipynb @@ -105,6 +105,28 @@ "print(f\"Shortest path: {path}, Length: {length}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Branches on the Shortest Path\n", + "\n", + "`Grid.iter_branches_in_shortest_path` walks the same nodes returned by `get_shortest_path` but exposes the actual `BranchArray` records for each edge. Iterate the result to inspect branch IDs, statuses, or any other metadata without recomputing the path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from power_grid_model_ds import Grid\n", + "\n", + "grid = Grid.from_txt(\"S1 101\", \"101 102\", \"102 103\")\n", + "for branch in grid.iter_branches_in_shortest_path(101, 103):\n", + " print(f\"Branch {branch.id.item()} runs {branch.from_node.item()} → {branch.to_node.item()}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -181,7 +203,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": ".venv (3.12.6)", "language": "python", "name": "python3" }, diff --git a/src/power_grid_model_ds/_core/model/grids/_search.py b/src/power_grid_model_ds/_core/model/grids/_search.py index fdfd14b5..da49c3fd 100644 --- a/src/power_grid_model_ds/_core/model/grids/_search.py +++ b/src/power_grid_model_ds/_core/model/grids/_search.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: MPL-2.0 import dataclasses +from collections.abc import Iterator +from itertools import pairwise from typing import TYPE_CHECKING import numpy as np @@ -12,8 +14,9 @@ from power_grid_model_ds._core.model.arrays.base.array import FancyArray from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist from power_grid_model_ds._core.model.enums.nodes import NodeType +from power_grid_model_ds._core.model.graphs.errors import MissingBranchError from power_grid_model_ds._core.utils.misc import find_diff_masks_with_equal_nan -from power_grid_model_ds.arrays import BranchArray +from power_grid_model_ds.arrays import Branch3Array, BranchArray if TYPE_CHECKING: from power_grid_model_ds._core.model.grids.base import Grid @@ -75,6 +78,43 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False): ) +def iter_branches_in_shortest_path( + grid: "Grid", from_node_id: int, to_node_id: int +) -> Iterator[BranchArray | Branch3Array]: + """See Grid.iter_branches_in_shortest_path().""" + + path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id) + + for current_node, next_node in pairwise(path): + branches = _get_branches(grid, current_node, next_node) + if branches.size == 0: + raise MissingBranchError( + f"No active branch connects nodes {current_node} -> {next_node} even though a path exists." + ) + branch_ids = branches.id.tolist() + try: + typed_branches = grid.get_typed_branches(branch_ids) + except RecordDoesNotExist: + typed_branches = grid.three_winding_transformer.filter(branch_ids) + yield typed_branches + + +def _get_branches(grid: "Grid", from_node: int, to_node: int) -> BranchArray: + """Return active branch records and an index filtered to the requested path nodes.""" + + active_branches = grid.branches.filter(from_status=1, to_status=1).filter( + from_node=from_node, to_node=to_node, mode_="AND" + ) + if grid.three_winding_transformer.size: + three_winding_active = grid.three_winding_transformer.as_branches().filter( + from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND" + ) + if three_winding_active.size: + active_branches = fp.concatenate(active_branches, three_winding_active) + + return active_branches + + def find_differences_between_grids( grid1: "Grid", grid2: "Grid", print_diff: bool = False ) -> dict[str, dict[str, object]]: diff --git a/src/power_grid_model_ds/_core/model/grids/base.py b/src/power_grid_model_ds/_core/model/grids/base.py index d008acd2..9d9254dd 100644 --- a/src/power_grid_model_ds/_core/model/grids/base.py +++ b/src/power_grid_model_ds/_core/model/grids/base.py @@ -5,6 +5,7 @@ """Base grid classes""" import warnings +from collections.abc import Iterator from dataclasses import dataclass, fields from pathlib import Path from typing import Literal, Self, TypeVar, overload @@ -45,6 +46,7 @@ get_downstream_nodes, get_nearest_substation_node, get_typed_branches, + iter_branches_in_shortest_path, ) from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle @@ -386,6 +388,25 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray: """ return self.branches.filter(from_node=nodes_in_path, to_node=nodes_in_path, from_status=1, to_status=1) + def iter_branches_in_shortest_path( + self, from_node_id: int, to_node_id: int + ) -> Iterator[BranchArray | Branch3Array]: + """Returns the ordered active branches that form the shortest path between two nodes. When parallel active edges + are in the path all these branches will be returned for the same from_node and to_node. + + Args: + from_node_id (int): External id of the path start node. + to_node_id (int): External id of the path end node. + + Yields: + BranchArray: branch arrays for each active branch on the path. + + Raises: + MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found. + """ + + return iter_branches_in_shortest_path(self, from_node_id, to_node_id) + def get_nearest_substation_node(self, node_id: int): """Find the nearest substation node. diff --git a/src/power_grid_model_ds/arrays.py b/src/power_grid_model_ds/arrays.py index 74e98acc..180d3960 100644 --- a/src/power_grid_model_ds/arrays.py +++ b/src/power_grid_model_ds/arrays.py @@ -145,18 +145,21 @@ class Branch3Array(IdArray, Branch3): def as_branches(self) -> BranchArray: """Convert Branch3Array to BranchArray.""" branches_1_2 = BranchArray.empty(self.size) + branches_1_2.id = self.id branches_1_2.from_node = self.node_1 branches_1_2.to_node = self.node_2 branches_1_2.from_status = self.status_1 branches_1_2.to_status = self.status_2 branches_1_3 = BranchArray.empty(self.size) + branches_1_3.id = self.id branches_1_3.from_node = self.node_1 branches_1_3.to_node = self.node_3 branches_1_3.from_status = self.status_1 branches_1_3.to_status = self.status_3 branches_2_3 = BranchArray.empty(self.size) + branches_2_3.id = self.id branches_2_3.from_node = self.node_2 branches_2_3.to_node = self.node_3 branches_2_3.from_status = self.status_2 diff --git a/tests/unit/model/grids/test_search.py b/tests/unit/model/grids/test_search.py index f6372b70..c0c92793 100644 --- a/tests/unit/model/grids/test_search.py +++ b/tests/unit/model/grids/test_search.py @@ -73,6 +73,19 @@ def test_get_branches_in_path_empty_path(self, basic_grid): assert branches.size == 0 +class TestIterBranchesInShortestPath: + def test_iter_branches_in_shortest_path(self, basic_grid): + branches = list(basic_grid.iter_branches_in_shortest_path(101, 106)) + assert branches == [basic_grid.line.filter(id=201), basic_grid.transformer.filter(id=301)] + + def test_iter_branches_same_node_returns_empty(self, basic_grid): + assert list(basic_grid.iter_branches_in_shortest_path(101, 101)) == [] + + def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt): + branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104)) + assert branches == [grid_with_3wt.three_winding_transformer.filter(id=301), grid_with_3wt.line.filter(id=201)] + + def test_component_three_winding_transformer(grid_with_3wt): substation_nodes = grid_with_3wt.node.filter(node_type=NodeType.SUBSTATION_NODE.value).id with grid_with_3wt.graphs.active_graph.tmp_remove_nodes(substation_nodes):