Skip to content

Commit 1967995

Browse files
committed
ENH(net,yahoo#18): ignore UN-SATISFIABLE operations with partial inputs
+ The x2 TCs added just before are now passing.
1 parent f316494 commit 1967995

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

graphkit/network.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from io import StringIO
99

1010
from .base import Operation
11+
from .modifiers import optional
1112

1213

1314
class DataPlaceholderNode(str):
@@ -141,6 +142,65 @@ def compile(self):
141142
raise TypeError("Unrecognized network graph node")
142143

143144

145+
def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited):
146+
"""
147+
Recusrively check if operation inputs are given/calculated (satisfied), or not.
148+
149+
:param satisfiables:
150+
the set to populate with satisfiable operations
151+
152+
:param visited:
153+
a cache of operations & needs, not to visit them again
154+
:return:
155+
true if opearation is satisfiable
156+
"""
157+
assert isinstance(operation, Operation), (
158+
"Expected Operation, got:",
159+
type(operation),
160+
)
161+
162+
if operation in visited:
163+
return visited[operation]
164+
165+
166+
def is_need_satisfiable(need):
167+
if need in visited:
168+
return visited[need]
169+
170+
if need in inputs:
171+
satisfied = True
172+
else:
173+
need_providers = list(self.graph.predecessors(need))
174+
satisfied = bool(need_providers) and any(
175+
self._collect_satisfiable_needs(op, inputs, satisfiables, visited)
176+
for op in need_providers
177+
)
178+
visited[need] = satisfied
179+
180+
return satisfied
181+
182+
satisfied = all(
183+
is_need_satisfiable(need)
184+
for need in operation.needs
185+
if not isinstance(need, optional)
186+
)
187+
if satisfied:
188+
satisfiables.add(operation)
189+
visited[operation] = satisfied
190+
191+
return satisfied
192+
193+
194+
def _collect_satisfiable_operations(self, nodes, inputs):
195+
satisfiables = set()
196+
visited = {}
197+
for node in nodes:
198+
if node not in visited and isinstance(node, Operation):
199+
self._collect_satisfiable_needs(node, inputs, satisfiables, visited)
200+
201+
return satisfiables
202+
203+
144204
def _find_necessary_steps(self, outputs, inputs):
145205
"""
146206
Determines what graph steps need to pe run to get to the requested
@@ -204,6 +264,13 @@ def _find_necessary_steps(self, outputs, inputs):
204264
# Get rid of the unnecessary nodes from the set of necessary ones.
205265
necessary_nodes -= unnecessary_nodes
206266

267+
# Drop (un-satifiable) operations with partial inputs.
268+
# See https://github.com/yahoo/graphkit/pull/18
269+
#
270+
satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs)
271+
for node in list(necessary_nodes):
272+
if isinstance(node, Operation) and node not in satisfiables:
273+
necessary_nodes.remove(node)
207274

208275
necessary_steps = [step for step in self.steps if step in necessary_nodes]
209276

0 commit comments

Comments
 (0)