@@ -542,14 +542,14 @@ def remove_node(self, node, rewire=True):
542542 if len (inputs ) > 1 or len (outputs ) > 1 :
543543 raise Exception ('Cannot delete a node with multiple inputs/outputs' )
544544
545- if len (inputs ) == 1 :
545+ if len (outputs ) == 1 and len (inputs ) == 1 :
546+
546547 # Connect inputs -> $outputs
547- if node .name in self .outputs :
548+ if node .outputs [ 0 ] in self .outputs :
548549 msg = f'Remove leaf node { node .name } will connect its input node { inputs [0 ]} to output, but it already is.'
549550 assert inputs [0 ] not in self .outputs , msg
550- self .outputs = [inputs [0 ] if name == node .name else name for name in self .outputs ]
551+ self .outputs = [inputs [0 ] if name == node .outputs [ 0 ] else name for name in self .outputs ]
551552
552- if len (outputs ) == 1 and len (inputs ) == 1 :
553553 inp_var = node .get_input_variable ()
554554 out_var = node .get_output_variable ()
555555
@@ -565,9 +565,6 @@ def remove_node(self, node, rewire=True):
565565 if outputs [0 ] == nxt_inp :
566566 next_node .inputs [i ] = inputs [0 ]
567567
568- if node .outputs [0 ] in self .outputs :
569- prev_node = node .get_input_node (node .inputs [0 ])
570- self .outputs [self .outputs .index (node .outputs [0 ])] = prev_node .outputs [0 ]
571568 del self .output_vars [node .outputs [0 ]]
572569 del self .graph [node .name ]
573570
0 commit comments