@@ -187,45 +187,56 @@ def _build_execution_plan(self, dag):
187187
188188 return plan
189189
190- def _collect_unsatisfiable_operations (self , necessary_nodes , inputs ):
190+ def _collect_unsatisfied_operations (self , dag , inputs ):
191191 """
192- Traverse ordered graph and mark satisfied needs on each operation,
192+ Traverse topologically sorted dag to collect un- satisfied operations.
193193
194- collecting those missing at least one.
195- Since the graph is ordered, as soon as we're on an operation,
196- all its needs have been accounted, so we can get its satisfaction.
194+ Unsatisfied operations are those suffering from ANY of the following:
195+
196+ - They are missing at least one compulsory need-input.
197+ Since the dag is ordered, as soon as we're on an operation,
198+ all its needs have been accounted, so we can get its satisfaction.
199+
200+ - Their provided outputs are not linked to any data in the dag.
201+ An operation might not have any output link when :meth:`_solve_dag()`
202+ has broken them, due to given intermediate inputs.
197203
198- :param necessary_nodes:
199- the subset of the graph to consider but WITHOUT the initial data
200- (because that is what :meth:`compile()` can gives us...)
204+ :param dag:
205+ the graph to consider
201206 :param inputs:
202207 an iterable of the names of the input values
203208 return:
204- a list of unsatisfiable operations
209+ a list of unsatisfied operations to prune
205210 """
206- G = self . graph # shortcut
207- ok_data = set (inputs ) # to collect producible data
208- op_satisfaction = defaultdict ( set ) # to collect operation satisfiable needs
209- unsatisfiables = [] # to collect operations with partial needs
210- # We also need inputs to mark op_satisfaction .
211- nodes = chain ( necessary_nodes , inputs ) # note that `inputs` are plain strings
212- for node in nx .topological_sort (G . subgraph ( nodes ) ):
211+ # To collect data that will be produced.
212+ ok_data = set (inputs )
213+ # To colect the map of operations --> satisfied- needs.
214+ op_satisfaction = defaultdict ( set )
215+ # To collect the operations to drop .
216+ unsatisfied = []
217+ for node in nx .topological_sort (dag ):
213218 if isinstance (node , Operation ):
214- real_needs = set (n for n in node .needs if not isinstance (n , optional ))
215- if real_needs .issubset (op_satisfaction [node ]):
216- # mark all future data-provides as ok
217- ok_data .update (G .adj [node ])
219+ if not dag .adj [node ]:
220+ # Prune operations ending up without any provided-outputs.
221+ unsatisfied .append (node )
218222 else :
219- unsatisfiables .append (node )
223+ real_needs = set (n for n in node .needs if not isinstance (n , optional ))
224+ if real_needs .issubset (op_satisfaction [node ]):
225+ # We have a satisfied operation; mark its output-data
226+ # as ok.
227+ ok_data .update (dag .adj [node ])
228+ else :
229+ # Prune operations with partial inputs.
230+ unsatisfied .append (node )
220231 elif isinstance (node , (DataPlaceholderNode , str )): # `str` are givens
221232 if node in ok_data :
222233 # mark satisfied-needs on all future operations
223- for future_op in G .adj [node ]:
234+ for future_op in dag .adj [node ]:
224235 op_satisfaction [future_op ].add (node )
225236 else :
226237 raise AssertionError ("Unrecognized network graph node %r" % node )
227238
228- return unsatisfiables
239+ return unsatisfied
229240
230241
231242 def _solve_dag (self , outputs , inputs ):
@@ -246,50 +257,44 @@ def _solve_dag(self, outputs, inputs):
246257
247258 :return:
248259 the subgraph comprising the solution
249-
250260 """
251- graph = self .graph
252- if not outputs :
261+ dag = self .graph
262+
263+ # Ignore input names that aren't in the graph.
264+ graph_inputs = iset (dag .nodes ) & inputs # preserve order
253265
254- # If caller requested all outputs, the necessary nodes are all
255- # nodes that are reachable from one of the inputs. Ignore input
256- # names that aren't in the graph.
257- necessary_nodes = set () # unordered, not iterated
258- for input_name in iter (inputs ):
259- if graph .has_node (input_name ):
260- necessary_nodes |= nx .descendants (graph , input_name )
266+ # Scream if some requested outputs aren't in the graph.
267+ unknown_outputs = iset (outputs ) - dag .nodes
268+ if unknown_outputs :
269+ raise ValueError (
270+ "Unknown output node(s) requested: %s"
271+ % ", " .join (unknown_outputs ))
261272
262- else :
273+ dag = dag .copy () # preserve net's graph
274+
275+ # Break the incoming edges to all given inputs.
276+ #
277+ # Nodes producing any given intermediate inputs are unecessary
278+ # (unless they are also used elsewhere).
279+ # To discover which ones to prune, we break their incoming edges
280+ # and they will drop out while collecting ancestors from the outputs.
281+ for given in graph_inputs :
282+ dag .remove_edges_from (list (dag .in_edges (given )))
283+
284+ if outputs :
285+ # If caller requested specific outputs, we can prune any
286+ # unrelated nodes further up the dag.
287+ ending_in_outputs = set ()
288+ for input_name in outputs :
289+ ending_in_outputs .update (nx .ancestors (dag , input_name ))
290+ dag = dag .subgraph (ending_in_outputs | set (outputs ))
263291
264- # If the caller requested a subset of outputs, find any nodes that
265- # are made unecessary because we were provided with an input that's
266- # deeper into the network graph. Ignore input names that aren't
267- # in the graph.
268- unnecessary_nodes = set () # unordered, not iterated
269- for input_name in iter (inputs ):
270- if graph .has_node (input_name ):
271- unnecessary_nodes |= nx .ancestors (graph , input_name )
272-
273- # Find the nodes we need to be able to compute the requested
274- # outputs. Raise an exception if a requested output doesn't
275- # exist in the graph.
276- necessary_nodes = set () # unordered, not iterated
277- for output_name in outputs :
278- if not graph .has_node (output_name ):
279- raise ValueError ("graphkit graph does not have an output "
280- "node named %s" % output_name )
281- necessary_nodes |= nx .ancestors (graph , output_name )
282-
283- # Get rid of the unnecessary nodes from the set of necessary ones.
284- necessary_nodes -= unnecessary_nodes
285292
286293 # Drop (un-satifiable) operations with partial inputs.
287294 # See yahoo/graphkit#18
288295 #
289- unsatisfiables = self ._collect_unsatisfiable_operations (necessary_nodes , inputs )
290- necessary_nodes -= set (unsatisfiables )
291-
292- shrinked_dag = graph .subgraph (necessary_nodes )
296+ unsatisfied = self ._collect_unsatisfied_operations (dag , inputs )
297+ shrinked_dag = dag .subgraph (dag .nodes - unsatisfied )
293298
294299 return shrinked_dag
295300
0 commit comments