|
| 1 | +""" generic A-Star path searching algorithm (https://github.com/jrialland/python-astar/blob/master/tests/basic/test_basic.py) """ |
| 2 | +# @TODO make it so we return the best path as discussed in class |
| 3 | +# In class, we discussed that we should have a terminal function |
| 4 | +# which approximates the cost_to_go and returns the path to the node |
| 5 | +# with the least cost_to_go + terminal_cost |
| 6 | + |
| 7 | +from abc import ABC, abstractmethod |
| 8 | +from typing import Callable, Dict, Iterable, Union, TypeVar, Generic |
| 9 | +from math import inf as infinity |
| 10 | +from operator import attrgetter |
| 11 | +import heapq |
| 12 | + |
| 13 | +# introduce generic type |
| 14 | +T = TypeVar("T") |
| 15 | + |
| 16 | + |
| 17 | +################################################################################ |
| 18 | +class SearchNode(Generic[T]): |
| 19 | + """Representation of a search node""" |
| 20 | + |
| 21 | + __slots__ = ("data", "gscore", "fscore", "tscore", "closed", "came_from", "in_openset", "cache") |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, data: T, gscore: float = infinity, fscore: float = infinity, tscore:float = infinity |
| 25 | + ) -> None: |
| 26 | + self.data = data |
| 27 | + self.gscore = gscore |
| 28 | + self.fscore = fscore |
| 29 | + self.tscore = tscore |
| 30 | + self.closed = False |
| 31 | + self.in_openset = False |
| 32 | + self.came_from: Union[None, SearchNode[T]] = None |
| 33 | + self.cache: Any = None |
| 34 | + |
| 35 | + def __lt__(self, b: "SearchNode[T]") -> bool: |
| 36 | + """Natural order is based on the fscore value & is used by heapq operations""" |
| 37 | + return self.fscore < b.fscore |
| 38 | + |
| 39 | + |
| 40 | +################################################################################ |
| 41 | +class SearchNodeDict(Dict[T, SearchNode[T]]): |
| 42 | + """A dict that returns a new SearchNode when a key is missing""" |
| 43 | + |
| 44 | + def __missing__(self, k) -> SearchNode[T]: |
| 45 | + v = SearchNode(k) |
| 46 | + self.__setitem__(k, v) |
| 47 | + return v |
| 48 | + |
| 49 | + |
| 50 | +################################################################################ |
| 51 | +SNType = TypeVar("SNType", bound=SearchNode) |
| 52 | + |
| 53 | + |
| 54 | +class OpenSet(Generic[SNType]): |
| 55 | + def __init__(self) -> None: |
| 56 | + self.heap: list[SNType] = [] |
| 57 | + |
| 58 | + def push(self, item: SNType) -> None: |
| 59 | + item.in_openset = True |
| 60 | + heapq.heappush(self.heap, item) |
| 61 | + |
| 62 | + def pop(self) -> SNType: |
| 63 | + item = heapq.heappop(self.heap) |
| 64 | + item.in_openset = False |
| 65 | + return item |
| 66 | + |
| 67 | + def remove(self, item: SNType) -> None: |
| 68 | + idx = self.heap.index(item) |
| 69 | + item.in_openset = False |
| 70 | + item = self.heap.pop() |
| 71 | + if idx < len(self.heap): |
| 72 | + self.heap[idx] = item |
| 73 | + # Fix heap invariants |
| 74 | + heapq._siftup(self.heap, idx) |
| 75 | + heapq._siftdown(self.heap, 0, idx) |
| 76 | + |
| 77 | + def __len__(self) -> int: |
| 78 | + return len(self.heap) |
| 79 | + |
| 80 | + |
| 81 | +################################################################################* |
| 82 | + |
| 83 | + |
| 84 | +class AStar(ABC, Generic[T]): |
| 85 | + __slots__ = () |
| 86 | + |
| 87 | + @abstractmethod |
| 88 | + def heuristic_cost_estimate(self, current: T, goal: T) -> float: |
| 89 | + """ |
| 90 | + Computes the estimated (rough) distance between a node and the goal. |
| 91 | + The second parameter is always the goal. |
| 92 | +
|
| 93 | + This method must be implemented in a subclass. |
| 94 | + """ |
| 95 | + raise NotImplementedError |
| 96 | + |
| 97 | + @abstractmethod |
| 98 | + def terminal_cost_estimate(self, current: T, goal: T) -> float: |
| 99 | + """Computes the estimated distance between a node and the goal. |
| 100 | + This function is called after all iterations of A* have been run |
| 101 | + and is used to determine the closest node to the goal found so far. |
| 102 | +
|
| 103 | + This method must be implemented in a subclass. |
| 104 | +
|
| 105 | + Args: |
| 106 | + current (T): Current T |
| 107 | + goal (T): goal T |
| 108 | +
|
| 109 | + Returns: |
| 110 | + float: _description_ |
| 111 | + """ |
| 112 | + raise NotImplementedError |
| 113 | + |
| 114 | + def distance_between(self, n1: T, n2: T) -> float: |
| 115 | + """ |
| 116 | + Gives the real distance between two adjacent nodes n1 and n2 (i.e n2 |
| 117 | + belongs to the list of n1's neighbors). |
| 118 | + n2 is guaranteed to belong to the list returned by the call to neighbors(n1). |
| 119 | +
|
| 120 | + This method (or "path_distance_between") must be implemented in a subclass. |
| 121 | + """ |
| 122 | + raise NotImplementedError |
| 123 | + |
| 124 | + def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float: |
| 125 | + """ |
| 126 | + Gives the real distance between the node n1 and its neighbor n2. |
| 127 | + n2 is guaranteed to belong to the list returned by the call to |
| 128 | + path_neighbors(n1). |
| 129 | +
|
| 130 | + Calls "distance_between"`by default. |
| 131 | + """ |
| 132 | + return self.distance_between(n1.data, n2.data) |
| 133 | + |
| 134 | + def neighbors(self, node: T) -> Iterable[T]: |
| 135 | + """ |
| 136 | + For a given node, returns (or yields) the list of its neighbors. |
| 137 | +
|
| 138 | + This method (or "path_neighbors") must be implemented in a subclass. |
| 139 | + """ |
| 140 | + raise NotImplementedError |
| 141 | + |
| 142 | + def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]: |
| 143 | + """ |
| 144 | + For a given node, returns (or yields) the list of its reachable neighbors. |
| 145 | + Calls "neighbors" by default. |
| 146 | + """ |
| 147 | + return self.neighbors(node.data) |
| 148 | + |
| 149 | + def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]: |
| 150 | + return (search_nodes[n] for n in self.path_neighbors(current)) |
| 151 | + |
| 152 | + def is_goal_reached(self, current: T, goal: T) -> bool: |
| 153 | + """ |
| 154 | + Returns true when we can consider that 'current' is the goal. |
| 155 | + The default implementation simply compares `current == goal`, but this |
| 156 | + method can be overwritten in a subclass to provide more refined checks. |
| 157 | + """ |
| 158 | + return current == goal |
| 159 | + |
| 160 | + def reconstruct_path(self, last: SearchNode, reversePath=False) -> Iterable[T]: |
| 161 | + def _gen(): |
| 162 | + current = last |
| 163 | + while current: |
| 164 | + yield current.data |
| 165 | + current = current.came_from |
| 166 | + |
| 167 | + if reversePath: |
| 168 | + return _gen() |
| 169 | + else: |
| 170 | + return reversed(list(_gen())) |
| 171 | + |
| 172 | + def astar( |
| 173 | + self, start: T, goal: T, reversePath: bool = False, iterations: int = 5000 |
| 174 | + ) -> Union[Iterable[T], None]: |
| 175 | + if self.is_goal_reached(start, goal): |
| 176 | + return [start] |
| 177 | + |
| 178 | + openSet: OpenSet[SearchNode[T]] = OpenSet() |
| 179 | + searchNodes: SearchNodeDict[T] = SearchNodeDict() |
| 180 | + startNode = searchNodes[start] = SearchNode( |
| 181 | + start, gscore=0.0, fscore=self.heuristic_cost_estimate(start, goal) |
| 182 | + ) |
| 183 | + openSet.push(startNode) |
| 184 | + bestNode = startNode |
| 185 | + |
| 186 | + iteration = 0 |
| 187 | + |
| 188 | + while openSet and iteration < iterations: |
| 189 | + current = openSet.pop() |
| 190 | + |
| 191 | + if self.is_goal_reached(current.data, goal): |
| 192 | + return self.reconstruct_path(current, reversePath) |
| 193 | + |
| 194 | + current.closed = True |
| 195 | + |
| 196 | + for neighbor in self._neighbors(current, searchNodes): |
| 197 | + if neighbor.closed: |
| 198 | + continue |
| 199 | + |
| 200 | + gscore = current.gscore + self.path_distance_between(current, neighbor) |
| 201 | + |
| 202 | + if gscore >= neighbor.gscore: |
| 203 | + continue |
| 204 | + |
| 205 | + fscore = gscore + self.heuristic_cost_estimate( |
| 206 | + neighbor.data, goal |
| 207 | + ) |
| 208 | + tscore = self.terminal_cost_estimate( |
| 209 | + neighbor.data, goal |
| 210 | + ) |
| 211 | + |
| 212 | + # print(f"Checking node: {neighbor.data} with tscore {tscore}") |
| 213 | + if tscore < bestNode.tscore: |
| 214 | + # print(f"Found a better node: {neighbor.data} with tscore {tscore}") |
| 215 | + bestNode = neighbor |
| 216 | + |
| 217 | + if neighbor.in_openset: |
| 218 | + if neighbor.fscore < fscore: |
| 219 | + # the new path to this node isn't better |
| 220 | + continue |
| 221 | + |
| 222 | + # we have to remove the item from the heap, as its score has changed |
| 223 | + openSet.remove(neighbor) |
| 224 | + |
| 225 | + # update the node |
| 226 | + neighbor.came_from = current |
| 227 | + neighbor.gscore = gscore |
| 228 | + neighbor.fscore = fscore |
| 229 | + neighbor.tscore = tscore |
| 230 | + |
| 231 | + openSet.push(neighbor) |
| 232 | + |
| 233 | + iteration += 1 |
| 234 | + |
| 235 | + # print("Warning: A* search failed to find a path") |
| 236 | + return self.reconstruct_path(bestNode, reversePath) |
| 237 | + |
| 238 | + |
| 239 | +################################################################################ |
| 240 | +U = TypeVar("U") |
| 241 | + |
| 242 | + |
| 243 | +def find_path( |
| 244 | + start: U, |
| 245 | + goal: U, |
| 246 | + neighbors_fnct: Callable[[U], Iterable[U]], |
| 247 | + reversePath=False, |
| 248 | + heuristic_cost_estimate_fnct: Callable[[U, U], float] = lambda a, b: infinity, |
| 249 | + distance_between_fnct: Callable[[U, U], float] = lambda a, b: 1.0, |
| 250 | + is_goal_reached_fnct: Callable[[U, U], bool] = lambda a, b: a == b, |
| 251 | +) -> Union[Iterable[U], None]: |
| 252 | + """A non-class version of the path finding algorithm""" |
| 253 | + |
| 254 | + class FindPath(AStar): |
| 255 | + def heuristic_cost_estimate(self, current: U, goal: U) -> float: |
| 256 | + return heuristic_cost_estimate_fnct(current, goal) # type: ignore |
| 257 | + |
| 258 | + def distance_between(self, n1: U, n2: U) -> float: |
| 259 | + return distance_between_fnct(n1, n2) |
| 260 | + |
| 261 | + def neighbors(self, node) -> Iterable[U]: |
| 262 | + return neighbors_fnct(node) # type: ignore |
| 263 | + |
| 264 | + def is_goal_reached(self, current: U, goal: U) -> bool: |
| 265 | + return is_goal_reached_fnct(current, goal) |
| 266 | + |
| 267 | + return FindPath().astar(start, goal, reversePath) |
0 commit comments