Skip to content

Commit 7e851b1

Browse files
committed
ENH(DAG): NEW SOLVER
+ Pruning behaves correctly also when outputs given; this happens by breaking incoming provide-links to any given intermedediate inputs. + Unsatisfied detection now includes those without outputs due to broken links (above). + Remove some uneeded "glue" from unsatisfied-detection code, leftover from previous compile() refactoring. + Renamed satisfiable --> satisfied. + Improved unknown output requested raise-message. + x2 TCs, in #24 and 1st in #25 now PASS. - 2x TCs in #25 still FAIL, and need "Pinning" of given-inputs (the operation MUST and MUST NOT run in these cases).
1 parent 2ce2a43 commit 7e851b1

File tree

2 files changed

+78
-60
lines changed

2 files changed

+78
-60
lines changed

graphkit/network.py

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ def __repr__(self):
3434
return 'DeleteInstruction("%s")' % self
3535

3636

37+
class PinInstruction(str):
38+
"""
39+
An instruction in the *execution plan* not to store the newly compute value
40+
into network's values-cache but to pin it instead to some given value.
41+
It is used ensure that given intermediate values are not overwritten when
42+
their providing functions could not be avoided, because their other outputs
43+
are needed elesewhere.
44+
"""
45+
def __repr__(self):
46+
return 'PinInstruction("%s")' % self
47+
48+
3749
class Network(object):
3850
"""
3951
This is the main network implementation. The class contains all of the
@@ -187,45 +199,56 @@ def _build_execution_plan(self, dag):
187199

188200
return plan
189201

190-
def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
202+
def _collect_unsatisfied_operations(self, dag, inputs):
191203
"""
192-
Traverse ordered graph and mark satisfied needs on each operation,
204+
Traverse topologically sorted dag to collect un-satisfied operations.
193205
194-
collecting those missing at least one.
195-
Since the graph is ordered, as soon as we're on an operation,
196-
all its needs have been accounted, so we can get its satisfaction.
206+
Unsatisfied operations are those suffering from ANY of the following:
207+
208+
- They are missing at least one compulsory need-input.
209+
Since the dag is ordered, as soon as we're on an operation,
210+
all its needs have been accounted, so we can get its satisfaction.
211+
212+
- Their provided outputs are not linked to any data in the dag.
213+
An operation might not have any output link when :meth:`_solve_dag()`
214+
has broken them, due to given intermediate inputs.
197215
198-
:param necessary_nodes:
199-
the subset of the graph to consider but WITHOUT the initial data
200-
(because that is what :meth:`compile()` can gives us...)
216+
:param dag:
217+
the graph to consider
201218
:param inputs:
202219
an iterable of the names of the input values
203220
return:
204-
a list of unsatisfiable operations
221+
a list of unsatisfied operations to prune
205222
"""
206-
G = self.graph # shortcut
207-
ok_data = set(inputs) # to collect producible data
208-
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
209-
unsatisfiables = [] # to collect operations with partial needs
210-
# We also need inputs to mark op_satisfaction.
211-
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
212-
for node in nx.topological_sort(G.subgraph(nodes)):
223+
# To collect data that will be produced.
224+
ok_data = set(inputs)
225+
# To colect the map of operations --> satisfied-needs.
226+
op_satisfaction = defaultdict(set)
227+
# To collect the operations to drop.
228+
unsatisfied = []
229+
for node in nx.topological_sort(dag):
213230
if isinstance(node, Operation):
214-
real_needs = set(n for n in node.needs if not isinstance(n, optional))
215-
if real_needs.issubset(op_satisfaction[node]):
216-
# mark all future data-provides as ok
217-
ok_data.update(G.adj[node])
231+
if not dag.adj[node]:
232+
# Prune operations ending up without any provided-outputs.
233+
unsatisfied.append(node)
218234
else:
219-
unsatisfiables.append(node)
235+
real_needs = set(n for n in node.needs if not isinstance(n, optional))
236+
if real_needs.issubset(op_satisfaction[node]):
237+
# We have a satisfied operation; mark its output-data
238+
# as ok.
239+
ok_data.update(dag.adj[node])
240+
else:
241+
# Prune operations with partial inputs.
242+
unsatisfied.append(node)
220243
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
221244
if node in ok_data:
222245
# mark satisfied-needs on all future operations
223-
for future_op in G.adj[node]:
246+
for future_op in dag.adj[node]:
224247
op_satisfaction[future_op].add(node)
225248
else:
226249
raise AssertionError("Unrecognized network graph node %r" % node)
227250

228-
return unsatisfiables
251+
return unsatisfied
229252

230253

231254
def _solve_dag(self, outputs, inputs):
@@ -246,50 +269,44 @@ def _solve_dag(self, outputs, inputs):
246269
247270
:return:
248271
the subgraph comprising the solution
249-
250272
"""
251-
graph = self.graph
252-
if not outputs:
273+
dag = self.graph
274+
275+
# Ignore input names that aren't in the graph.
276+
graph_inputs = iset(dag.nodes) & inputs # preserve order
253277

254-
# If caller requested all outputs, the necessary nodes are all
255-
# nodes that are reachable from one of the inputs. Ignore input
256-
# names that aren't in the graph.
257-
necessary_nodes = set() # unordered, not iterated
258-
for input_name in iter(inputs):
259-
if graph.has_node(input_name):
260-
necessary_nodes |= nx.descendants(graph, input_name)
278+
# Scream if some requested outputs aren't in the graph.
279+
unknown_outputs = iset(outputs) - dag.nodes
280+
if unknown_outputs:
281+
raise ValueError(
282+
"Unknown output node(s) requested: %s"
283+
% ", ".join(unknown_outputs))
261284

262-
else:
285+
dag = dag.copy() # preserve net's graph
286+
287+
# Break the incoming edges to all given inputs.
288+
#
289+
# Nodes producing any given intermediate inputs are unecessary
290+
# (unless they are also used elsewhere).
291+
# To discover which ones to prune, we break their incoming edges
292+
# and they will drop out while collecting ancestors from the outputs.
293+
for given in graph_inputs:
294+
dag.remove_edges_from(list(dag.in_edges(given)))
295+
296+
if outputs:
297+
# If caller requested specific outputs, we can prune any
298+
# unrelated nodes further up the dag.
299+
ending_in_outputs = set()
300+
for input_name in outputs:
301+
ending_in_outputs.update(nx.ancestors(dag, input_name))
302+
dag = dag.subgraph(ending_in_outputs | set(outputs))
263303

264-
# If the caller requested a subset of outputs, find any nodes that
265-
# are made unecessary because we were provided with an input that's
266-
# deeper into the network graph. Ignore input names that aren't
267-
# in the graph.
268-
unnecessary_nodes = set() # unordered, not iterated
269-
for input_name in iter(inputs):
270-
if graph.has_node(input_name):
271-
unnecessary_nodes |= nx.ancestors(graph, input_name)
272-
273-
# Find the nodes we need to be able to compute the requested
274-
# outputs. Raise an exception if a requested output doesn't
275-
# exist in the graph.
276-
necessary_nodes = set() # unordered, not iterated
277-
for output_name in outputs:
278-
if not graph.has_node(output_name):
279-
raise ValueError("graphkit graph does not have an output "
280-
"node named %s" % output_name)
281-
necessary_nodes |= nx.ancestors(graph, output_name)
282-
283-
# Get rid of the unnecessary nodes from the set of necessary ones.
284-
necessary_nodes -= unnecessary_nodes
285304

286305
# Drop (un-satifiable) operations with partial inputs.
287306
# See yahoo/graphkit#18
288307
#
289-
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
290-
necessary_nodes -= set(unsatisfiables)
291-
292-
shrinked_dag = graph.subgraph(necessary_nodes)
308+
unsatisfied = self._collect_unsatisfied_operations(dag, inputs)
309+
shrinked_dag = dag.subgraph(dag.nodes - unsatisfied)
293310

294311
return shrinked_dag
295312

test/test_graphkit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
265265
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
266266
)
267267

268-
exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
268+
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
269269
# v1.2.4 is ok
270270
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
271271
# FAILS
272272
# - on v1.2.4 with KeyError: 'a',
273273
# - on #18 (unsatisfied) with no result.
274+
# FIXED on #18+#26 (new dag solver).
274275
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
275276

276277

0 commit comments

Comments
 (0)