libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

commit 2b0447dd2e3ed8f7d8c2e4e163e26c0bfd78103a
parent a2b41d28c5759bf01744de3cffc95ef8226aadc7
Author: Antoine A <>
Date:   Thu, 14 Nov 2024 17:00:55 +0100

common: kotlin amount math and code cleanup

Diffstat:
Mbank/src/test/kotlin/AmountTest.kt | 59++++++++++++++++++++++++++++++++++++++++++-----------------
Mcommon/src/main/kotlin/Encoding.kt | 6++----
Mcommon/src/main/kotlin/TalerCommon.kt | 14++++++++++++++
Mnexus/src/main/kotlin/tech/libeufin/nexus/cli/EbicsSetup.kt | 2+-
4 files changed, 59 insertions(+), 22 deletions(-)

diff --git a/bank/src/test/kotlin/AmountTest.kt b/bank/src/test/kotlin/AmountTest.kt @@ -22,7 +22,7 @@ import org.junit.Test import tech.libeufin.common.* import tech.libeufin.common.db.* import tech.libeufin.common.test.* -import kotlin.test.assertEquals +import kotlin.test.* class AmountTest { // Test amount computation in db @@ -189,7 +189,7 @@ class AmountTest { fun normalize() = dbSetup { db -> db.conn { conn -> val stmt = conn.prepareStatement("SELECT normalized.val, normalized.frac FROM amount_normalize((?, ?)::taler_amount) as normalized") - fun TalerAmount.normalize(): TalerAmount { + fun TalerAmount.db(): TalerAmount { stmt.setLong(1, value) stmt.setInt(2, frac) return stmt.one { @@ -200,15 +200,27 @@ class AmountTest { ) } } + + fun assertNormalize(from: TalerAmount, to: TalerAmount) { + val normalized = from.normalize() + assertEquals(to, normalized, "Bad normalization") + val dbNormalized = from.db() + assertEquals(normalized, dbNormalized, "DB vs code behavior") + } + + fun assertErr(from: TalerAmount, msg: String) { + assertFails { from.normalize() } + assertException(msg) { from.db() } + } - assertEquals(TalerAmount("EUR:6"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE, "EUR").normalize()) - assertEquals(TalerAmount("EUR:6.00000001"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE + 1, "EUR").normalize()) - assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999").normalize()) - assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, TalerAmount.FRACTION_BASE, "EUR").normalize() } - assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE, TalerAmount.FRACTION_BASE , "EUR").normalize() } + assertNormalize(TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE, "EUR"), TalerAmount("EUR:6")) + assertNormalize(TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE + 1, "EUR"), TalerAmount("EUR:6.00000001")) + assertNormalize(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999")) + assertErr(TalerAmount(Long.MAX_VALUE, TalerAmount.FRACTION_BASE, "EUR"), "ERROR: bigint out of range") + assertErr(TalerAmount(TalerAmount.MAX_VALUE, TalerAmount.FRACTION_BASE , "EUR"), "ERROR: amount value overflowed") for (amount in listOf(TalerAmount.max("EUR"), TalerAmount.zero("EUR"))) { - assertEquals(amount, amount.normalize()) + assertNormalize(amount, amount) } } } @@ -217,7 +229,7 @@ class AmountTest { fun add() = dbSetup { db -> db.conn { conn -> val stmt = conn.prepareStatement("SELECT sum.val, sum.frac FROM amount_add((?, ?)::taler_amount, (?, ?)::taler_amount) as sum") - operator fun TalerAmount.plus(increment: TalerAmount): TalerAmount { + fun TalerAmount.db(increment: TalerAmount): TalerAmount { stmt.setLong(1, value) stmt.setInt(2, frac) stmt.setLong(3, increment.value) @@ -230,14 +242,27 @@ class AmountTest { ) } } - assertEquals(TalerAmount.max("EUR"), TalerAmount.max("EUR") + TalerAmount.zero("EUR")) - assertEquals(TalerAmount.zero("EUR"), TalerAmount.zero("EUR") + TalerAmount.zero("EUR")) - assertEquals(TalerAmount("EUR:6.41") + TalerAmount("EUR:4.69"), TalerAmount("EUR:11.1")) - assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}") + TalerAmount("EUR:0.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999")) - assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, 0, "EUR") + TalerAmount(6, 0, "EUR") } - assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, 0, "EUR") + TalerAmount(1, 0, "EUR") } - assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, TalerAmount.FRACTION_BASE - 1, "EUR") + TalerAmount(5, 2, "EUR") } - assertException("ERROR: integer out of range") { TalerAmount(0, Int.MAX_VALUE, "EUR") + TalerAmount(0, 1, "EUR") } + + fun assertAdd(a: TalerAmount, b: TalerAmount, sum: TalerAmount) { + val codeSum = a + b + assertEquals(sum, codeSum, "Bad sum") + val dbSum = a.db(b) + assertEquals(codeSum, dbSum, "DB vs code behavior") + } + + fun assertErr(a: TalerAmount, b: TalerAmount, msg: String) { + assertFails { a + b } + assertException(msg) { a.db(b) } + } + + assertAdd(TalerAmount.max("EUR"), TalerAmount.zero("EUR"), TalerAmount.max("EUR")) + assertAdd(TalerAmount.zero("EUR"), TalerAmount.zero("EUR"), TalerAmount.zero("EUR")) + assertAdd(TalerAmount("EUR:6.41"), TalerAmount("EUR:4.69"), TalerAmount("EUR:11.1")) + assertAdd(TalerAmount("EUR:${TalerAmount.MAX_VALUE}"), TalerAmount("EUR:0.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999")) + assertErr(TalerAmount(TalerAmount.MAX_VALUE - 5, 0, "EUR"), TalerAmount(6, 0, "EUR"), "ERROR: amount value overflowed") + assertErr(TalerAmount(Long.MAX_VALUE, 0, "EUR"), TalerAmount(1, 0, "EUR"), "ERROR: bigint out of range") + assertErr(TalerAmount(TalerAmount.MAX_VALUE - 5, TalerAmount.FRACTION_BASE - 1, "EUR"), TalerAmount(5, 2, "EUR"), "ERROR: amount value overflowed") + assertErr(TalerAmount(0, Int.MAX_VALUE, "EUR"), TalerAmount(0, 1, "EUR"), "ERROR: integer out of range") } } diff --git a/common/src/main/kotlin/Encoding.kt b/common/src/main/kotlin/Encoding.kt @@ -64,11 +64,9 @@ object Base32Crockford { for (char in encoded) { // Read input val index = char - '0' - if (index < 0 || index > INV.size) - throw IllegalArgumentException("invalid Base32 character: $char") + require(index in 0..INV.size) { "invalid Base32 character: $char" } val decoded = INV[index] - if (decoded == -1) - throw IllegalArgumentException("invalid Base32 character: $char") + require(decoded != -1) { "invalid Base32 character: $char" } buffer = (buffer shl 5) or decoded bitsLeft += 5 // Write bytes diff --git a/common/src/main/kotlin/TalerCommon.kt b/common/src/main/kotlin/TalerCommon.kt @@ -220,6 +220,20 @@ class TalerAmount { } } + fun normalize(): TalerAmount { + val value = Math.addExact(this.value, (this.frac / FRACTION_BASE).toLong()) + val frac = this.frac % FRACTION_BASE + println("${this.value}+${this.frac / FRACTION_BASE}=${value} ${MAX_VALUE}") + if (value > MAX_VALUE) throw ArithmeticException("amount value overflowed") + return TalerAmount(value, frac, currency) + } + + operator fun plus(increment: TalerAmount): TalerAmount { + val value = Math.addExact(this.value, increment.value) + val frac = Math.addExact(this.frac, increment.frac) + return TalerAmount(value, frac, currency).normalize() + } + internal object Serializer : KSerializer<TalerAmount> { override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("TalerAmount", PrimitiveKind.STRING) diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/cli/EbicsSetup.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/cli/EbicsSetup.kt @@ -74,7 +74,7 @@ private suspend fun submitClientKeys( ebicsLogger: EbicsLogger, order: EbicsKeyMng.Order ) { - if (order == HPB) throw IllegalArgumentException("Only INI & HIA are supported for client keys") + require(order != HPB) { "Only INI & HIA are supported for client keys" } val resp = keyManagement(cfg, privs, client, ebicsLogger, order) if (resp.technicalCode == EbicsReturnCode.EBICS_INVALID_USER_OR_USER_STATE) { throw Exception("$order status code ${resp.technicalCode}: either your IDs are incorrect, or you already have keys registered with this bank")