commit 2b0447dd2e3ed8f7d8c2e4e163e26c0bfd78103a
parent a2b41d28c5759bf01744de3cffc95ef8226aadc7
Author: Antoine A <>
Date: Thu, 14 Nov 2024 17:00:55 +0100
common: kotlin amount math and code cleanup
Diffstat:
4 files changed, 59 insertions(+), 22 deletions(-)
diff --git a/bank/src/test/kotlin/AmountTest.kt b/bank/src/test/kotlin/AmountTest.kt
@@ -22,7 +22,7 @@ import org.junit.Test
import tech.libeufin.common.*
import tech.libeufin.common.db.*
import tech.libeufin.common.test.*
-import kotlin.test.assertEquals
+import kotlin.test.*
class AmountTest {
// Test amount computation in db
@@ -189,7 +189,7 @@ class AmountTest {
fun normalize() = dbSetup { db ->
db.conn { conn ->
val stmt = conn.prepareStatement("SELECT normalized.val, normalized.frac FROM amount_normalize((?, ?)::taler_amount) as normalized")
- fun TalerAmount.normalize(): TalerAmount {
+ fun TalerAmount.db(): TalerAmount {
stmt.setLong(1, value)
stmt.setInt(2, frac)
return stmt.one {
@@ -200,15 +200,27 @@ class AmountTest {
)
}
}
+
+ fun assertNormalize(from: TalerAmount, to: TalerAmount) {
+ val normalized = from.normalize()
+ assertEquals(to, normalized, "Bad normalization")
+ val dbNormalized = from.db()
+ assertEquals(normalized, dbNormalized, "DB vs code behavior")
+ }
+
+ fun assertErr(from: TalerAmount, msg: String) {
+ assertFails { from.normalize() }
+ assertException(msg) { from.db() }
+ }
- assertEquals(TalerAmount("EUR:6"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE, "EUR").normalize())
- assertEquals(TalerAmount("EUR:6.00000001"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE + 1, "EUR").normalize())
- assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999").normalize())
- assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, TalerAmount.FRACTION_BASE, "EUR").normalize() }
- assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE, TalerAmount.FRACTION_BASE , "EUR").normalize() }
+ assertNormalize(TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE, "EUR"), TalerAmount("EUR:6"))
+ assertNormalize(TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE + 1, "EUR"), TalerAmount("EUR:6.00000001"))
+ assertNormalize(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"))
+ assertErr(TalerAmount(Long.MAX_VALUE, TalerAmount.FRACTION_BASE, "EUR"), "ERROR: bigint out of range")
+ assertErr(TalerAmount(TalerAmount.MAX_VALUE, TalerAmount.FRACTION_BASE , "EUR"), "ERROR: amount value overflowed")
for (amount in listOf(TalerAmount.max("EUR"), TalerAmount.zero("EUR"))) {
- assertEquals(amount, amount.normalize())
+ assertNormalize(amount, amount)
}
}
}
@@ -217,7 +229,7 @@ class AmountTest {
fun add() = dbSetup { db ->
db.conn { conn ->
val stmt = conn.prepareStatement("SELECT sum.val, sum.frac FROM amount_add((?, ?)::taler_amount, (?, ?)::taler_amount) as sum")
- operator fun TalerAmount.plus(increment: TalerAmount): TalerAmount {
+ fun TalerAmount.db(increment: TalerAmount): TalerAmount {
stmt.setLong(1, value)
stmt.setInt(2, frac)
stmt.setLong(3, increment.value)
@@ -230,14 +242,27 @@ class AmountTest {
)
}
}
- assertEquals(TalerAmount.max("EUR"), TalerAmount.max("EUR") + TalerAmount.zero("EUR"))
- assertEquals(TalerAmount.zero("EUR"), TalerAmount.zero("EUR") + TalerAmount.zero("EUR"))
- assertEquals(TalerAmount("EUR:6.41") + TalerAmount("EUR:4.69"), TalerAmount("EUR:11.1"))
- assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}") + TalerAmount("EUR:0.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"))
- assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, 0, "EUR") + TalerAmount(6, 0, "EUR") }
- assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, 0, "EUR") + TalerAmount(1, 0, "EUR") }
- assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, TalerAmount.FRACTION_BASE - 1, "EUR") + TalerAmount(5, 2, "EUR") }
- assertException("ERROR: integer out of range") { TalerAmount(0, Int.MAX_VALUE, "EUR") + TalerAmount(0, 1, "EUR") }
+
+ fun assertAdd(a: TalerAmount, b: TalerAmount, sum: TalerAmount) {
+ val codeSum = a + b
+ assertEquals(sum, codeSum, "Bad sum")
+ val dbSum = a.db(b)
+ assertEquals(codeSum, dbSum, "DB vs code behavior")
+ }
+
+ fun assertErr(a: TalerAmount, b: TalerAmount, msg: String) {
+ assertFails { a + b }
+ assertException(msg) { a.db(b) }
+ }
+
+ assertAdd(TalerAmount.max("EUR"), TalerAmount.zero("EUR"), TalerAmount.max("EUR"))
+ assertAdd(TalerAmount.zero("EUR"), TalerAmount.zero("EUR"), TalerAmount.zero("EUR"))
+ assertAdd(TalerAmount("EUR:6.41"), TalerAmount("EUR:4.69"), TalerAmount("EUR:11.1"))
+ assertAdd(TalerAmount("EUR:${TalerAmount.MAX_VALUE}"), TalerAmount("EUR:0.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"))
+ assertErr(TalerAmount(TalerAmount.MAX_VALUE - 5, 0, "EUR"), TalerAmount(6, 0, "EUR"), "ERROR: amount value overflowed")
+ assertErr(TalerAmount(Long.MAX_VALUE, 0, "EUR"), TalerAmount(1, 0, "EUR"), "ERROR: bigint out of range")
+ assertErr(TalerAmount(TalerAmount.MAX_VALUE - 5, TalerAmount.FRACTION_BASE - 1, "EUR"), TalerAmount(5, 2, "EUR"), "ERROR: amount value overflowed")
+ assertErr(TalerAmount(0, Int.MAX_VALUE, "EUR"), TalerAmount(0, 1, "EUR"), "ERROR: integer out of range")
}
}
diff --git a/common/src/main/kotlin/Encoding.kt b/common/src/main/kotlin/Encoding.kt
@@ -64,11 +64,9 @@ object Base32Crockford {
for (char in encoded) {
// Read input
val index = char - '0'
- if (index < 0 || index > INV.size)
- throw IllegalArgumentException("invalid Base32 character: $char")
+ require(index in 0..INV.size) { "invalid Base32 character: $char" }
val decoded = INV[index]
- if (decoded == -1)
- throw IllegalArgumentException("invalid Base32 character: $char")
+ require(decoded != -1) { "invalid Base32 character: $char" }
buffer = (buffer shl 5) or decoded
bitsLeft += 5
// Write bytes
diff --git a/common/src/main/kotlin/TalerCommon.kt b/common/src/main/kotlin/TalerCommon.kt
@@ -220,6 +220,20 @@ class TalerAmount {
}
}
+ fun normalize(): TalerAmount {
+ val value = Math.addExact(this.value, (this.frac / FRACTION_BASE).toLong())
+ val frac = this.frac % FRACTION_BASE
+ println("${this.value}+${this.frac / FRACTION_BASE}=${value} ${MAX_VALUE}")
+ if (value > MAX_VALUE) throw ArithmeticException("amount value overflowed")
+ return TalerAmount(value, frac, currency)
+ }
+
+ operator fun plus(increment: TalerAmount): TalerAmount {
+ val value = Math.addExact(this.value, increment.value)
+ val frac = Math.addExact(this.frac, increment.frac)
+ return TalerAmount(value, frac, currency).normalize()
+ }
+
internal object Serializer : KSerializer<TalerAmount> {
override val descriptor: SerialDescriptor =
PrimitiveSerialDescriptor("TalerAmount", PrimitiveKind.STRING)
diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/cli/EbicsSetup.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/cli/EbicsSetup.kt
@@ -74,7 +74,7 @@ private suspend fun submitClientKeys(
ebicsLogger: EbicsLogger,
order: EbicsKeyMng.Order
) {
- if (order == HPB) throw IllegalArgumentException("Only INI & HIA are supported for client keys")
+ require(order != HPB) { "Only INI & HIA are supported for client keys" }
val resp = keyManagement(cfg, privs, client, ebicsLogger, order)
if (resp.technicalCode == EbicsReturnCode.EBICS_INVALID_USER_OR_USER_STATE) {
throw Exception("$order status code ${resp.technicalCode}: either your IDs are incorrect, or you already have keys registered with this bank")