Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion docs/examples/model/graph_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,28 @@
"print(f\"Shortest path: {path}, Length: {length}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Branches on the Shortest Path\n",
"\n",
"`Grid.iter_branches_in_shortest_path` walks the same nodes returned by `get_shortest_path` but exposes the actual `BranchArray` records for each edge. Iterate the result to inspect branch IDs, statuses, or any other metadata without recomputing the path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from power_grid_model_ds import Grid\n",
"\n",
"grid = Grid.from_txt(\"S1 101\", \"101 102\", \"102 103\")\n",
"for branch in grid.iter_branches_in_shortest_path(101, 103):\n",
" print(f\"Branch {branch.id.item()} runs {branch.from_node.item()} → {branch.to_node.item()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -181,7 +203,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": ".venv (3.12.6)",
"language": "python",
"name": "python3"
},
Expand Down
42 changes: 41 additions & 1 deletion src/power_grid_model_ds/_core/model/grids/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: MPL-2.0

import dataclasses
from collections.abc import Iterator
from itertools import pairwise
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -12,8 +14,9 @@
from power_grid_model_ds._core.model.arrays.base.array import FancyArray
from power_grid_model_ds._core.model.arrays.base.errors import RecordDoesNotExist
from power_grid_model_ds._core.model.enums.nodes import NodeType
from power_grid_model_ds._core.model.graphs.errors import MissingBranchError
from power_grid_model_ds._core.utils.misc import find_diff_masks_with_equal_nan
from power_grid_model_ds.arrays import BranchArray
from power_grid_model_ds.arrays import Branch3Array, BranchArray

if TYPE_CHECKING:
from power_grid_model_ds._core.model.grids.base import Grid
Expand Down Expand Up @@ -75,6 +78,43 @@ def get_downstream_nodes(grid: "Grid", node_id: int, inclusive: bool = False):
)


def iter_branches_in_shortest_path(
grid: "Grid", from_node_id: int, to_node_id: int
) -> Iterator[BranchArray | Branch3Array]:
"""See Grid.iter_branches_in_shortest_path()."""

path, _ = grid.graphs.active_graph.get_shortest_path(from_node_id, to_node_id)

for current_node, next_node in pairwise(path):
branches = _get_branches(grid, current_node, next_node)
if branches.size == 0:
raise MissingBranchError(
f"No active branch connects nodes {current_node} -> {next_node} even though a path exists."
)
branch_ids = branches.id.tolist()
try:
typed_branches = grid.get_typed_branches(branch_ids)
except RecordDoesNotExist:
typed_branches = grid.three_winding_transformer.filter(branch_ids)
yield typed_branches


def _get_branches(grid: "Grid", from_node: int, to_node: int) -> BranchArray:
"""Return active branch records and an index filtered to the requested path nodes."""

active_branches = grid.branches.filter(from_status=1, to_status=1).filter(
from_node=from_node, to_node=to_node, mode_="AND"
)
if grid.three_winding_transformer.size:
three_winding_active = grid.three_winding_transformer.as_branches().filter(
from_status=1, to_status=1, from_node=from_node, to_node=to_node, mode_="AND"
)
if three_winding_active.size:
active_branches = fp.concatenate(active_branches, three_winding_active)

return active_branches


def find_differences_between_grids(
grid1: "Grid", grid2: "Grid", print_diff: bool = False
) -> dict[str, dict[str, object]]:
Expand Down
21 changes: 21 additions & 0 deletions src/power_grid_model_ds/_core/model/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Base grid classes"""

import warnings
from collections.abc import Iterator
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Literal, Self, TypeVar, overload
Expand Down Expand Up @@ -45,6 +46,7 @@
get_downstream_nodes,
get_nearest_substation_node,
get_typed_branches,
iter_branches_in_shortest_path,
)
from power_grid_model_ds._core.model.grids.serialization.json import deserialize_from_json, serialize_to_json
from power_grid_model_ds._core.model.grids.serialization.pickle import load_grid_from_pickle, save_grid_to_pickle
Expand Down Expand Up @@ -386,6 +388,25 @@ def get_branches_in_path(self, nodes_in_path: list[int]) -> BranchArray:
"""
return self.branches.filter(from_node=nodes_in_path, to_node=nodes_in_path, from_status=1, to_status=1)

def iter_branches_in_shortest_path(
self, from_node_id: int, to_node_id: int
) -> Iterator[BranchArray | Branch3Array]:
"""Returns the ordered active branches that form the shortest path between two nodes. When parallel active edges
are in the path all these branches will be returned for the same from_node and to_node.

Args:
from_node_id (int): External id of the path start node.
to_node_id (int): External id of the path end node.

Yields:
BranchArray: branch arrays for each active branch on the path.

Raises:
MissingBranchError: If the graph reports an edge on the shortest path but no active branch is found.
"""

return iter_branches_in_shortest_path(self, from_node_id, to_node_id)

def get_nearest_substation_node(self, node_id: int):
"""Find the nearest substation node.

Expand Down
3 changes: 3 additions & 0 deletions src/power_grid_model_ds/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,21 @@ class Branch3Array(IdArray, Branch3):
def as_branches(self) -> BranchArray:
"""Convert Branch3Array to BranchArray."""
branches_1_2 = BranchArray.empty(self.size)
branches_1_2.id = self.id
branches_1_2.from_node = self.node_1
branches_1_2.to_node = self.node_2
branches_1_2.from_status = self.status_1
branches_1_2.to_status = self.status_2

branches_1_3 = BranchArray.empty(self.size)
branches_1_3.id = self.id
branches_1_3.from_node = self.node_1
branches_1_3.to_node = self.node_3
branches_1_3.from_status = self.status_1
branches_1_3.to_status = self.status_3

branches_2_3 = BranchArray.empty(self.size)
branches_2_3.id = self.id
branches_2_3.from_node = self.node_2
branches_2_3.to_node = self.node_3
branches_2_3.from_status = self.status_2
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/model/grids/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def test_get_branches_in_path_empty_path(self, basic_grid):
assert branches.size == 0


class TestIterBranchesInShortestPath:
def test_iter_branches_in_shortest_path(self, basic_grid):
branches = list(basic_grid.iter_branches_in_shortest_path(101, 106))
assert branches == [basic_grid.line.filter(id=201), basic_grid.transformer.filter(id=301)]

def test_iter_branches_same_node_returns_empty(self, basic_grid):
assert list(basic_grid.iter_branches_in_shortest_path(101, 101)) == []

def test_iter_branches_in_shortest_path_three_winding_transformer_typed(self, grid_with_3wt):
branches = list(grid_with_3wt.iter_branches_in_shortest_path(101, 104))
assert branches == [grid_with_3wt.three_winding_transformer.filter(id=301), grid_with_3wt.line.filter(id=201)]


Comment on lines +87 to +88
Copy link

Copilot AI Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding test coverage for the MissingBranchError case that would be raised at line 142 of _search.py. This would happen when the graph reports a path but no active branch exists between consecutive nodes, which could occur if the graph state is inconsistent with the branch data. A test would help ensure this error is raised with a clear message in such edge cases.

Suggested change
def test_iter_branches_in_shortest_path_missing_branch_raises(self, basic_grid, monkeypatch):
# Simulate an inconsistent state where the graph reports a path but
# there is no active branch between consecutive nodes in that path.
def fake_get_shortest_path(*args, **kwargs):
# Return a path with a node pair that has no active branch
return [101, 999], 1
monkeypatch.object(basic_grid.graphs.active_graph, "get_shortest_path", fake_get_shortest_path)
with pytest.raises(Exception):
list(basic_grid.iter_branches_in_shortest_path(101, 999))

Copilot uses AI. Check for mistakes.
def test_component_three_winding_transformer(grid_with_3wt):
substation_nodes = grid_with_3wt.node.filter(node_type=NodeType.SUBSTATION_NODE.value).id
with grid_with_3wt.graphs.active_graph.tmp_remove_nodes(substation_nodes):
Expand Down
Loading