Skip to content

Commit 3ee344b

Browse files
committed
ENH(DAG): NEW SOLVER
+ Pruning behaves correctly also when outputs given; this happens by breaking incoming provide-links to any given intermedediate inputs. + Unsatisfied detection now includes those without outputs due to broken links (above). + Remove some uneeded "glue" from unsatisfied-detection code, leftover from previous compile() refactoring. + Renamed satisfiable --> satisfied. + Improved unknown output requested raise-message. + x2 TCs, in #24 and 1st in #25 now PASS. - 2x TCs in #25 still FAIL, and need "Pinning" of given-inputs (the operation MUST and MUST NOT run in these cases).
1 parent 2ce2a43 commit 3ee344b

File tree

2 files changed

+66
-60
lines changed

2 files changed

+66
-60
lines changed

graphkit/network.py

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -187,45 +187,56 @@ def _build_execution_plan(self, dag):
187187

188188
return plan
189189

190-
def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
190+
def _collect_unsatisfied_operations(self, dag, inputs):
191191
"""
192-
Traverse ordered graph and mark satisfied needs on each operation,
192+
Traverse topologically sorted dag to collect un-satisfied operations.
193193
194-
collecting those missing at least one.
195-
Since the graph is ordered, as soon as we're on an operation,
196-
all its needs have been accounted, so we can get its satisfaction.
194+
Unsatisfied operations are those suffering from ANY of the following:
195+
196+
- They are missing at least one compulsory need-input.
197+
Since the dag is ordered, as soon as we're on an operation,
198+
all its needs have been accounted, so we can get its satisfaction.
199+
200+
- Their provided outputs are not linked to any data in the dag.
201+
An operation might not have any output link when :meth:`_solve_dag()`
202+
has broken them, due to given intermediate inputs.
197203
198-
:param necessary_nodes:
199-
the subset of the graph to consider but WITHOUT the initial data
200-
(because that is what :meth:`compile()` can gives us...)
204+
:param dag:
205+
the graph to consider
201206
:param inputs:
202207
an iterable of the names of the input values
203208
return:
204-
a list of unsatisfiable operations
209+
a list of unsatisfied operations to prune
205210
"""
206-
G = self.graph # shortcut
207-
ok_data = set(inputs) # to collect producible data
208-
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
209-
unsatisfiables = [] # to collect operations with partial needs
210-
# We also need inputs to mark op_satisfaction.
211-
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
212-
for node in nx.topological_sort(G.subgraph(nodes)):
211+
# To collect data that will be produced.
212+
ok_data = set(inputs)
213+
# To colect the map of operations --> satisfied-needs.
214+
op_satisfaction = defaultdict(set)
215+
# To collect the operations to drop.
216+
unsatisfied = []
217+
for node in nx.topological_sort(dag):
213218
if isinstance(node, Operation):
214-
real_needs = set(n for n in node.needs if not isinstance(n, optional))
215-
if real_needs.issubset(op_satisfaction[node]):
216-
# mark all future data-provides as ok
217-
ok_data.update(G.adj[node])
219+
if not dag.adj[node]:
220+
# Prune operations ending up without any provided-outputs.
221+
unsatisfied.append(node)
218222
else:
219-
unsatisfiables.append(node)
223+
real_needs = set(n for n in node.needs if not isinstance(n, optional))
224+
if real_needs.issubset(op_satisfaction[node]):
225+
# We have a satisfied operation; mark its output-data
226+
# as ok.
227+
ok_data.update(dag.adj[node])
228+
else:
229+
# Prune operations with partial inputs.
230+
unsatisfied.append(node)
220231
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
221232
if node in ok_data:
222233
# mark satisfied-needs on all future operations
223-
for future_op in G.adj[node]:
234+
for future_op in dag.adj[node]:
224235
op_satisfaction[future_op].add(node)
225236
else:
226237
raise AssertionError("Unrecognized network graph node %r" % node)
227238

228-
return unsatisfiables
239+
return unsatisfied
229240

230241

231242
def _solve_dag(self, outputs, inputs):
@@ -246,50 +257,44 @@ def _solve_dag(self, outputs, inputs):
246257
247258
:return:
248259
the subgraph comprising the solution
249-
250260
"""
251-
graph = self.graph
252-
if not outputs:
261+
dag = self.graph
262+
263+
# Ignore input names that aren't in the graph.
264+
graph_inputs = iset(dag.nodes) & inputs # preserve order
253265

254-
# If caller requested all outputs, the necessary nodes are all
255-
# nodes that are reachable from one of the inputs. Ignore input
256-
# names that aren't in the graph.
257-
necessary_nodes = set() # unordered, not iterated
258-
for input_name in iter(inputs):
259-
if graph.has_node(input_name):
260-
necessary_nodes |= nx.descendants(graph, input_name)
266+
# Scream if some requested outputs aren't in the graph.
267+
unknown_outputs = iset(outputs) - dag.nodes
268+
if unknown_outputs:
269+
raise ValueError(
270+
"Unknown output node(s) requested: %s"
271+
% ", ".join(unknown_outputs))
261272

262-
else:
273+
dag = dag.copy() # preserve net's graph
274+
275+
# Break the incoming edges to all given inputs.
276+
#
277+
# Nodes producing any given intermediate inputs are unecessary
278+
# (unless they are also used elsewhere).
279+
# To discover which ones to prune, we break their incoming edges
280+
# and they will drop out while collecting ancestors from the outputs.
281+
for given in graph_inputs:
282+
dag.remove_edges_from(list(dag.in_edges(given)))
283+
284+
if outputs:
285+
# If caller requested specific outputs, we can prune any
286+
# unrelated nodes further up the dag.
287+
ending_in_outputs = set()
288+
for input_name in outputs:
289+
ending_in_outputs.update(nx.ancestors(dag, input_name))
290+
dag = dag.subgraph(ending_in_outputs | set(outputs))
263291

264-
# If the caller requested a subset of outputs, find any nodes that
265-
# are made unecessary because we were provided with an input that's
266-
# deeper into the network graph. Ignore input names that aren't
267-
# in the graph.
268-
unnecessary_nodes = set() # unordered, not iterated
269-
for input_name in iter(inputs):
270-
if graph.has_node(input_name):
271-
unnecessary_nodes |= nx.ancestors(graph, input_name)
272-
273-
# Find the nodes we need to be able to compute the requested
274-
# outputs. Raise an exception if a requested output doesn't
275-
# exist in the graph.
276-
necessary_nodes = set() # unordered, not iterated
277-
for output_name in outputs:
278-
if not graph.has_node(output_name):
279-
raise ValueError("graphkit graph does not have an output "
280-
"node named %s" % output_name)
281-
necessary_nodes |= nx.ancestors(graph, output_name)
282-
283-
# Get rid of the unnecessary nodes from the set of necessary ones.
284-
necessary_nodes -= unnecessary_nodes
285292

286293
# Drop (un-satifiable) operations with partial inputs.
287294
# See yahoo/graphkit#18
288295
#
289-
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
290-
necessary_nodes -= set(unsatisfiables)
291-
292-
shrinked_dag = graph.subgraph(necessary_nodes)
296+
unsatisfied = self._collect_unsatisfied_operations(dag, inputs)
297+
shrinked_dag = dag.subgraph(dag.nodes - unsatisfied)
293298

294299
return shrinked_dag
295300

test/test_graphkit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
265265
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
266266
)
267267

268-
exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
268+
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
269269
# v1.2.4 is ok
270270
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
271271
# FAILS
272272
# - on v1.2.4 with KeyError: 'a',
273273
# - on #18 (unsatisfied) with no result.
274+
# FIXED on #18+#26 (new dag solver).
274275
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
275276

276277

0 commit comments

Comments
 (0)