|
8 | 8 | from io import StringIO |
9 | 9 |
|
10 | 10 | from .base import Operation |
| 11 | +from .modifiers import optional |
11 | 12 |
|
12 | 13 |
|
13 | 14 | class DataPlaceholderNode(str): |
@@ -141,6 +142,67 @@ def compile(self): |
141 | 142 | raise TypeError("Unrecognized network graph node") |
142 | 143 |
|
143 | 144 |
|
| 145 | + def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited): |
| 146 | + """ |
| 147 | + Recusrively check if operation inputs are given/calculated (satisfied), or not. |
| 148 | +
|
| 149 | + :param satisfiables: |
| 150 | + the set to populate with satisfiable operations |
| 151 | +
|
| 152 | + :param visited: |
| 153 | + a cache of operations & needs, not to visit them again |
| 154 | + :return: |
| 155 | + true if opearation is satisfiable |
| 156 | + """ |
| 157 | + assert isinstance(operation, Operation), ( |
| 158 | + "Expected Operation, got:", |
| 159 | + type(operation), |
| 160 | + ) |
| 161 | + |
| 162 | + if operation in visited: |
| 163 | + return visited[operation] |
| 164 | + |
| 165 | + |
| 166 | + def is_need_satisfiable(need): |
| 167 | + if need in visited: |
| 168 | + return visited[need] |
| 169 | + |
| 170 | + if need in inputs: |
| 171 | + satisfied = True |
| 172 | + else: |
| 173 | + need_providers = list(self.graph.predecessors(need)) |
| 174 | + satisfied = bool(need_providers) and any( |
| 175 | + self._collect_satisfiable_needs(op, inputs, satisfiables, visited) |
| 176 | + for op in need_providers |
| 177 | + ) |
| 178 | + visited[need] = satisfied |
| 179 | + |
| 180 | + return satisfied |
| 181 | + |
| 182 | + satisfied = all( |
| 183 | + is_need_satisfiable(need) |
| 184 | + for need in operation.needs |
| 185 | + if not isinstance(need, optional) |
| 186 | + ) |
| 187 | + if satisfied: |
| 188 | + satisfiables.add(operation) |
| 189 | + visited[operation] = satisfied |
| 190 | + |
| 191 | + return satisfied |
| 192 | + |
| 193 | + |
| 194 | + def _collect_satisfiable_operations(self, nodes, inputs): |
| 195 | + satisfiables = set() |
| 196 | + from unittest import mock |
| 197 | + |
| 198 | + visited = {} |
| 199 | + for node in nodes: |
| 200 | + if node not in visited and isinstance(node, Operation): |
| 201 | + self._collect_satisfiable_needs(node, inputs, satisfiables, visited) |
| 202 | + |
| 203 | + return satisfiables |
| 204 | + |
| 205 | + |
144 | 206 | def _find_necessary_steps(self, outputs, inputs): |
145 | 207 | """ |
146 | 208 | Determines what graph steps need to pe run to get to the requested |
@@ -204,6 +266,13 @@ def _find_necessary_steps(self, outputs, inputs): |
204 | 266 | # Get rid of the unnecessary nodes from the set of necessary ones. |
205 | 267 | necessary_nodes -= unnecessary_nodes |
206 | 268 |
|
| 269 | + # Drop (un-satifiable) operations with partial inputs. |
| 270 | + # See https://github.com/yahoo/graphkit/pull/18 |
| 271 | + # |
| 272 | + satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs) |
| 273 | + for node in list(necessary_nodes): |
| 274 | + if isinstance(node, Operation) and node not in satisfiables: |
| 275 | + necessary_nodes.remove(node) |
207 | 276 |
|
208 | 277 | necessary_steps = [step for step in self.steps if step in necessary_nodes] |
209 | 278 |
|
|
0 commit comments