@@ -87,8 +87,10 @@ struct NegotiationContext {
8787 holder_is_initiator : bool ,
8888 received_tx_add_input_count : u16 ,
8989 received_tx_add_output_count : u16 ,
90- /// The inputs to be contributed by the holder.
91- inputs : HashMap < SerialId , InteractiveTxInput > ,
90+ /// The inputs contributed by the holder
91+ local_inputs : HashMap < SerialId , InteractiveTxInput > ,
92+ /// The inputs contributed by the counterparty
93+ remote_inputs : HashMap < SerialId , InteractiveTxInput > ,
9294 /// The output intended to be the new funding output.
9395 /// When an output added to the same pubkey, it will be treated as the shared output.
9496 /// The script pubkey is used to discriminate which output is the funding output.
@@ -108,8 +110,10 @@ struct NegotiationContext {
108110 /// Note: this output is also included in `outputs`.
109111 actual_new_funding_output : Option < SharedOutput > ,
110112 prevtx_outpoints : HashSet < OutPoint > ,
111- /// The outputs to be contributed by the holder (excluding the funding output)
112- outputs : HashMap < SerialId , InteractiveTxOutput > ,
113+ /// The outputs contributed by the holder
114+ local_outputs : HashMap < SerialId , InteractiveTxOutput > ,
115+ /// The outputs contributed by the counterparty
116+ remote_outputs : HashMap < SerialId , InteractiveTxOutput > ,
113117 /// The locktime of the funding transaction.
114118 tx_locktime : AbsoluteLockTime ,
115119 /// The fee rate used for the transaction
@@ -129,12 +133,14 @@ impl NegotiationContext {
129133 holder_is_initiator,
130134 received_tx_add_input_count : 0 ,
131135 received_tx_add_output_count : 0 ,
132- inputs : new_hash_map ( ) ,
136+ local_inputs : new_hash_map ( ) ,
137+ remote_inputs : new_hash_map ( ) ,
133138 intended_new_funding_output,
134139 intended_local_contribution_satoshis,
135140 actual_new_funding_output : None ,
136141 prevtx_outpoints : new_hash_set ( ) ,
137- outputs : new_hash_map ( ) ,
142+ local_outputs : new_hash_map ( ) ,
143+ remote_outputs : new_hash_map ( ) ,
138144 tx_locktime,
139145 feerate_sat_per_kw,
140146 }
@@ -191,24 +197,16 @@ impl NegotiationContext {
191197 self . holder_is_initiator == serial_id. is_for_non_initiator ( )
192198 }
193199
194- fn total_input_and_output_count ( & self ) -> usize {
195- self . inputs . len ( ) . saturating_add ( self . outputs . len ( ) )
200+ fn total_input_count ( & self ) -> usize {
201+ self . local_inputs . len ( ) . saturating_add ( self . remote_inputs . len ( ) )
196202 }
197203
198- fn counterparty_inputs_contributed ( & self ) -> impl Iterator < Item = & InteractiveTxInput > + Clone {
199- self . inputs
200- . iter ( )
201- . filter ( move |( serial_id, _) | self . is_serial_id_valid_for_counterparty ( serial_id) )
202- . map ( |( _, input_with_prevout) | input_with_prevout)
204+ fn total_output_count ( & self ) -> usize {
205+ self . local_outputs . len ( ) . saturating_add ( self . remote_outputs . len ( ) )
203206 }
204207
205- fn counterparty_outputs_contributed (
206- & self ,
207- ) -> impl Iterator < Item = & InteractiveTxOutput > + Clone {
208- self . outputs
209- . iter ( )
210- . filter ( move |( serial_id, _) | self . is_serial_id_valid_for_counterparty ( serial_id) )
211- . map ( |( _, output) | output)
208+ fn total_input_and_output_count ( & self ) -> usize {
209+ self . total_input_count ( ) . saturating_add ( self . total_output_count ( ) )
212210 }
213211
214212 fn received_tx_add_input ( & mut self , msg : & msgs:: TxAddInput ) -> Result < ( ) , AbortReason > {
@@ -265,7 +263,7 @@ impl NegotiationContext {
265263 }
266264
267265 let prev_outpoint = OutPoint { txid, vout : msg. prevtx_out } ;
268- match self . inputs . entry ( msg. serial_id ) {
266+ match self . remote_inputs . entry ( msg. serial_id ) {
269267 hash_map:: Entry :: Occupied ( _) => {
270268 // The receiving node:
271269 // - MUST fail the negotiation if:
@@ -303,7 +301,7 @@ impl NegotiationContext {
303301 return Err ( AbortReason :: IncorrectSerialIdParity ) ;
304302 }
305303
306- self . inputs
304+ self . remote_inputs
307305 . remove ( & msg. serial_id )
308306 // The receiving node:
309307 // - MUST fail the negotiation if:
@@ -339,7 +337,7 @@ impl NegotiationContext {
339337 // Check that adding this output would not cause the total output value to exceed the total
340338 // bitcoin supply.
341339 let mut outputs_value: u64 = 0 ;
342- for output in self . outputs . iter ( ) {
340+ for output in self . local_outputs . iter ( ) . chain ( self . remote_outputs . iter ( ) ) {
343341 outputs_value = outputs_value. saturating_add ( output. 1 . value ( ) ) ;
344342 }
345343 if outputs_value. saturating_add ( msg. sats ) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
@@ -377,7 +375,7 @@ impl NegotiationContext {
377375 } else {
378376 InteractiveTxOutput :: Remote ( RemoteOutput { serial_id : msg. serial_id , txout } )
379377 } ;
380- match self . outputs . entry ( msg. serial_id ) {
378+ match self . remote_outputs . entry ( msg. serial_id ) {
381379 hash_map:: Entry :: Occupied ( _) => {
382380 // The receiving node:
383381 // - MUST fail the negotiation if:
@@ -395,7 +393,7 @@ impl NegotiationContext {
395393 if !self . is_serial_id_valid_for_counterparty ( & msg. serial_id ) {
396394 return Err ( AbortReason :: IncorrectSerialIdParity ) ;
397395 }
398- if let Some ( _) = self . outputs . remove ( & msg. serial_id ) {
396+ if let Some ( _) = self . remote_outputs . remove ( & msg. serial_id ) {
399397 Ok ( ( ) )
400398 } else {
401399 // The receiving node:
@@ -430,7 +428,7 @@ impl NegotiationContext {
430428 . ok_or ( AbortReason :: PrevTxOutInvalid ) ?
431429 . value ,
432430 } ) ;
433- self . inputs . insert ( msg. serial_id , input) ;
431+ self . local_inputs . insert ( msg. serial_id , input) ;
434432 Ok ( ( ) )
435433 }
436434
@@ -443,17 +441,17 @@ impl NegotiationContext {
443441 } else {
444442 InteractiveTxOutput :: Local ( LocalOutput { serial_id : msg. serial_id , txout } )
445443 } ;
446- self . outputs . insert ( msg. serial_id , output) ;
444+ self . local_outputs . insert ( msg. serial_id , output) ;
447445 Ok ( ( ) )
448446 }
449447
450448 fn sent_tx_remove_input ( & mut self , msg : & msgs:: TxRemoveInput ) -> Result < ( ) , AbortReason > {
451- self . inputs . remove ( & msg. serial_id ) ;
449+ self . local_inputs . remove ( & msg. serial_id ) ;
452450 Ok ( ( ) )
453451 }
454452
455453 fn sent_tx_remove_output ( & mut self , msg : & msgs:: TxRemoveOutput ) -> Result < ( ) , AbortReason > {
456- self . outputs . remove ( & msg. serial_id ) ;
454+ self . local_outputs . remove ( & msg. serial_id ) ;
457455 Ok ( ( ) )
458456 }
459457
@@ -464,11 +462,19 @@ impl NegotiationContext {
464462 // - the peer's total input satoshis with its part of any shared input is less than their outputs
465463 // and proportion of any shared output
466464 let mut counterparty_value_in: u64 = 0 ;
467- for ( _, input) in & self . inputs {
465+ // Consider remote and local also, due to possible shared inputs
466+ for ( _, input) in & self . remote_inputs {
467+ counterparty_value_in = counterparty_value_in. saturating_add ( input. remote_value ( ) ) ;
468+ }
469+ for ( _, input) in & self . local_inputs {
468470 counterparty_value_in = counterparty_value_in. saturating_add ( input. remote_value ( ) ) ;
469471 }
470472 let mut counterparty_value_out: u64 = 0 ;
471- for ( _, output) in & self . outputs {
473+ // Consider both local and remote, due to possible shared inputs
474+ for ( _, output) in & self . remote_outputs {
475+ counterparty_value_out = counterparty_value_out. saturating_add ( output. remote_value ( ) ) ;
476+ }
477+ for ( _, output) in & self . local_outputs {
472478 counterparty_value_out = counterparty_value_out. saturating_add ( output. remote_value ( ) ) ;
473479 }
474480 if counterparty_value_in < counterparty_value_out {
@@ -477,8 +483,8 @@ impl NegotiationContext {
477483
478484 // - there are more than 252 inputs
479485 // - there are more than 252 outputs
480- if self . inputs . len ( ) > MAX_INPUTS_OUTPUTS_COUNT
481- || self . outputs . len ( ) > MAX_INPUTS_OUTPUTS_COUNT
486+ if self . total_input_count ( ) > MAX_INPUTS_OUTPUTS_COUNT
487+ || self . total_output_count ( ) > MAX_INPUTS_OUTPUTS_COUNT
482488 {
483489 return Err ( AbortReason :: ExceededNumberOfInputsOrOutputs ) ;
484490 }
@@ -488,14 +494,14 @@ impl NegotiationContext {
488494
489495 // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
490496 let mut counterparty_weight_contributed: u64 = self
491- . counterparty_outputs_contributed ( )
492- . map ( |output| {
497+ . remote_outputs
498+ . iter ( )
499+ . map ( |( _, output) | {
493500 ( 8 /* value */ + output. script_pubkey ( ) . consensus_encode ( & mut sink ( ) ) . unwrap ( ) as u64 )
494501 * WITNESS_SCALE_FACTOR as u64
495502 } )
496503 . sum ( ) ;
497- counterparty_weight_contributed +=
498- self . counterparty_inputs_contributed ( ) . count ( ) as u64 * INPUT_WEIGHT ;
504+ counterparty_weight_contributed += self . remote_inputs . len ( ) as u64 * INPUT_WEIGHT ;
499505 let counterparty_fees_contributed =
500506 counterparty_value_in. saturating_sub ( counterparty_value_out) ;
501507 let mut required_counterparty_contribution_fee =
@@ -516,8 +522,13 @@ impl NegotiationContext {
516522 }
517523
518524 // Inputs and outputs must be sorted by serial_id
519- let mut inputs = self . inputs . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
520- let mut outputs = self . outputs . into_iter ( ) . collect :: < Vec < _ > > ( ) ;
525+ let mut inputs =
526+ self . local_inputs . into_iter ( ) . chain ( self . remote_inputs . into_iter ( ) ) . collect :: < Vec < _ > > ( ) ;
527+ let mut outputs = self
528+ . local_outputs
529+ . into_iter ( )
530+ . chain ( self . remote_outputs . into_iter ( ) )
531+ . collect :: < Vec < _ > > ( ) ;
521532 inputs. sort_unstable_by_key ( |( serial_id, _) | * serial_id) ;
522533 outputs. sort_unstable_by_key ( |( serial_id, _) | * serial_id) ;
523534
0 commit comments