summaryrefslogtreecommitdiff
path: root/src/androidMain/kotlin/net/taler/wallet/kotlin/crypto/RsaBlinding.kt
blob: 6158c527c58c32e1282311562793b33fe22e1154 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package net.taler.wallet.kotlin.crypto

import java.math.BigInteger
import kotlin.math.ceil
import kotlin.math.floor

@OptIn(ExperimentalStdlibApi::class)
internal object RsaBlinding {

    fun rsaBlind(hm: ByteArray, bks: ByteArray, rsaPubEnc: ByteArray): ByteArray {
        val rsaPub = rsaPubDecode(rsaPubEnc)
        val data = rsaFullDomainHash(hm, rsaPub)
        val r = rsaBlindingKeyDerive(rsaPub, bks)
        val rE = r.modPow(rsaPub.e, rsaPub.n)
        val bm = rE.multiply(data).mod(rsaPub.n)
        return bm.toByteArrayWithoutSign()
    }

    fun rsaUnblind(sig: ByteArray, rsaPubEnc: ByteArray, bks: ByteArray): ByteArray {
        val rsaPub = rsaPubDecode(rsaPubEnc)
        val blindedSig = BigInteger(1, sig)
        val r = rsaBlindingKeyDerive(rsaPub, bks)
        val rInv = r.modInverse(rsaPub.n)
        val s = blindedSig.multiply(rInv).mod(rsaPub.n)
        return s.toByteArrayWithoutSign()
    }

    fun rsaVerify(hm: ByteArray, rsaSig: ByteArray, rsaPubEnc: ByteArray): Boolean {
        val rsaPub = rsaPubDecode(rsaPubEnc)
        val d = rsaFullDomainHash(hm, rsaPub)
        val sig = BigInteger(1, rsaSig)
        val sigE = sig.modPow(rsaPub.e, rsaPub.n)
        return sigE == d
    }

    private fun rsaBlindingKeyDerive(rsaPub: RsaPublicKey, bks: ByteArray): BigInteger {
        val salt = "Blinding KDF extrator HMAC key".encodeToByteArray()
        val info = "Blinding KDF".encodeToByteArray()
        return kdfMod(rsaPub.n, bks, salt, info)
    }

    private fun rsaPubDecode(publicKey: ByteArray): RsaPublicKey {
        val modulusLength = (publicKey[0].toInt() shl 8) or publicKey[1].toInt()
        val exponentLength = (publicKey[2].toInt() shl 8) or publicKey[3].toInt()
        if (4 + exponentLength + modulusLength != publicKey.size) {
            throw Error("invalid RSA public key (format wrong)")
        }
        val modulus = publicKey.copyOfRange(4, 4 + modulusLength)
        val exponent = publicKey.copyOfRange(
            4 + modulusLength,
            4 + modulusLength + exponentLength
        )
        return RsaPublicKey(BigInteger(1, modulus), BigInteger(1, exponent))
    }

    private fun rsaFullDomainHash(hm: ByteArray, rsaPublicKey: RsaPublicKey): BigInteger {
        val info = "RSA-FDA FTpsW!".encodeToByteArray()
        val salt = rsaPubEncode(rsaPublicKey)
        val r = kdfMod(rsaPublicKey.n, hm, salt, info)
        rsaGcdValidate(r, rsaPublicKey.n)
        return r
    }

    private fun rsaPubEncode(rsaPublicKey: RsaPublicKey): ByteArray {
        val mb = rsaPublicKey.n.toByteArrayWithoutSign()
        val eb = rsaPublicKey.e.toByteArrayWithoutSign()
        val out = ByteArray(4 + mb.size + eb.size)
        out[0] = ((mb.size ushr 8) and 0xff).toByte()
        out[1] = (mb.size and 0xff).toByte()
        out[2] = ((eb.size ushr 8) and 0xff).toByte()
        out[3] = (eb.size and 0xff).toByte()
        mb.copyInto(out, destinationOffset = 4)
        eb.copyInto(out, destinationOffset = 4 + mb.size)
        return out
    }

    private fun kdfMod(n: BigInteger, ikm: ByteArray, salt: ByteArray, info: ByteArray): BigInteger {
        val nBits = n.bitLength()
        val bufLen = floor((nBits.toDouble() - 1) / 8 + 1).toInt()
        val mask = (1 shl (8 - (bufLen * 8 - nBits))) - 1
        var counter = 0
        while (true) {
            val ctx = ByteArray(info.size + 2)
            info.copyInto(ctx)
            ctx[ctx.size - 2] = ((counter ushr 8) and 0xff).toByte()
            ctx[ctx.size - 1] = (counter and 0xff).toByte()
            val buf = CryptoJvmImpl.kdf(bufLen, ikm, salt, ctx)
            val arr = buf.copyOf()
            arr[0] = (arr[0].toInt() and mask).toByte()
            val r = BigInteger(1, arr)
            if (r < n) return r
            counter++
        }
    }

    /**
     * Test for malicious RSA key.
     *
     * Assuming n is an RSA modulous and r is generated using a call to
     * GNUNET_CRYPTO_kdf_mod_mpi, if gcd(r,n) != 1 then n must be a
     * malicious RSA key designed to deanomize the user.
     *
     * @param r KDF result
     * @param n RSA modulus of the public key
     */
    private fun rsaGcdValidate(r: BigInteger, n: BigInteger) {
        if (r.gcd(n) != BigInteger.ONE) throw Error("malicious RSA public key")
    }

    // TODO check that this strips *only* the sign correctly
    private fun BigInteger.toByteArrayWithoutSign(): ByteArray = this.toByteArray().let {
        val byteLength = ceil(this.bitLength().toDouble() / 8).toInt()
        val signBitPosition = ceil((this.bitLength() + 1).toDouble() / 8).toInt()
        val start = signBitPosition - byteLength
        it.copyOfRange(start, it.size)  // stripping least significant byte (sign)
    }

}

internal class RsaPublicKey(val n: BigInteger, val e: BigInteger)