commit b69d3c141ba465afd383946749dd1a9f1c034300
parent f5c2f7313c2c3bea20af5544a97d76fafad190de
Author: Antoine A <>
Date: Fri, 20 Oct 2023 12:23:03 +0000
Fix TalerAmount logic and add amount_mul function in preparation for conversion logic
Diffstat:
5 files changed, 131 insertions(+), 27 deletions(-)
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/Database.kt
@@ -37,9 +37,7 @@ import tech.libeufin.util.*
private const val DB_CTR_LIMIT = 1000000
fun Customer.expectRowId(): Long = this.dbRowId ?: throw internalServerError("Cutsomer '$login' had no DB row ID.")
-fun BankAccount.expectBalance(): TalerAmount = this.balance ?: throw internalServerError("Bank account '${this.internalPaytoUri}' lacks balance.")
fun BankAccount.expectRowId(): Long = this.bankAccountId ?: throw internalServerError("Bank account '${this.internalPaytoUri}' lacks database row ID.")
-fun BankAccountTransaction.expectRowId(): Long = this.dbRowId ?: throw internalServerError("Bank account transaction (${this.subject}) lacks database row ID.")
private val logger: Logger = LoggerFactory.getLogger("tech.libeufin.bank.Database")
@@ -116,7 +114,7 @@ class Database(dbConfig: String, private val bankCurrency: String): java.io.Clos
dbPool.close()
}
- private suspend fun <R> conn(lambda: suspend (PgConnection) -> R): R {
+ suspend fun <R> conn(lambda: suspend (PgConnection) -> R): R {
// Use a coroutine dispatcher that we can block as JDBC API is blocking
return withContext(Dispatchers.IO) {
val conn = dbPool.getConnection()
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/TalerCommon.kt b/bank/src/main/kotlin/tech/libeufin/bank/TalerCommon.kt
@@ -241,7 +241,7 @@ class TalerAmount {
val (currency, value, frac) = match.destructured
this.currency = currency
this.value = value.toLongOrNull() ?: throw badAmount("Invalid value")
- if (this.value > MAX_SAFE_INTEGER) throw badAmount("Value specified in amount is too large")
+ if (this.value > MAX_VALUE) throw badAmount("Value specified in amount is too large")
this.frac = if (frac.isEmpty()) {
0
} else {
@@ -284,7 +284,8 @@ class TalerAmount {
companion object {
const val FRACTION_BASE = 100000000
- private val PATTERN = Regex("([A-Z]{1,11}):([0-9]+)(?:\\.([0-9]{0,8}))?");
+ const val MAX_VALUE = 4503599627370496L; // 2^52
+ private val PATTERN = Regex("([A-Z]{1,11}):([0-9]+)(?:\\.([0-9]{1,8}))?");
}
}
diff --git a/bank/src/test/kotlin/AmountTest.kt b/bank/src/test/kotlin/AmountTest.kt
@@ -17,10 +17,10 @@
* <http://www.gnu.org/licenses/>
*/
-
import org.junit.Test
import org.postgresql.jdbc.PgConnection
import tech.libeufin.bank.*
+import tech.libeufin.util.*
import kotlin.test.*
import java.time.Instant
import java.util.*
@@ -126,19 +126,19 @@ class AmountTest {
}
@Test
- fun parseValid() {
+ fun parse() {
assertEquals(TalerAmount("EUR:4"), TalerAmount(4L, 0, "EUR"))
assertEquals(TalerAmount("EUR:0.02"), TalerAmount(0L, 2000000, "EUR"))
assertEquals(TalerAmount("EUR:4.12"), TalerAmount(4L, 12000000, "EUR"))
assertEquals(TalerAmount("LOCAL:4444.1000"), TalerAmount(4444L, 10000000, "LOCAL"))
- }
+ assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount(TalerAmount.MAX_VALUE, 99999999, "EUR"))
- @Test
- fun parseInvalid() {
assertException("Invalid amount format") {TalerAmount("")}
assertException("Invalid amount format") {TalerAmount("EUR")}
assertException("Invalid amount format") {TalerAmount("eur:12")}
assertException("Invalid amount format") {TalerAmount(" EUR:12")}
+ assertException("Invalid amount format") {TalerAmount("EUR:1.")}
+ assertException("Invalid amount format") {TalerAmount("EUR:.1")}
assertException("Invalid amount format") {TalerAmount("AZERTYUIOPQSD:12")}
assertException("Value specified in amount is too large") {TalerAmount("EUR:${Long.MAX_VALUE}")}
assertException("Invalid amount format") {TalerAmount("EUR:4.000000000")}
@@ -151,4 +151,84 @@ class AmountTest {
assertEquals(amount, TalerAmount(amount).toString())
}
}
+
+ @Test
+ fun normalize() = dbSetup { db ->
+ db.conn { conn ->
+ val stmt = conn.prepareStatement("SELECT normalized.val, normalized.frac FROM amount_normalize((?, ?)::taler_amount) as normalized")
+ fun TalerAmount.normalize(): TalerAmount? {
+ stmt.setLong(1, value)
+ stmt.setInt(2, frac)
+ return stmt.oneOrNull {
+ TalerAmount(
+ it.getLong(1),
+ it.getInt(2),
+ "EUR"
+ )
+ }!!
+ }
+
+ assertEquals(TalerAmount("EUR:6"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE, "EUR").normalize())
+ assertEquals(TalerAmount("EUR:6.00000001"), TalerAmount(4L, 2 * TalerAmount.FRACTION_BASE + 1, "EUR").normalize())
+ assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999").normalize())
+ assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, TalerAmount.FRACTION_BASE, "EUR").normalize() }
+ assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE, TalerAmount.FRACTION_BASE , "EUR").normalize() }
+ }
+ }
+
+ @Test
+ fun add() = dbSetup { db ->
+ db.conn { conn ->
+ val stmt = conn.prepareStatement("SELECT sum.val, sum.frac FROM amount_add((?, ?)::taler_amount, (?, ?)::taler_amount) as sum")
+ operator fun TalerAmount.plus(increment: TalerAmount): TalerAmount? {
+ stmt.setLong(1, value)
+ stmt.setInt(2, frac)
+ stmt.setLong(3, increment.value)
+ stmt.setInt(4, increment.frac)
+ return stmt.oneOrNull {
+ TalerAmount(
+ it.getLong(1),
+ it.getInt(2),
+ "EUR"
+ )
+ }!!
+ }
+
+ assertEquals(TalerAmount("EUR:6.41") + TalerAmount("EUR:4.69"), TalerAmount("EUR:11.1"))
+ assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}") + TalerAmount("EUR:0.99999999"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"))
+ assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, 0, "EUR") + TalerAmount(6, 0, "EUR") }
+ assertException("ERROR: bigint out of range") { TalerAmount(Long.MAX_VALUE, 0, "EUR") + TalerAmount(1, 0, "EUR") }
+ assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE - 5, TalerAmount.FRACTION_BASE - 1, "EUR") + TalerAmount(5, 2, "EUR") }
+ assertException("ERROR: integer out of range") { TalerAmount(0, Int.MAX_VALUE, "EUR") + TalerAmount(0, 1, "EUR") }
+ }
+ }
+
+ @Test
+ fun mul() = dbSetup { db ->
+ db.conn { conn ->
+ val stmt = conn.prepareStatement("SELECT product.val, product.frac FROM amount_mul((?, ?)::taler_amount, (?, ?)::taler_amount) as product")
+ operator fun TalerAmount.times(increment: TalerAmount): TalerAmount? {
+ stmt.setLong(1, value)
+ stmt.setInt(2, frac)
+ stmt.setLong(3, increment.value)
+ stmt.setInt(4, increment.frac)
+ return stmt.oneOrNull {
+ TalerAmount(
+ it.getLong(1),
+ it.getInt(2),
+ "EUR"
+ )
+ }!!
+ }
+
+ assertEquals(TalerAmount("EUR:6.41") * TalerAmount("EUR:4.69"), TalerAmount("EUR:30.0629"))
+ assertEquals(TalerAmount("EUR:6.41") * TalerAmount("EUR:1.000001"), TalerAmount("EUR:6.41000641"))
+ assertEquals(TalerAmount("EUR:0.99999999") * TalerAmount("EUR:2.5"), TalerAmount("EUR:2.49999998"))
+ assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999") * TalerAmount("EUR:1"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}.99999999"))
+ assertEquals(TalerAmount("EUR:${TalerAmount.MAX_VALUE/4}") * TalerAmount("EUR:4"), TalerAmount("EUR:${TalerAmount.MAX_VALUE}"))
+ assertException("ERROR: amount value overflowed") { TalerAmount(TalerAmount.MAX_VALUE/3, 0, "EUR") * TalerAmount(3, 1, "EUR") }
+ assertException("ERROR: amount value overflowed") { TalerAmount((TalerAmount.MAX_VALUE+2)/2, 0, "EUR") * TalerAmount("EUR:2") }
+ assertException("ERROR: numeric field overflow") { TalerAmount(Long.MAX_VALUE, 0, "EUR") * TalerAmount("EUR:1") }
+ }
+ }
}
\ No newline at end of file
diff --git a/bank/src/test/kotlin/helpers.kt b/bank/src/test/kotlin/helpers.kt
@@ -137,7 +137,7 @@ fun assertException(msg: String, lambda: () -> Unit) {
lambda()
throw Exception("Expected failure")
} catch (e: Exception) {
- assertEquals(msg, e.message)
+ assert(e.message!!.startsWith(msg)) { "${e.message}" }
}
}
diff --git a/database-versioning/procedures.sql b/database-versioning/procedures.sql
@@ -1,34 +1,55 @@
BEGIN;
SET search_path TO libeufin_bank;
-CREATE OR REPLACE PROCEDURE amount_normalize(
+CREATE OR REPLACE FUNCTION amount_normalize(
IN amount taler_amount
- ,INOUT normalized taler_amount
+ ,OUT normalized taler_amount
)
LANGUAGE plpgsql AS $$
BEGIN
normalized.val = amount.val + amount.frac / 100000000;
+ IF (normalized.val > 1::bigint<<52) THEN
+ RAISE EXCEPTION 'amount value overflowed';
+ END IF;
normalized.frac = amount.frac % 100000000;
+
END $$;
-COMMENT ON PROCEDURE amount_normalize
- IS 'Returns the normalized amount by adding to the .val the value of (.frac / 100000000) and removing the modulus 100000000 from .frac.';
+COMMENT ON FUNCTION amount_normalize
+ IS 'Returns the normalized amount by adding to the .val the value of (.frac / 100000000) and removing the modulus 100000000 from .frac.'
+ 'It raises an exception when the resulting .val is larger than 2^52';
-CREATE OR REPLACE PROCEDURE amount_add(
+CREATE OR REPLACE FUNCTION amount_add(
IN a taler_amount
,IN b taler_amount
- ,INOUT sum taler_amount
+ ,OUT sum taler_amount
)
LANGUAGE plpgsql AS $$
BEGIN
sum = (a.val + b.val, a.frac + b.frac);
- CALL amount_normalize(sum ,sum);
- IF sum.val > (1<<52) THEN
- RAISE EXCEPTION 'addition overflow';
- END IF;
+ SELECT normalized.val, normalized.frac INTO sum.val, sum.frac FROM amount_normalize(sum) as normalized;
END $$;
-COMMENT ON PROCEDURE amount_add
+COMMENT ON FUNCTION amount_add
IS 'Returns the normalized sum of two amounts. It raises an exception when the resulting .val is larger than 2^52';
+CREATE OR REPLACE FUNCTION amount_mul(
+ IN a taler_amount
+ ,IN b taler_amount
+ ,OUT product taler_amount
+)
+LANGUAGE plpgsql AS $$
+DECLARE
+tmp NUMERIC(24, 8); -- 16 digit for val and 8 for frac
+BEGIN
+ -- TODO write custom multiplication logic to get more control over rounding
+ tmp = (a.val::numeric(24, 8) + a.frac::numeric(24, 8) / 100000000) * (b.val::numeric(24, 8) + b.frac::numeric(24, 8) / 100000000);
+ product = (trunc(tmp)::bigint, (tmp * 100000000 % 100000000)::int);
+ IF (product.val > 1::bigint<<52) THEN
+ RAISE EXCEPTION 'amount value overflowed';
+ END IF;
+END $$;
+COMMENT ON FUNCTION amount_mul -- TODO document rounding
+ IS 'Returns the product of two amounts. It raises an exception when the resulting .val is larger than 2^52';
+
CREATE OR REPLACE FUNCTION amount_left_minus_right(
IN l taler_amount
,IN r taler_amount
@@ -552,7 +573,9 @@ END IF;
-- check enough funds
IF account_has_debt THEN
-- debt case: simply checking against the max debt allowed.
- CALL amount_add(account_balance, in_amount, account_balance);
+ SELECT sum.val, sum.frac
+ INTO account_balance.val, account_balance.frac
+ FROM amount_add(account_balance, in_amount) as sum;
SELECT NOT ok
INTO out_balance_insufficient
FROM amount_left_minus_right(account_max_debt, account_balance);
@@ -826,9 +849,9 @@ out_creditor_not_found=FALSE;
-- check debtor has enough funds.
IF debtor_has_debt THEN
-- debt case: simply checking against the max debt allowed.
- CALL amount_add(debtor_balance,
- in_amount,
- potential_balance);
+ SELECT sum.val, sum.frac
+ INTO potential_balance.val, potential_balance.frac
+ FROM amount_add(debtor_balance, in_amount) as sum;
SELECT NOT ok
INTO out_balance_insufficient
FROM amount_left_minus_right(debtor_max_debt,
@@ -875,7 +898,9 @@ out_balance_insufficient=FALSE;
-- from debit to a credit situation, and adjust the balance
-- accordingly.
IF NOT creditor_has_debt THEN -- easy case.
- CALL amount_add(creditor_balance, in_amount, new_creditor_balance);
+ SELECT sum.val, sum.frac
+ INTO new_creditor_balance.val, new_creditor_balance.frac
+ FROM amount_add(creditor_balance, in_amount) as sum;
will_creditor_have_debt=FALSE;
ELSE -- creditor had debit but MIGHT switch to credit.
SELECT