Skip to content

Commit 1519baa

Browse files
committed
ENH(net,#18): ignore UN-SATISFIABLE operations with partial inputs
Usefull when 2 (or more) operations provifing the same output, and only one has fully satisfied inputs. Before it would fail trying to evaluate the un-satisfied ones. + New TC added. .
1 parent 617e577 commit 1519baa

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

graphkit/network.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from io import StringIO
99

1010
from .base import Operation
11+
from .modifiers import optional
1112

1213

1314
class DataPlaceholderNode(str):
@@ -141,6 +142,67 @@ def compile(self):
141142
raise TypeError("Unrecognized network graph node")
142143

143144

145+
def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited):
146+
"""
147+
Recusrively check if operation inputs are given/calculated (satisfied), or not.
148+
149+
:param satisfiables:
150+
the set to populate with satisfiable operations
151+
152+
:param visited:
153+
a cache of operations & needs, not to visit them again
154+
:return:
155+
true if opearation is satisfiable
156+
"""
157+
assert isinstance(operation, Operation), (
158+
"Expected Operation, got:",
159+
type(operation),
160+
)
161+
162+
if operation in visited:
163+
return visited[operation]
164+
165+
166+
def is_need_satisfiable(need):
167+
if need in visited:
168+
return visited[need]
169+
170+
if need in inputs:
171+
satisfied = True
172+
else:
173+
need_providers = list(self.graph.predecessors(need))
174+
satisfied = bool(need_providers) and any(
175+
self._collect_satisfiable_needs(op, inputs, satisfiables, visited)
176+
for op in need_providers
177+
)
178+
visited[need] = satisfied
179+
180+
return satisfied
181+
182+
satisfied = all(
183+
is_need_satisfiable(need)
184+
for need in operation.needs
185+
if not isinstance(need, optional)
186+
)
187+
if satisfied:
188+
satisfiables.add(operation)
189+
visited[operation] = satisfied
190+
191+
return satisfied
192+
193+
194+
def _collect_satisfiable_operations(self, nodes, inputs):
195+
satisfiables = set()
196+
from unittest import mock
197+
198+
visited = {}
199+
for node in nodes:
200+
if node not in visited and isinstance(node, Operation):
201+
self._collect_satisfiable_needs(node, inputs, satisfiables, visited)
202+
203+
return satisfiables
204+
205+
144206
def _find_necessary_steps(self, outputs, inputs):
145207
"""
146208
Determines what graph steps need to pe run to get to the requested
@@ -204,6 +266,13 @@ def _find_necessary_steps(self, outputs, inputs):
204266
# Get rid of the unnecessary nodes from the set of necessary ones.
205267
necessary_nodes -= unnecessary_nodes
206268

269+
# Drop (un-satifiable) operations with partial inputs.
270+
# See https://github.com/yahoo/graphkit/pull/18
271+
#
272+
satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs)
273+
for node in list(necessary_nodes):
274+
if isinstance(node, Operation) and node not in satisfiables:
275+
necessary_nodes.remove(node)
207276

208277
necessary_steps = [step for step in self.steps if step in necessary_nodes]
209278

test/test_graphkit.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pickle
66

77
from pprint import pprint
8-
from operator import add
8+
from operator import add, mul, floordiv
99
from numpy.testing import assert_raises
1010

1111
import graphkit.network as network
@@ -69,6 +69,22 @@ def pow_op1(a, exponent=2):
6969
# net.plot(show=True)
7070

7171

72+
def test_operations_with_partial_inputs_ignored():
73+
graph = compose(name="graph")(
74+
operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul),
75+
operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv),
76+
operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add),
77+
)
78+
79+
exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
80+
assert graph({"a": 10, "b1": 2, "c": 1}) == exp
81+
assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21}
82+
83+
exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
84+
assert graph({"a": 10, "b2": 2, "c": 1}) == exp
85+
assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6}
86+
87+
7288
def test_network_simple_merge():
7389

7490
sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add)

0 commit comments

Comments
 (0)