package net.taler.wallet.kotlin.crypto import java.math.BigInteger import kotlin.math.ceil import kotlin.math.floor @OptIn(ExperimentalStdlibApi::class) internal object RsaBlinding { fun rsaBlind(hm: ByteArray, bks: ByteArray, rsaPubEnc: ByteArray): ByteArray { val rsaPub = rsaPubDecode(rsaPubEnc) val data = rsaFullDomainHash(hm, rsaPub) val r = rsaBlindingKeyDerive(rsaPub, bks) val rE = r.modPow(rsaPub.e, rsaPub.n) val bm = rE.multiply(data).mod(rsaPub.n) return bm.toByteArrayWithoutSign() } fun rsaUnblind(sig: ByteArray, rsaPubEnc: ByteArray, bks: ByteArray): ByteArray { val rsaPub = rsaPubDecode(rsaPubEnc) val blindedSig = BigInteger(1, sig) val r = rsaBlindingKeyDerive(rsaPub, bks) val rInv = r.modInverse(rsaPub.n) val s = blindedSig.multiply(rInv).mod(rsaPub.n) return s.toByteArrayWithoutSign() } fun rsaVerify(hm: ByteArray, rsaSig: ByteArray, rsaPubEnc: ByteArray): Boolean { val rsaPub = rsaPubDecode(rsaPubEnc) val d = rsaFullDomainHash(hm, rsaPub) val sig = BigInteger(1, rsaSig) val sigE = sig.modPow(rsaPub.e, rsaPub.n) return sigE == d } private fun rsaBlindingKeyDerive(rsaPub: RsaPublicKey, bks: ByteArray): BigInteger { val salt = "Blinding KDF extrator HMAC key".encodeToByteArray() val info = "Blinding KDF".encodeToByteArray() return kdfMod(rsaPub.n, bks, salt, info) } private fun rsaPubDecode(publicKey: ByteArray): RsaPublicKey { val modulusLength = (publicKey[0].toInt() shl 8) or publicKey[1].toInt() val exponentLength = (publicKey[2].toInt() shl 8) or publicKey[3].toInt() if (4 + exponentLength + modulusLength != publicKey.size) { throw Error("invalid RSA public key (format wrong)") } val modulus = publicKey.copyOfRange(4, 4 + modulusLength) val exponent = publicKey.copyOfRange( 4 + modulusLength, 4 + modulusLength + exponentLength ) return RsaPublicKey(BigInteger(1, modulus), BigInteger(1, exponent)) } private fun rsaFullDomainHash(hm: ByteArray, rsaPublicKey: RsaPublicKey): BigInteger { val info = "RSA-FDA FTpsW!".encodeToByteArray() val salt = rsaPubEncode(rsaPublicKey) val r = kdfMod(rsaPublicKey.n, hm, salt, info) rsaGcdValidate(r, rsaPublicKey.n) return r } private fun rsaPubEncode(rsaPublicKey: RsaPublicKey): ByteArray { val mb = rsaPublicKey.n.toByteArrayWithoutSign() val eb = rsaPublicKey.e.toByteArrayWithoutSign() val out = ByteArray(4 + mb.size + eb.size) out[0] = ((mb.size ushr 8) and 0xff).toByte() out[1] = (mb.size and 0xff).toByte() out[2] = ((eb.size ushr 8) and 0xff).toByte() out[3] = (eb.size and 0xff).toByte() mb.copyInto(out, destinationOffset = 4) eb.copyInto(out, destinationOffset = 4 + mb.size) return out } private fun kdfMod(n: BigInteger, ikm: ByteArray, salt: ByteArray, info: ByteArray): BigInteger { val nBits = n.bitLength() val bufLen = floor((nBits.toDouble() - 1) / 8 + 1).toInt() val mask = (1 shl (8 - (bufLen * 8 - nBits))) - 1 var counter = 0 while (true) { val ctx = ByteArray(info.size + 2) info.copyInto(ctx) ctx[ctx.size - 2] = ((counter ushr 8) and 0xff).toByte() ctx[ctx.size - 1] = (counter and 0xff).toByte() val buf = CryptoJvmImpl.kdf(bufLen, ikm, salt, ctx) val arr = buf.copyOf() arr[0] = (arr[0].toInt() and mask).toByte() val r = BigInteger(1, arr) if (r < n) return r counter++ } } /** * Test for malicious RSA key. * * Assuming n is an RSA modulous and r is generated using a call to * GNUNET_CRYPTO_kdf_mod_mpi, if gcd(r,n) != 1 then n must be a * malicious RSA key designed to deanomize the user. * * @param r KDF result * @param n RSA modulus of the public key */ private fun rsaGcdValidate(r: BigInteger, n: BigInteger) { if (r.gcd(n) != BigInteger.ONE) throw Error("malicious RSA public key") } // TODO check that this strips *only* the sign correctly private fun BigInteger.toByteArrayWithoutSign(): ByteArray = this.toByteArray().let { val byteLength = ceil(this.bitLength().toDouble() / 8).toInt() val signBitPosition = ceil((this.bitLength() + 1).toDouble() / 8).toInt() val start = signBitPosition - byteLength it.copyOfRange(start, it.size) // stripping least significant byte (sign) } } internal class RsaPublicKey(val n: BigInteger, val e: BigInteger)