@@ -34,6 +34,18 @@ def __repr__(self):
3434 return 'DeleteInstruction("%s")' % self
3535
3636
37+ class PinInstruction (str ):
38+ """
39+ An instruction in the *execution plan* not to store the newly compute value
40+ into network's values-cache but to pin it instead to some given value.
41+ It is used ensure that given intermediate values are not overwritten when
42+ their providing functions could not be avoided, because their other outputs
43+ are needed elesewhere.
44+ """
45+ def __repr__ (self ):
46+ return 'PinInstruction("%s")' % self
47+
48+
3749class Network (object ):
3850 """
3951 This is the main network implementation. The class contains all of the
@@ -187,45 +199,56 @@ def _build_execution_plan(self, dag):
187199
188200 return plan
189201
190- def _collect_unsatisfiable_operations (self , necessary_nodes , inputs ):
202+ def _collect_unsatisfied_operations (self , dag , inputs ):
191203 """
192- Traverse ordered graph and mark satisfied needs on each operation,
204+ Traverse topologically sorted dag to collect un- satisfied operations.
193205
194- collecting those missing at least one.
195- Since the graph is ordered, as soon as we're on an operation,
196- all its needs have been accounted, so we can get its satisfaction.
206+ Unsatisfied operations are those suffering from ANY of the following:
207+
208+ - They are missing at least one compulsory need-input.
209+ Since the dag is ordered, as soon as we're on an operation,
210+ all its needs have been accounted, so we can get its satisfaction.
211+
212+ - Their provided outputs are not linked to any data in the dag.
213+ An operation might not have any output link when :meth:`_solve_dag()`
214+ has broken them, due to given intermediate inputs.
197215
198- :param necessary_nodes:
199- the subset of the graph to consider but WITHOUT the initial data
200- (because that is what :meth:`compile()` can gives us...)
216+ :param dag:
217+ the graph to consider
201218 :param inputs:
202219 an iterable of the names of the input values
203220 return:
204- a list of unsatisfiable operations
221+ a list of unsatisfied operations to prune
205222 """
206- G = self . graph # shortcut
207- ok_data = set (inputs ) # to collect producible data
208- op_satisfaction = defaultdict ( set ) # to collect operation satisfiable needs
209- unsatisfiables = [] # to collect operations with partial needs
210- # We also need inputs to mark op_satisfaction .
211- nodes = chain ( necessary_nodes , inputs ) # note that `inputs` are plain strings
212- for node in nx .topological_sort (G . subgraph ( nodes ) ):
223+ # To collect data that will be produced.
224+ ok_data = set (inputs )
225+ # To colect the map of operations --> satisfied- needs.
226+ op_satisfaction = defaultdict ( set )
227+ # To collect the operations to drop .
228+ unsatisfied = []
229+ for node in nx .topological_sort (dag ):
213230 if isinstance (node , Operation ):
214- real_needs = set (n for n in node .needs if not isinstance (n , optional ))
215- if real_needs .issubset (op_satisfaction [node ]):
216- # mark all future data-provides as ok
217- ok_data .update (G .adj [node ])
231+ if not dag .adj [node ]:
232+ # Prune operations ending up without any provided-outputs.
233+ unsatisfied .append (node )
218234 else :
219- unsatisfiables .append (node )
235+ real_needs = set (n for n in node .needs if not isinstance (n , optional ))
236+ if real_needs .issubset (op_satisfaction [node ]):
237+ # We have a satisfied operation; mark its output-data
238+ # as ok.
239+ ok_data .update (dag .adj [node ])
240+ else :
241+ # Prune operations with partial inputs.
242+ unsatisfied .append (node )
220243 elif isinstance (node , (DataPlaceholderNode , str )): # `str` are givens
221244 if node in ok_data :
222245 # mark satisfied-needs on all future operations
223- for future_op in G .adj [node ]:
246+ for future_op in dag .adj [node ]:
224247 op_satisfaction [future_op ].add (node )
225248 else :
226249 raise AssertionError ("Unrecognized network graph node %r" % node )
227250
228- return unsatisfiables
251+ return unsatisfied
229252
230253
231254 def _solve_dag (self , outputs , inputs ):
@@ -246,50 +269,44 @@ def _solve_dag(self, outputs, inputs):
246269
247270 :return:
248271 the subgraph comprising the solution
249-
250272 """
251- graph = self .graph
252- if not outputs :
273+ dag = self .graph
274+
275+ # Ignore input names that aren't in the graph.
276+ graph_inputs = iset (dag .nodes ) & inputs # preserve order
253277
254- # If caller requested all outputs, the necessary nodes are all
255- # nodes that are reachable from one of the inputs. Ignore input
256- # names that aren't in the graph.
257- necessary_nodes = set () # unordered, not iterated
258- for input_name in iter (inputs ):
259- if graph .has_node (input_name ):
260- necessary_nodes |= nx .descendants (graph , input_name )
278+ # Scream if some requested outputs aren't in the graph.
279+ unknown_outputs = iset (outputs ) - dag .nodes
280+ if unknown_outputs :
281+ raise ValueError (
282+ "Unknown output node(s) requested: %s"
283+ % ", " .join (unknown_outputs ))
261284
262- else :
285+ dag = dag .copy () # preserve net's graph
286+
287+ # Break the incoming edges to all given inputs.
288+ #
289+ # Nodes producing any given intermediate inputs are unecessary
290+ # (unless they are also used elsewhere).
291+ # To discover which ones to prune, we break their incoming edges
292+ # and they will drop out while collecting ancestors from the outputs.
293+ for given in graph_inputs :
294+ dag .remove_edges_from (list (dag .in_edges (given )))
295+
296+ if outputs :
297+ # If caller requested specific outputs, we can prune any
298+ # unrelated nodes further up the dag.
299+ ending_in_outputs = set ()
300+ for input_name in outputs :
301+ ending_in_outputs .update (nx .ancestors (dag , input_name ))
302+ dag = dag .subgraph (ending_in_outputs | set (outputs ))
263303
264- # If the caller requested a subset of outputs, find any nodes that
265- # are made unecessary because we were provided with an input that's
266- # deeper into the network graph. Ignore input names that aren't
267- # in the graph.
268- unnecessary_nodes = set () # unordered, not iterated
269- for input_name in iter (inputs ):
270- if graph .has_node (input_name ):
271- unnecessary_nodes |= nx .ancestors (graph , input_name )
272-
273- # Find the nodes we need to be able to compute the requested
274- # outputs. Raise an exception if a requested output doesn't
275- # exist in the graph.
276- necessary_nodes = set () # unordered, not iterated
277- for output_name in outputs :
278- if not graph .has_node (output_name ):
279- raise ValueError ("graphkit graph does not have an output "
280- "node named %s" % output_name )
281- necessary_nodes |= nx .ancestors (graph , output_name )
282-
283- # Get rid of the unnecessary nodes from the set of necessary ones.
284- necessary_nodes -= unnecessary_nodes
285304
286305 # Drop (un-satifiable) operations with partial inputs.
287306 # See yahoo/graphkit#18
288307 #
289- unsatisfiables = self ._collect_unsatisfiable_operations (necessary_nodes , inputs )
290- necessary_nodes -= set (unsatisfiables )
291-
292- shrinked_dag = graph .subgraph (necessary_nodes )
308+ unsatisfied = self ._collect_unsatisfied_operations (dag , inputs )
309+ shrinked_dag = dag .subgraph (dag .nodes - unsatisfied )
293310
294311 return shrinked_dag
295312
0 commit comments