1212import de .rub .nds .modifiablevariable .util .ArrayConverter ;
1313import de .rub .nds .tlsattacker .core .constants .AlgorithmResolver ;
1414import de .rub .nds .tlsattacker .core .constants .HKDFAlgorithm ;
15- import de .rub .nds .tlsattacker .core .constants .KeyUpdateRequest ;
1615import de .rub .nds .tlsattacker .core .constants .Tls13KeySetType ;
1716import de .rub .nds .tlsattacker .core .crypto .HKDFunction ;
1817import de .rub .nds .tlsattacker .core .exceptions .AdjustmentException ;
@@ -47,35 +46,27 @@ public KeyUpdateHandler(TlsContext tlsContext) {
4746
4847 @ Override
4948 public void adjustTLSContext (KeyUpdateMessage message ) {
50-
49+ if (tlsContext .getChooser ().getTalkingConnectionEnd () != tlsContext .getChooser ().getConnectionEndType ()) {
50+ adjustApplicationTrafficSecrets ();
51+ setRecordCipher (Tls13KeySetType .APPLICATION_TRAFFIC_SECRETS );
52+ }
5153 }
5254
5355 @ Override
5456 public void adjustTlsContextAfterSerialize (KeyUpdateMessage message ) {
55-
56- if (message .getRequestUpdate () == KeyUpdateRequest .UPDATE_REQUESTED ) {
57- adjustApplicationTrafficSecrets ();
58- }
57+ adjustApplicationTrafficSecrets ();
5958 setRecordCipher (Tls13KeySetType .APPLICATION_TRAFFIC_SECRETS );
60-
6159 }
6260
6361 @ Override
6462 public ProtocolMessageParser getParser (byte [] message , int pointer ) {
65-
6663 return new KeyUpdateParser (pointer , message , tlsContext .getChooser ().getSelectedProtocolVersion (),
6764 tlsContext .getConfig ());
6865
6966 }
7067
7168 @ Override
7269 public ProtocolMessagePreparator getPreparator (KeyUpdateMessage message ) {
73- if (tlsContext .getChooser ().getTalkingConnectionEnd () != tlsContext .getChooser ().getConnectionEndType ()) {
74- if (message .getRequestUpdate () == KeyUpdateRequest .UPDATE_REQUESTED ) {
75- adjustApplicationTrafficSecrets ();
76- }
77- setRecordCipher (Tls13KeySetType .APPLICATION_TRAFFIC_SECRETS );
78- }
7970 return new KeyUpdatePreparator (tlsContext .getChooser (), message );
8071 }
8172
@@ -89,23 +80,29 @@ private void adjustApplicationTrafficSecrets() {
8980 .getSelectedCipherSuite ());
9081
9182 try {
92-
9383 Mac mac = Mac .getInstance (hkdfAlgortihm .getMacAlgorithm ().getJavaName ());
94- byte [] clientApplicationTrafficSecret = HKDFunction .expandLabel (hkdfAlgortihm ,
95- tlsContext .getClientApplicationTrafficSecret (), HKDFunction .TRAFFICUPD , new byte [0 ],
96- mac .getMacLength ());
9784
98- tlsContext .setClientApplicationTrafficSecret (clientApplicationTrafficSecret );
99- LOGGER .debug ("Set clientApplicationTrafficSecret in Context to "
100- + ArrayConverter .bytesToHexString (clientApplicationTrafficSecret ));
85+ if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT ) {
86+
87+ byte [] clientApplicationTrafficSecret = HKDFunction .expandLabel (hkdfAlgortihm ,
88+ tlsContext .getClientApplicationTrafficSecret (), HKDFunction .TRAFFICUPD , new byte [0 ],
89+ mac .getMacLength ());
90+
91+ tlsContext .setClientApplicationTrafficSecret (clientApplicationTrafficSecret );
92+ LOGGER .debug ("Set clientApplicationTrafficSecret in Context to "
93+ + ArrayConverter .bytesToHexString (clientApplicationTrafficSecret ));
10194
102- byte [] serverApplicationTrafficSecret = HKDFunction .expandLabel (hkdfAlgortihm ,
103- tlsContext .getServerApplicationTrafficSecret (), HKDFunction .TRAFFICUPD , new byte [0 ],
104- mac .getMacLength ());
95+ } else {
10596
106- tlsContext .setServerApplicationTrafficSecret (serverApplicationTrafficSecret );
107- LOGGER .debug ("Set serverApplicationTrafficSecret in Context to "
108- + ArrayConverter .bytesToHexString (serverApplicationTrafficSecret ));
97+ byte [] serverApplicationTrafficSecret = HKDFunction .expandLabel (hkdfAlgortihm ,
98+ tlsContext .getServerApplicationTrafficSecret (), HKDFunction .TRAFFICUPD , new byte [0 ],
99+ mac .getMacLength ());
100+
101+ tlsContext .setServerApplicationTrafficSecret (serverApplicationTrafficSecret );
102+ LOGGER .debug ("Set serverApplicationTrafficSecret in Context to "
103+ + ArrayConverter .bytesToHexString (serverApplicationTrafficSecret ));
104+
105+ }
109106
110107 } catch (NoSuchAlgorithmException | CryptoException ex ) {
111108 throw new AdjustmentException (ex );
@@ -127,19 +124,26 @@ private KeySet getKeySet(TlsContext context, Tls13KeySetType keySetType) {
127124 private void setRecordCipher (Tls13KeySetType keySetType ) {
128125 try {
129126 int AEAD_IV_LENGTH = 12 ;
127+ KeySet keySet ;
130128 HKDFAlgorithm hkdfAlgortihm = AlgorithmResolver .getHKDFAlgorithm (tlsContext .getChooser ()
131129 .getSelectedCipherSuite ());
132130
133- tlsContext .setActiveClientKeySetType (keySetType );
134- LOGGER .debug ("Setting cipher for client to use " + keySetType );
135- KeySet keySet = getKeySet (tlsContext , tlsContext .getActiveClientKeySetType ());
131+ if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT ) {
136132
137- if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT
138- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .CLIENT
139- || tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .SERVER
140- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .SERVER ) {
133+ tlsContext .setActiveClientKeySetType (keySetType );
134+ LOGGER .debug ("Setting cipher for client to use " + keySetType );
135+ keySet = getKeySet (tlsContext , tlsContext .getActiveClientKeySetType ());
136+
137+ } else {
138+ tlsContext .setActiveServerKeySetType (keySetType );
139+ LOGGER .debug ("Setting cipher for server to use " + keySetType );
140+ keySet = getKeySet (tlsContext , tlsContext .getActiveServerKeySetType ());
141+ }
142+
143+ if (tlsContext .getChooser ().getTalkingConnectionEnd () == tlsContext .getChooser ().getConnectionEndType ()) {
141144
142145 if (tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .CLIENT ) {
146+
143147 keySet .setClientWriteIv (HKDFunction .expandLabel (hkdfAlgortihm ,
144148 tlsContext .getClientApplicationTrafficSecret (), HKDFunction .IV , new byte [0 ], AEAD_IV_LENGTH ));
145149
@@ -156,10 +160,8 @@ private void setRecordCipher(Tls13KeySetType keySetType) {
156160 AlgorithmResolver .getCipher (tlsContext .getChooser ().getSelectedCipherSuite ()).getKeySize ()));
157161 }
158162
159- } else if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .SERVER
160- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .CLIENT
161- || tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT
162- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .SERVER ) {
163+ } else if (tlsContext .getChooser ().getTalkingConnectionEnd () != tlsContext .getChooser ()
164+ .getConnectionEndType ()) {
163165
164166 if (tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .SERVER ) {
165167
@@ -186,18 +188,11 @@ private void setRecordCipher(Tls13KeySetType keySetType) {
186188 .getChooser ().getSelectedCipherSuite ());
187189 tlsContext .getRecordLayer ().setRecordCipher (recordCipherClient );
188190
189- if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT
190- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .CLIENT
191- || tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .SERVER
192- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .SERVER ) {
193-
191+ if (tlsContext .getChooser ().getTalkingConnectionEnd () == tlsContext .getChooser ().getConnectionEndType ()) {
194192 tlsContext .setWriteSequenceNumber (0 );
195193 tlsContext .getRecordLayer ().updateEncryptionCipher ();
196- } else if (tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .SERVER
197- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .CLIENT
198- || tlsContext .getChooser ().getTalkingConnectionEnd () == ConnectionEndType .CLIENT
199- && tlsContext .getChooser ().getConnectionEndType () == ConnectionEndType .SERVER ) {
200-
194+ } else if (tlsContext .getChooser ().getTalkingConnectionEnd () != tlsContext .getChooser ()
195+ .getConnectionEndType ()) {
201196 tlsContext .setReadSequenceNumber (0 );
202197 tlsContext .getRecordLayer ().updateDecryptionCipher ();
203198 }
0 commit comments