@@ -162,9 +162,12 @@ public void computeRsaPssSignature(
162162 // Generate padding string PS, which is a string of zero bytes
163163 int emBits = computations .getModulus ().getValue ().bitLength () - 1 ;
164164 int emLength = (emBits + 7 ) / 8 ;
165-
166- byte [] psValue =
167- new byte [emLength - computations .getSalt ().getValue ().length - hValue .length - 2 ];
165+ int psLenght = emLength - computations .getSalt ().getValue ().length - hValue .length - 2 ;
166+ if (psLenght < 0 ) {
167+ LOGGER .warn ("PS length is negative. Overwritting with 0" );
168+ psLenght = 0 ;
169+ }
170+ byte [] psValue = new byte [psLenght ];
168171 computations .setPsValue (psValue );
169172 psValue = computations .getPsValue ().getValue ();
170173 LOGGER .debug ("Ps value: {}" , psValue );
@@ -178,15 +181,16 @@ public void computeRsaPssSignature(
178181 // Mask generation function (MGF1)
179182 byte [] dbMask = maskGeneratorFunction1 (hValue , mgf1Algorithm , emLength - hValue .length - 1 );
180183 LOGGER .debug ("DB mask: {}" , dbMask );
181- assert (db .length == dbMask .length );
182184 byte [] maskedDB = mask (db , dbMask );
183185 computations .setMaskedDb (maskedDB );
184186 maskedDB = computations .getMaskedDb ().getValue ();
185187 LOGGER .debug ("Masked DB: {}" , maskedDB );
186188 computations .setTfValue (new byte [] {(byte ) 0xBC });
187189
188190 int firstByteMask = 0xff >>> ((emLength * 8 ) - emBits );
189- maskedDB [0 ] &= firstByteMask ;
191+ if (maskedDB .length > 0 ) {
192+ maskedDB [0 ] &= firstByteMask ;
193+ }
190194 // Construct the encoded message EM = maskedDB || H || 0xBC
191195 byte [] em =
192196 ArrayConverter .concatenate (maskedDB , hValue , computations .getTfValue ().getValue ());
@@ -206,8 +210,12 @@ public void computeRsaPssSignature(
206210 }
207211
208212 private byte [] mask (byte [] value , byte [] mask ) {
213+ // Usually value and mask will be of equal length, but invalid values may cause this to not
214+ // hold
215+ // that is why we take the minimum of both lengths here.
216+ int length = Math .min (value .length , mask .length );
209217 byte [] maskedValue = new byte [value .length ];
210- for (int i = 0 ; i < value . length ; i ++) {
218+ for (int i = 0 ; i < length ; i ++) {
211219 maskedValue [i ] = (byte ) (value [i ] ^ mask [i ]);
212220 }
213221 return maskedValue ;
0 commit comments