libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

commit 9f7a78d2cd08459ec22fadbf57c379ec2269204e
parent 61ae182a17a32f3ee0e42380ef31849a40395886
Author: Antoine A <>
Date:   Fri, 23 May 2025 15:24:54 +0200

common: refactor SQL args binding

Diffstat:
Mbank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt | 125++++++++++++++++++++++++++++++++++---------------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt | 27+++++++++++----------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt | 29+++++++++++++----------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/Database.kt | 6+++---
Mbank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt | 45++++++++++++++++++++-------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt | 16++++++++--------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt | 58+++++++++++++++++++++++++++++-----------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt | 29++++++++++++++---------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt | 47+++++++++++++++++++++--------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt | 99+++++++++++++++++++++++++++++--------------------------------------------------
Mbank/src/test/kotlin/AmountTest.kt | 87+++++++++++++++++++++++++++++++++----------------------------------------------
Mbank/src/test/kotlin/CoreBankApiTest.kt | 13++++++-------
Mbank/src/test/kotlin/DatabaseTest.kt | 41++++++++++++++++++++---------------------
Mbank/src/test/kotlin/GcTest.kt | 4++--
Mbank/src/test/kotlin/StatsTest.kt | 22++++++++++------------
Mcommon/src/main/kotlin/db/DbPool.kt | 6+++---
Mcommon/src/main/kotlin/db/helpers.kt | 13++++++-------
Mcommon/src/main/kotlin/db/schema.kt | 16++++++++--------
Acommon/src/main/kotlin/db/statement.kt | 204+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mcommon/src/main/kotlin/db/transaction.kt | 96+++++--------------------------------------------------------------------------
Mcommon/src/main/kotlin/db/types.kt | 12+++++++++++-
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/EbicsDAO.kt | 8+++-----
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt | 29+++++++++++++----------------
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt | 91+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/KvDAO.kt | 10+++++-----
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt | 86+++++++++++++++++++++++++++++++++++--------------------------------------------
Mnexus/src/test/kotlin/DatabaseTest.kt | 22+++++++++++-----------
Mtestbench/src/test/kotlin/IntegrationTest.kt | 1-
28 files changed, 636 insertions(+), 606 deletions(-)

diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt @@ -54,7 +54,7 @@ class AccountDAO(private val db: Database) { checkPaytoIdempotent: Boolean, pwCrypto: PwCrypto ): AccountCreationResult = db.serializableTransaction { conn -> - val timestamp = Instant.now().micros() + val timestamp = Instant.now() val idempotent = conn.withStatement(""" SELECT password_hash, name=? AND email IS NOT DISTINCT FROM ? @@ -65,31 +65,25 @@ class AccountDAO(private val db: Database) { AND is_public=? AND is_taler_exchange=? AND max_debt=(?,?)::taler_amount - AND ${if (minCashout == null) "min_cashout IS NULL" else "min_cashout IS NOT DISTINCT FROM (?,?)::taler_amount"} + AND min_cashout IS NOT DISTINCT FROM ${optAmount(minCashout)} ,internal_payto, name FROM customers JOIN bank_accounts ON customer_id=owning_customer_id WHERE username=? """) { - setString(1, name) - setString(2, email) - setString(3, phone) - setString(4, cashoutPayto?.simple()) - setString(5, tanChannel?.name) - setBoolean(6, checkPaytoIdempotent) - setString(7, internalPayto.canonical) - setBoolean(8, isPublic) - setBoolean(9, isTalerExchange) - setLong(10, maxDebt.value) - setInt(11, maxDebt.frac) - if (minCashout != null) { - setLong(12, minCashout.value) - setInt(13, minCashout.frac) - setString(14, username) - } else { - setString(12, username) - } + bind(name) + bind(email) + bind(phone) + bind(cashoutPayto?.simple()) + bind(tanChannel) + bind(checkPaytoIdempotent) + bind(internalPayto.canonical) + bind(isPublic) + bind(isTalerExchange) + bind(maxDebt) + bind(minCashout) + bind(username) oneOrNull { Pair( pwCrypto.checkpw(password, it.getString(1)).match && it.getBoolean(2), @@ -112,8 +106,8 @@ class AccountDAO(private val db: Database) { ,creation_time ) VALUES (?, ?) """) { - setString(1, internalPayto.iban.value) - setLong(2, timestamp) + bind(internalPayto.iban.value) + bind(timestamp) if (!executeUpdateViolation()) { conn.rollback() return@serializableTransaction AccountCreationResult.PayToReuse @@ -133,13 +127,13 @@ class AccountDAO(private val db: Database) { RETURNING customer_id """ ) { - setString(1, username) - setString(2, pwCrypto.hashpw(password)) - setString(3, name) - setString(4, email) - setString(5, phone) - setString(6, cashoutPayto?.simple()) - setString(7, tanChannel?.name) + bind(username) + bind(pwCrypto.hashpw(password)) + bind(name) + bind(email) + bind(phone) + bind(cashoutPayto?.simple()) + bind(tanChannel) one { it.getLong("customer_id") } } @@ -151,18 +145,14 @@ class AccountDAO(private val db: Database) { ,is_taler_exchange ,max_debt ,min_cashout - ) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${if (minCashout == null) "NULL" else "(?, ?)::taler_amount"}) + ) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${optAmount(minCashout)}) """) { - setString(1, internalPayto.canonical) - setLong(2, customerId) - setBoolean(3, isPublic) - setBoolean(4, isTalerExchange) - setLong(5, maxDebt.value) - setInt(6, maxDebt.frac) - if (minCashout != null) { - setLong(7, minCashout.value) - setInt(8, minCashout.frac) - } + bind(internalPayto.canonical) + bind(customerId) + bind(isPublic) + bind(isTalerExchange) + bind(maxDebt) + bind(minCashout) if (!executeUpdateViolation()) { conn.rollback() return@serializableTransaction AccountCreationResult.PayToReuse @@ -174,10 +164,9 @@ class AccountDAO(private val db: Database) { SELECT out_balance_insufficient FROM bank_transaction(?,'admin','bonus',(?,?)::taler_amount,?,true,NULL,NULL,NULL,NULL) """) { - setString(1, internalPayto.canonical) - setLong(2, bonus.value) - setInt(3, bonus.frac) - setLong(4, timestamp) + bind(internalPayto.canonical) + bind(bonus) + bind(timestamp) one { when { it.getBoolean("out_balance_insufficient") -> { @@ -215,9 +204,9 @@ class AccountDAO(private val db: Database) { FROM account_delete(?,?,?) """ ) { - setString(1, username) - setLong(2, Instant.now().micros()) - setBoolean(3, is2fa) + bind(username) + bind(Instant.now()) + bind(is2fa) one { when { it.getBoolean("out_not_found") -> AccountDeletionResult.UnknownAccount @@ -287,7 +276,7 @@ class AccountDAO(private val db: Database) { ON customer_id=owning_customer_id WHERE username=? AND deleted_at IS NULL """) { - setString(1, username) + bind(username) oneOrNull { CurrentAccount( id = it.getLong("customer_id"), @@ -349,8 +338,8 @@ class AccountDAO(private val db: Database) { // Invalidate current challenges if (patchChannel != null || patchInfo != null) { conn.withStatement("UPDATE tan_challenges SET expiration_date=0 WHERE customer=?") { - setLong(1, curr.id) - execute() + bind(curr.id) + executeUpdate() } } @@ -427,8 +416,8 @@ class AccountDAO(private val db: Database) { SELECT customer_id, password_hash, (NOT ? AND tan_channel IS NOT NULL) FROM customers WHERE username=? AND deleted_at IS NULL """) { - setBoolean(1, is2fa) - setString(2, username) + bind(is2fa) + bind(username) oneOrNull { Triple(it.getLong(1), it.getString(2), it.getBoolean(3)) } ?: return@serializableTransaction AccountPatchAuthResult.UnknownAccount @@ -440,8 +429,8 @@ class AccountDAO(private val db: Database) { } else { val newPwh = pwCrypto.hashpw(newPw.pw) conn.withStatement("UPDATE customers SET password_hash=?, token_creation_counter=0 WHERE customer_id=?") { - setString(1, newPwh) - setLong(2, customerId) + bind(newPwh) + bind(customerId) executeUpdate() } @@ -463,7 +452,7 @@ class AccountDAO(private val db: Database) { val info = db.serializable( "SELECT customer_id, password_hash, token_creation_counter FROM customers WHERE username=? AND deleted_at IS NULL" ) { - setString(1, username) + bind(username) oneOrNull { Triple(it.getLong(1), it.getString(2), it.getInt(3)) } @@ -484,9 +473,9 @@ class AccountDAO(private val db: Database) { db.serializable( "UPDATE customers SET password_hash=? where customer_id=? AND password_hash=?" ) { - setString(1, newPwh) - setLong(2, customerId) - setString(3, currentPwh) + bind(newPwh) + bind(customerId) + bind(currentPwh) executeUpdate() } } @@ -507,7 +496,7 @@ class AccountDAO(private val db: Database) { WHERE username=? """ ) { - setString(1, username) + bind(username) oneOrNull { BankInfo( payto = it.getBankPayto("internal_payto", "name", db.ctx), @@ -523,7 +512,7 @@ class AccountDAO(private val db: Database) { SELECT FROM bank_accounts WHERE internal_payto=? """ ) { - setString(1, payto.canonical) + bind(payto.canonical) oneOrNull { AccountInfo() } @@ -559,8 +548,8 @@ class AccountDAO(private val db: Database) { WHERE username=? """ ) { - setInt(1, MAX_TOKEN_CREATION_ATTEMPTS) - setString(2, username) + bind(MAX_TOKEN_CREATION_ATTEMPTS) + bind(username) oneOrNull { val name = it.getString("name") val status: AccountStatus = it.getEnum("status") @@ -615,10 +604,7 @@ class AccountDAO(private val db: Database) { """, { if (params.usernameFilter != null) { - setString(1, params.usernameFilter) - 1 - } else { - 0 + bind(params.usernameFilter) } } ) { @@ -668,12 +654,9 @@ class AccountDAO(private val db: Database) { WHERE ${if (params.usernameFilter != null) "name LIKE ? AND" else ""} """, { - setInt(1, MAX_TOKEN_CREATION_ATTEMPTS) + bind(MAX_TOKEN_CREATION_ATTEMPTS) if (params.usernameFilter != null) { - setString(2, params.usernameFilter) - 2 - } else { - 1 + bind(params.usernameFilter) } } ) { diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023-2024 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -63,15 +63,13 @@ class CashoutDAO(private val db: Database) { FROM cashout_create(?,?,(?,?)::taler_amount,(?,?)::taler_amount,?,?,?) """ ) { - setString(1, username) - setBytes(2, requestUid.raw) - setLong(3, amountDebit.value) - setInt(4, amountDebit.frac) - setLong(5, amountCredit.value) - setInt(6, amountCredit.frac) - setString(7, subject) - setLong(8, timestamp.micros()) - setBoolean(9, is2fa) + bind(username) + bind(requestUid) + bind(amountDebit) + bind(amountCredit) + bind(subject) + bind(timestamp) + bind(is2fa) one { when { it.getBoolean("out_under_min") -> CashoutCreationResult.UnderMin @@ -110,8 +108,8 @@ class CashoutDAO(private val db: Database) { WHERE cashout_id=? AND username=? """ ) { - setLong(1, id) - setString(2, username) + bind(id) + bind(username) oneOrNull { CashoutStatusResponse( status = CashoutStatus.confirmed, @@ -156,10 +154,7 @@ class CashoutDAO(private val db: Database) { WHERE deleted_at IS NULL AND username = ? ) AND """, - bind = { - setString(1, username) - 1 - } + args = { bind(username) } ) { CashoutInfo( cashout_id = it.getLong("cashout_id"), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023-2024 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -34,9 +34,8 @@ class ConversionDAO(private val db: Database) { Pair("cashin_ratio", cfg.cashin_ratio), Pair("cashout_ratio", cfg.cashout_ratio), )) { - setString(1, name) - setLong(2, amount.value) - setInt(3, amount.frac) + bind(name) + bind(amount) executeUpdate() } for ((name, amount) in listOf( @@ -47,9 +46,8 @@ class ConversionDAO(private val db: Database) { Pair("cashout_tiny_amount", cfg.cashout_tiny_amount), Pair("cashout_min_amount", cfg.cashout_min_amount), )) { - setString(1, name) - setLong(2, amount.value) - setInt(3, amount.frac) + bind(name) + bind(amount) executeUpdate() } } @@ -59,8 +57,8 @@ class ConversionDAO(private val db: Database) { Pair("cashin_rounding_mode", cfg.cashin_rounding_mode), Pair("cashout_rounding_mode", cfg.cashout_rounding_mode) )) { - setString(1, name) - setString(2, value.name) + bind(name) + bind(value) executeUpdate() } } @@ -72,15 +70,15 @@ class ConversionDAO(private val db: Database) { one { it.getBoolean(1) } } if (!check) return@serializableTransaction null - val amount = conn.prepareStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount") - val roundingMode = conn.prepareStatement("SELECT config_get_rounding_mode(?)") + val amount = conn.talerStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount") + val roundingMode = conn.talerStatement("SELECT config_get_rounding_mode(?)") fun getAmount(name: String, currency: String): TalerAmount { - amount.setString(1, name) + amount.bind(name) return amount.one { it.getAmount("amount", currency) } } fun getRatio(name: String): DecimalNumber = getAmount(name, "").run { DecimalNumber(value, frac) } fun getMode(name: String): RoundingMode { - roundingMode.setString(1, name) + roundingMode.bind(name) return roundingMode.one { it.getEnum<RoundingMode>(1) } } val rate = ConversionRate( @@ -118,9 +116,8 @@ class ConversionDAO(private val db: Database) { private suspend fun conversion(amount: TalerAmount, direction: String, function: String): ConversionResult = db.serializable( "SELECT too_small, no_config, (converted).val AS amount_val, (converted).frac AS amount_frac FROM $function((?, ?)::taler_amount, ?, (0, 0)::taler_amount)" ) { - setLong(1, amount.value) - setInt(2, amount.frac) - setString(3, direction) + bind(amount) + bind(direction) one { when { it.getBoolean("no_config") -> ConversionResult.MissingConfig diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023-2024 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -132,8 +132,8 @@ class Database( FROM stats_get_frame(?::timestamp, ?::stat_timeframe_enum) """ ) { - setObject(1, params.date) - setString(2, params.timeframe.name) + bind(params.date) + bind(params.timeframe) oneOrNull { fiatCurrency?.run { MonitorWithConversion( diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt @@ -134,15 +134,14 @@ class ExchangeDAO(private val db: Database) { ) { val subject = "${req.wtid} ${req.exchange_base_url.url}" - setBytes(1, req.request_uid.raw) - setBytes(2, req.wtid.raw) - setString(3, subject) - setLong(4, req.amount.value) - setInt(5, req.amount.frac) - setString(6, req.exchange_base_url.url) - setString(7, req.credit_account.canonical) - setString(8, username) - setLong(9, timestamp.micros()) + bind(req.request_uid) + bind(req.wtid) + bind(subject) + bind(req.amount) + bind(req.exchange_base_url.url) + bind(req.credit_account.canonical) + bind(username) + bind(timestamp) one { when { @@ -179,8 +178,8 @@ class ExchangeDAO(private val db: Database) { WHERE transfer_operation_id=? AND exchange_id=? """ ) { - setLong(1, txId) - setLong(2, exchangeId) + bind(txId) + bind(exchangeId) oneOrNull { TransferStatus( status = it.getEnum<TransferStatusState>("status"), @@ -219,12 +218,9 @@ class ExchangeDAO(private val db: Database) { } """, { - setLong(1, exchangeId) - if (status == null) { - 1 - } else { - setString(2, status.name) - 2 + bind(exchangeId) + if (status != null) { + bind(status) } } ) { @@ -274,14 +270,13 @@ class ExchangeDAO(private val db: Database) { ); """ ) { - setBytes(1, metadata.key.raw) - setString(2, subject) - setLong(3, amount.value) - setInt(4, amount.frac) - setString(5, debitAccount.canonical) - setString(6, username) - setLong(7, timestamp.micros()) - setString(8, metadata.type.name) + bind(metadata.key) + bind(subject) + bind(amount) + bind(debitAccount.canonical) + bind(username) + bind(timestamp) + bind(metadata.type) one { when { diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -41,8 +41,8 @@ class GcDAO(private val db: Database) { conn.withStatement( "UPDATE taler_withdrawal_operations SET aborted = true WHERE creation_date < ?" ) { - setLong(1, abortAfterMicro) - execute() + bind(abortAfterMicro) + executeUpdate() } // Clean aborted operations, expired challenges and expired tokens @@ -52,8 +52,8 @@ class GcDAO(private val db: Database) { "DELETE FROM bearer_tokens WHERE expiration_time < ?" )) { conn.withStatement(smt) { - setLong(1, cleanAfterMicro) - execute() + bind(cleanAfterMicro) + executeUpdate() } } @@ -61,8 +61,8 @@ class GcDAO(private val db: Database) { conn.withStatement( "DELETE FROM bank_account_transactions WHERE transaction_date < ?" ) { - setLong(1, deleteAfterMicro) - execute() + bind(deleteAfterMicro) + executeUpdate() } // Hard delete soft deleted customer without bank transactions, bank account are deleted by CASCADE @@ -72,7 +72,7 @@ class GcDAO(private val db: Database) { WHERE owning_customer_id=customer_id ) """) { - execute() + executeUpdate() } // TODO clean stats diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023-2024 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -44,15 +44,15 @@ class TanDAO(private val db: Database) { ): Long = db.serializable( "SELECT tan_challenge_create(?,?::op_enum,?,?,?,?,?,?::tan_enum,?)" ) { - setString(1, body) - setString(2, op.name) - setString(3, code) - setLong(4, timestamp.micros()) - setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) - setInt(6, retryCounter) - setString(7, username) - setString(8, channel?.name) - setString(9, info) + bind(body) + bind(op) + bind(code) + bind(timestamp) + bind(TimeUnit.MICROSECONDS.convert(validityPeriod)) + bind(retryCounter) + bind(username) + bind(channel) + bind(info) oneOrNull { it.getLong(1) } ?: throw internalServerError("TAN challenge returned nothing.") @@ -82,14 +82,14 @@ class TanDAO(private val db: Database) { FROM tan_challenge_send(?,?,?,?,?,?,?,?) """ ) { - setLong(1, id) - setString(2, username) - setString(3, code) - setLong(4, timestamp.micros()) - setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) - setInt(6, retryCounter) - setBoolean(7, isAuth) - setInt(8, maxActive) + bind(id) + bind(username) + bind(code) + bind(timestamp) + bind(TimeUnit.MICROSECONDS.convert(validityPeriod)) + bind(retryCounter) + bind(isAuth) + bind(maxActive) one { when { it.getBoolean("out_no_op") -> TanSendResult.NotFound @@ -112,9 +112,9 @@ class TanDAO(private val db: Database) { ) = db.serializable( "SELECT tan_challenge_mark_sent(?,?,?)" ) { - setLong(1, id) - setLong(2, timestamp.micros()) - setLong(3, TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) + bind(id) + bind(timestamp) + bind(TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) executeQuery() } @@ -144,11 +144,11 @@ class TanDAO(private val db: Database) { FROM tan_challenge_try(?,?,?,?,?) """ ) { - setLong(1, id) - setString(2, username) - setString(3, code) - setLong(4, timestamp.micros()) - setBoolean(5, isAuth) + bind(id) + bind(username) + bind(code) + bind(timestamp.micros()) + bind(isAuth) one { when { it.getBoolean("out_ok") -> TanSolveResult.Success( @@ -185,9 +185,9 @@ class TanDAO(private val db: Database) { WHERE challenge_id=? AND op=?::op_enum AND username=? AND deleted_at IS NULL """ ) { - setLong(1, id) - setString(2, op.name) - setString(3, username) + bind(id) + bind(op) + bind(username) oneOrNull { Challenge( body = it.getString("body"), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023-2024 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -53,14 +53,14 @@ class TokenDAO(private val db: Database) { ) """ ) { - setString(1, username) - setBytes(2, content) - setLong(3, creationTime.micros()) - setLong(4, expirationTime.micros()) - setString(5, scope.name) - setBoolean(6, isRefreshable) - setString(7, description) - setBoolean(8, is2fa) + bind(username) + bind(content) + bind(creationTime) + bind(expirationTime) + bind(scope) + bind(isRefreshable) + bind(description) + bind(is2fa) one { when { it.getBoolean("out_tan_required") -> TokenCreationResult.TanRequired @@ -84,8 +84,8 @@ class TokenDAO(private val db: Database) { is_refreshable """ ) { - setLong(1, accessTime.micros()) - setBytes(2, token) + bind(accessTime) + bind(token) oneOrNull { BearerToken( creationTime = it.getLong("creation_time").asInstant(), @@ -101,8 +101,8 @@ class TokenDAO(private val db: Database) { suspend fun delete(token: ByteArray) = db.serializable( "DELETE FROM bearer_tokens WHERE content = ?" ) { - setBytes(1, token) - execute() + bind(token) + executeUpdate() } /** Get a page of all tokens of [username] accounts */ @@ -125,8 +125,7 @@ class TokenDAO(private val db: Database) { AND """, { - setString(1, username) - 1 + bind(username) } ) { TokenInfo( diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt @@ -78,20 +78,16 @@ class TransactionDAO(private val db: Database) { FROM bank_transaction(?,?,?,(?,?)::taler_amount,?,?,?,(?,?)::taler_amount,(?,?)::taler_amount,(?,?)::taler_amount) """ ) { - setString(1, creditAccountPayto.canonical) - setString(2, debitAccountUsername) - setString(3, subject) - setLong(4, amount.value) - setInt(5, amount.frac) - setLong(6, timestamp) - setBoolean(7, is2fa) - setBytes(8, requestUid?.raw) - setLong(9, wireTransferFees.value) - setInt(10, wireTransferFees.frac) - setLong(11, minAmount.value) - setInt(12, minAmount.frac) - setLong(13, maxAmount.value) - setInt(14, maxAmount.frac) + bind(creditAccountPayto.canonical) + bind(debitAccountUsername) + bind(subject) + bind(amount) + bind(timestamp) + bind(is2fa) + bind(requestUid) + bind(wireTransferFees) + bind(minAmount) + bind(maxAmount) one { when { it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor @@ -119,10 +115,10 @@ class TransactionDAO(private val db: Database) { "unsupported admin balance adjust" } else { val registered = conn.withStatement("CALL register_incoming(?, ?::taler_incoming_type, ?, ?)") { - setLong(1, creditRowId) - setString(2, metadata.type.name) - setBytes(3, metadata.key.raw) - setLong(4, creditAccountId) + bind(creditRowId) + bind(metadata.type) + bind(metadata.key) + bind(creditAccountId) executeProcedureViolation() } if (!registered) { @@ -145,12 +141,11 @@ class TransactionDAO(private val db: Database) { ?, ?, ?, (?, ?)::taler_amount, ?, NULL,NULL,NULL ); """) { - setLong(1, debitAccountId) - setLong(2, creditAccountId) - setString(3, "Bounce $creditRowId: $bounceCause") - setLong(4, amount.value) - setInt(5, amount.frac) - setLong(6, timestamp) + bind(debitAccountId) + bind(creditAccountId) + bind("Bounce $creditRowId: $bounceCause") + bind(amount) + bind(timestamp) executeQuery() } } @@ -184,8 +179,8 @@ class TransactionDAO(private val db: Database) { WHERE bank_transaction_id=? AND username=? """ ) { - setLong(1, rowId) - setString(2, username) + bind(rowId) + bind(username) oneOrNull { BankAccountTransactionInfo( creditor_payto_uri = it.getBankPayto("creditor_payto", "creditor_name", db.ctx), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt @@ -63,33 +63,22 @@ class WithdrawalDAO(private val db: Database) { out_bad_amount FROM create_taler_withdrawal( ?,?, - ${if (amount != null) "(?,?)::taler_amount" else "NULL"}, - ${if (suggested_amount != null) "(?,?)::taler_amount" else "NULL"}, + ${optAmount(amount)}, + ${optAmount(suggested_amount)}, ?, ?, (?, ?)::taler_amount, (?, ?)::taler_amount, (?, ?)::taler_amount ); """ ) { - setString(1, username) - setObject(2, uuid) - var id = 3 - if (amount != null) { - setLong(id, amount.value) - setInt(id+1, amount.frac) - id += 2 - } - if (suggested_amount != null) { - setLong(id, suggested_amount.value) - setInt(id+1, suggested_amount.frac) - id += 2 - } - setBoolean(id, no_amount_to_wallet) - setLong(id+1, timestamp.micros()) - setLong(id+2, wireTransferFees.value) - setInt(id+3, wireTransferFees.frac) - setLong(id+4, minAmount.value) - setInt(id+5, minAmount.frac) - setLong(id+6, maxAmount.value) - setInt(id+7, maxAmount.frac) + + bind(username) + bind(uuid) + bind(amount) + bind(suggested_amount) + bind(no_amount_to_wallet) + bind(timestamp.micros()) + bind(wireTransferFees) + bind(minAmount) + bind(maxAmount) one { when { it.getBoolean("out_account_not_found") -> WithdrawalCreationResult.UnknownAccount @@ -113,8 +102,8 @@ class WithdrawalDAO(private val db: Database) { FROM abort_taler_withdrawal(?, ?) """ ) { - setObject(1, uuid) - setString(2, username) + bind(uuid) + bind(username) one { when { it.getBoolean("out_no_op") -> AbortResult.UnknownOperation @@ -162,27 +151,19 @@ class WithdrawalDAO(private val db: Database) { out_aborted FROM select_taler_withdrawal( ?, ?, ?, ?, - ${if (amount != null) "(?, ?)::taler_amount" else "NULL"}, + ${optAmount(amount)}, (?,?)::taler_amount, (?,?)::taler_amount, (?,?)::taler_amount ); """ ) { - setObject(1, uuid) - setBytes(2, reservePub.raw) - setString(3, "Taler withdrawal $reservePub") - setString(4, exchangePayto.canonical) - var id = 5 - if (amount != null) { - setLong(id, amount.value) - setInt(id+1, amount.frac) - id += 2 - } - setLong(id, wireTransferFees.value) - setInt(id+1, wireTransferFees.frac) - setLong(id+2, minAmount.value) - setInt(id+3, minAmount.frac) - setLong(id+4, maxAmount.value) - setInt(id+5, maxAmount.frac) + bind(uuid) + bind(reservePub) + bind("Taler withdrawal $reservePub") + bind(exchangePayto.canonical) + bind(amount) + bind(wireTransferFees) + bind(minAmount) + bind(maxAmount) one { when { it.getBoolean("out_aborted") -> WithdrawalSelectionResult.AlreadyAborted @@ -236,25 +217,18 @@ class WithdrawalDAO(private val db: Database) { out_missing_amount, out_amount_differs FROM confirm_taler_withdrawal( - ?,?,?,?,(?,?)::taler_amount,(?,?)::taler_amount,(?,?)::taler_amount, - ${if (amount != null) "(?, ?)::taler_amount" else "NULL"} + ?,?,?,?,(?,?)::taler_amount,(?,?)::taler_amount,(?,?)::taler_amount, ${optAmount(amount)} ); """ ) { - setString(1, username) - setObject(2, uuid) - setLong(3, timestamp.micros()) - setBoolean(4, is2fa) - setLong(5, wireTransferFees.value) - setInt(6, wireTransferFees.frac) - setLong(7, minAmount.value) - setInt(8, minAmount.frac) - setLong(9, maxAmount.value) - setInt(10, maxAmount.frac) - if (amount != null) { - setLong(11, amount.value) - setInt(12, amount.frac) - } + bind(username) + bind(uuid) + bind(timestamp) + bind(is2fa) + bind(wireTransferFees) + bind(minAmount) + bind(maxAmount) + bind(amount) one { when { it.getBoolean("out_no_op") -> WithdrawalConfirmationResult.UnknownOperation @@ -281,7 +255,7 @@ class WithdrawalDAO(private val db: Database) { WHERE withdrawal_uuid=? """ ) { - setObject(1, uuid) + bind(uuid) oneOrNull { it.getString(1) } } @@ -346,7 +320,7 @@ class WithdrawalDAO(private val db: Database) { WHERE withdrawal_uuid=? """ ) { - setObject(1, uuid) + bind(uuid) oneOrNull { WithdrawalPublicInfo( status = it.getEnum("status"), @@ -399,9 +373,8 @@ class WithdrawalDAO(private val db: Database) { WHERE withdrawal_uuid=? """ ) { - setLong(1, maxAmount.value) - setInt(2, maxAmount.frac) - setObject(3, uuid) + bind(maxAmount) + bind(uuid) oneOrNull { BankWithdrawalOperationStatus( status = it.getEnum("status"), diff --git a/bank/src/test/kotlin/AmountTest.kt b/bank/src/test/kotlin/AmountTest.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -29,7 +29,7 @@ class AmountTest { @Test fun computationTest() = bankSetup { db -> db.conn { conn -> conn.execSQLUpdate("UPDATE libeufin_bank.bank_accounts SET balance.val = 100000 WHERE internal_payto = '${customerPayto.canonical}'") - val stmt = conn.prepareStatement(""" + val stmt = conn.talerStatement(""" UPDATE libeufin_bank.bank_accounts SET balance = (?, ?)::taler_amount ,has_debt = ? @@ -42,11 +42,9 @@ class AmountTest { maxDebt: TalerAmount, amount: TalerAmount ): Boolean { - stmt.setLong(1, balance.value) - stmt.setInt(2, balance.frac) - stmt.setBoolean(3, hasDebt) - stmt.setLong(4, maxDebt.value) - stmt.setInt(5, maxDebt.frac) + stmt.bind(balance) + stmt.bind(hasDebt) + stmt.bind(maxDebt) // Check bank transaction stmt.executeUpdate() @@ -63,6 +61,9 @@ class AmountTest { } // Check whithdraw + stmt.bind(balance) + stmt.bind(hasDebt) + stmt.bind(maxDebt) stmt.executeUpdate() for ((amount, suggested) in listOf(Pair(amount, null), Pair(null, amount), Pair(amount, amount))) { val wRes = client.postA("/accounts/merchant/withdrawals") { @@ -130,37 +131,31 @@ class AmountTest { // Max withdrawal amount computation in db @Test fun maxComputationTest() = bankSetup { db -> db.conn { conn -> - val update = conn.prepareStatement(""" + val update = conn.talerStatement(""" UPDATE libeufin_bank.bank_accounts SET balance = (?, ?)::taler_amount ,has_debt = ? ,max_debt = (?, ?)::taler_amount WHERE bank_account_id = 1 """) - val select = conn.prepareStatement(""" + val select = conn.talerStatement(""" SELECT (max_amount).val as max_amount_val ,(max_amount).frac as max_amount_frac FROM account_max_amount(1, (?, ?)::taler_amount) AS max_amount """) - select.apply { - val max = TalerAmount.max("KUDOS") - setLong(1, max.value) - setInt(2, max.frac) - } suspend fun routine( balance: TalerAmount, hasDebt: Boolean, maxDebt: TalerAmount ): TalerAmount { update.apply { - setLong(1, balance.value) - setInt(2, balance.frac) - setBoolean(3, hasDebt) - setLong(4, maxDebt.value) - setInt(5, maxDebt.frac) + bind(balance) + bind(hasDebt) + bind(maxDebt) executeUpdate() } + select.bind(TalerAmount.max("KUDOS")) return select.one { it.getAmount("max_amount", "KUDOS") } } @@ -188,10 +183,10 @@ class AmountTest { @Test fun normalize() = dbSetup { db -> db.conn { conn -> - val stmt = conn.prepareStatement("SELECT normalized.val, normalized.frac FROM amount_normalize((?, ?)::taler_amount) as normalized") + val stmt = conn.talerStatement("SELECT normalized.val, normalized.frac FROM amount_normalize((?, ?)::taler_amount) as normalized") fun TalerAmount.db(): TalerAmount { - stmt.setLong(1, value) - stmt.setInt(2, frac) + stmt.bind(value) + stmt.bind(frac) return stmt.one { TalerAmount( it.getLong(1), @@ -228,12 +223,11 @@ class AmountTest { @Test fun add() = dbSetup { db -> db.conn { conn -> - val stmt = conn.prepareStatement("SELECT sum.val, sum.frac FROM amount_add((?, ?)::taler_amount, (?, ?)::taler_amount) as sum") + val stmt = conn.talerStatement("SELECT sum.val, sum.frac FROM amount_add((?, ?)::taler_amount, (?, ?)::taler_amount) as sum") fun TalerAmount.db(increment: TalerAmount): TalerAmount { - stmt.setLong(1, value) - stmt.setInt(2, frac) - stmt.setLong(3, increment.value) - stmt.setInt(4, increment.frac) + stmt.bind(value) + stmt.bind(frac) + stmt.bind(increment) return stmt.one { TalerAmount( it.getLong(1), @@ -270,14 +264,11 @@ class AmountTest { fun conversionApply() = dbSetup { db -> db.conn { conn -> fun apply(nb: TalerAmount, times: DecimalNumber, tiny: DecimalNumber = DecimalNumber("0.00000001"), roundingMode: String = "zero"): TalerAmount { - val stmt = conn.prepareStatement("SELECT (result).val, (result).frac FROM conversion_apply_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") - stmt.setLong(1, nb.value) - stmt.setInt(2, nb.frac) - stmt.setLong(3, times.value) - stmt.setInt(4, times.frac) - stmt.setLong(5, tiny.value) - stmt.setInt(6, tiny.frac) - stmt.setString(7, roundingMode) + val stmt = conn.talerStatement("SELECT (result).val, (result).frac FROM conversion_apply_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") + stmt.bind(nb) + stmt.bind(times) + stmt.bind(tiny) + stmt.bind(roundingMode) return stmt.one { TalerAmount( it.getLong(1), @@ -333,15 +324,12 @@ class AmountTest { @Test fun conversionRevert() = dbSetup { db -> db.conn { conn -> - val applyStmt = conn.prepareStatement("SELECT (result).val, (result).frac FROM conversion_apply_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") + val applyStmt = conn.talerStatement("SELECT (result).val, (result).frac FROM conversion_apply_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") fun TalerAmount.apply(ratio: DecimalNumber, tiny: DecimalNumber = DecimalNumber("0.00000001"), roundingMode: String = "zero"): TalerAmount { - applyStmt.setLong(1, this.value) - applyStmt.setInt(2, this.frac) - applyStmt.setLong(3, ratio.value) - applyStmt.setInt(4, ratio.frac) - applyStmt.setLong(5, tiny.value) - applyStmt.setInt(6, tiny.frac) - applyStmt.setString(7, roundingMode) + applyStmt.bind(this) + applyStmt.bind(ratio) + applyStmt.bind(tiny) + applyStmt.bind(roundingMode) return applyStmt.one { TalerAmount( it.getLong(1), @@ -351,15 +339,12 @@ class AmountTest { } } - val revertStmt = conn.prepareStatement("SELECT (result).val, (result).frac FROM conversion_revert_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") + val revertStmt = conn.talerStatement("SELECT (result).val, (result).frac FROM conversion_revert_ratio((?, ?)::taler_amount, (?, ?)::taler_amount, (0, 0)::taler_amount, (?, ?)::taler_amount, ?::rounding_mode)") fun TalerAmount.revert(ratio: DecimalNumber, tiny: DecimalNumber = DecimalNumber("0.00000001"), roundingMode: String = "zero"): TalerAmount { - revertStmt.setLong(1, this.value) - revertStmt.setInt(2, this.frac) - revertStmt.setLong(3, ratio.value) - revertStmt.setInt(4, ratio.frac) - revertStmt.setLong(5, tiny.value) - revertStmt.setInt(6, tiny.frac) - revertStmt.setString(7, roundingMode) + revertStmt.bind(this) + revertStmt.bind(ratio) + revertStmt.bind(tiny) + revertStmt.bind(roundingMode) return revertStmt.one { TalerAmount( it.getLong(1), diff --git a/bank/src/test/kotlin/CoreBankApiTest.kt b/bank/src/test/kotlin/CoreBankApiTest.kt @@ -27,7 +27,7 @@ import tech.libeufin.bank.* import tech.libeufin.bank.auth.TOKEN_PREFIX import tech.libeufin.common.* import tech.libeufin.common.crypto.CryptoUtil -import tech.libeufin.common.db.one +import tech.libeufin.common.db.* import tech.libeufin.common.test.* import java.time.Duration import java.time.Instant @@ -54,7 +54,7 @@ class CoreBankSecurityTest { db.serializable( "UPDATE customers SET password_hash=? WHERE username='customer'" ) { - setString(1, hash) + bind(hash) executeUpdate() } assertEquals(hash, currentHash()) @@ -1766,11 +1766,10 @@ class CoreBankWithdrawalApiTest { INSERT INTO taler_withdrawal_operations(withdrawal_uuid,amount,selected_exchange_payto,selection_done,wallet_bank_account,creation_date) VALUES (?, (?, ?)::taler_amount, ?, true, 3, 0) """) { - setObject(1, uuid) - setLong(2, amount.value) - setInt(3, amount.frac) - setString(4, exchangePayto.canonical) - execute() + bind(uuid) + bind(amount) + bind(exchangePayto.canonical) + executeUpdate() } return client.postA("/accounts/customer/withdrawals/$uuid/confirm") diff --git a/bank/src/test/kotlin/DatabaseTest.kt b/bank/src/test/kotlin/DatabaseTest.kt @@ -25,8 +25,7 @@ import tech.libeufin.bank.db.AccountDAO.AccountCreationResult import tech.libeufin.common.TalerError import tech.libeufin.common.TalerErrorCode import tech.libeufin.common.assertOk -import tech.libeufin.common.db.one -import tech.libeufin.common.db.oneOrNull +import tech.libeufin.common.db.* import tech.libeufin.common.json import tech.libeufin.common.test.* import java.time.Duration @@ -50,45 +49,45 @@ class DatabaseTest { @Test fun tanChallenge() = bankSetup { db -> db.conn { conn -> - val createStmt = conn.prepareStatement("SELECT tan_challenge_create('','account_reconfig'::op_enum,?,?,?,?,'customer',NULL,NULL)") - val markSentStmt = conn.prepareStatement("SELECT tan_challenge_mark_sent(?,?,?)") - val tryStmt = conn.prepareStatement("SELECT out_ok, out_no_retry, out_expired FROM tan_challenge_try(?,'customer',?,?,true)") - val sendStmt = conn.prepareStatement("SELECT out_tan_code FROM tan_challenge_send(?,'customer',?,?,?,?,true,10)") + val createStmt = conn.talerStatement("SELECT tan_challenge_create('','account_reconfig'::op_enum,?,?,?,?,'customer',NULL,NULL)") + val markSentStmt = conn.talerStatement("SELECT tan_challenge_mark_sent(?,?,?)") + val tryStmt = conn.talerStatement("SELECT out_ok, out_no_retry, out_expired FROM tan_challenge_try(?,'customer',?,?,true)") + val sendStmt = conn.talerStatement("SELECT out_tan_code FROM tan_challenge_send(?,'customer',?,?,?,?,true,10)") val validityPeriod = Duration.ofHours(1) val retransmissionPeriod: Duration = Duration.ofMinutes(1) val retryCounter = 3 fun create(code: String, timestamp: Instant): Long { - createStmt.setString(1, code) - createStmt.setLong(2, ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) - createStmt.setLong(3, TimeUnit.MICROSECONDS.convert(validityPeriod)) - createStmt.setInt(4, retryCounter) + createStmt.bind(code) + createStmt.bind(ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) + createStmt.bind(TimeUnit.MICROSECONDS.convert(validityPeriod)) + createStmt.bind(retryCounter) return createStmt.one { it.getLong(1) } } fun markSent(id: Long, timestamp: Instant) { - markSentStmt.setLong(1, id) - markSentStmt.setLong(2, ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) - markSentStmt.setLong(3, TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) + markSentStmt.bind(id) + markSentStmt.bind(ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) + markSentStmt.bind(TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) return markSentStmt.one { } } fun cTry(id: Long, code: String, timestamp: Instant): Triple<Boolean, Boolean, Boolean> { - tryStmt.setLong(1, id) - tryStmt.setString(2, code) - tryStmt.setLong(3, ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) + tryStmt.bind(id) + tryStmt.bind(code) + tryStmt.bind(ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) return tryStmt.one { Triple(it.getBoolean(1), it.getBoolean(2), it.getBoolean(3)) } } fun send(id: Long, code: String, timestamp: Instant): String? { - sendStmt.setLong(1, id) - sendStmt.setString(2, code) - sendStmt.setLong(3, ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) - sendStmt.setLong(4, TimeUnit.MICROSECONDS.convert(validityPeriod)) - sendStmt.setInt(5, retryCounter) + sendStmt.bind(id) + sendStmt.bind(code) + sendStmt.bind(ChronoUnit.MICROS.between(Instant.EPOCH, timestamp)) + sendStmt.bind(TimeUnit.MICROSECONDS.convert(validityPeriod)) + sendStmt.bind(retryCounter) return sendStmt.oneOrNull { it.getString(1) } diff --git a/bank/src/test/kotlin/GcTest.kt b/bank/src/test/kotlin/GcTest.kt @@ -28,7 +28,7 @@ import tech.libeufin.bank.db.TransactionDAO.BankTransactionResult import tech.libeufin.bank.db.WithdrawalDAO.* import tech.libeufin.common.* import tech.libeufin.common.test.* -import tech.libeufin.common.db.one +import tech.libeufin.common.db.* import java.time.Duration import java.time.Instant import java.util.* @@ -39,7 +39,7 @@ class GcTest { @Test fun gc() = bankSetup { db -> db.conn { conn -> fun assertNb(nb: Int, stmt: String) { - assertEquals(nb, conn.prepareStatement(stmt).one { it.getInt(1) }) + assertEquals(nb, conn.talerStatement(stmt).one { it.getInt(1) }) } fun assertNbAccount(nb: Int) = assertNb(nb, "SELECT count(*) from bank_accounts") fun assertNbTokens(nb: Int) = assertNb(nb, "SELECT count(*) from bearer_tokens") diff --git a/bank/src/test/kotlin/StatsTest.kt b/bank/src/test/kotlin/StatsTest.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -26,7 +26,7 @@ import tech.libeufin.bank.Timeframe import tech.libeufin.common.ShortHashCode import tech.libeufin.common.TalerAmount import tech.libeufin.common.assertOkJson -import tech.libeufin.common.db.executeQueryCheck +import tech.libeufin.common.db.* import tech.libeufin.common.micros import tech.libeufin.common.test.* import java.time.Instant @@ -43,13 +43,12 @@ class StatsTest { suspend fun cashin(amount: String) { db.conn { conn -> - val stmt = conn.prepareStatement("SELECT 0 FROM cashin(?, ?, (?, ?)::taler_amount, ?)") - stmt.setLong(1, Instant.now().micros()) - stmt.setBytes(2, ShortHashCode.rand().raw) + val stmt = conn.talerStatement("SELECT 0 FROM cashin(?, ?, (?, ?)::taler_amount, ?)") + stmt.bind(Instant.now()) + stmt.bind(ShortHashCode.rand()) val amount = TalerAmount(amount) - stmt.setLong(3, amount.value) - stmt.setInt(4, amount.frac) - stmt.setString(5, "") + stmt.bind(amount) + stmt.bind("") stmt.executeQueryCheck() } } @@ -131,12 +130,11 @@ class StatsTest { fun timeframe() = bankSetup { db -> db.conn { conn -> fun register(timestamp: LocalDateTime, amount: TalerAmount) { - val stmt = conn.prepareStatement( + val stmt = conn.talerStatement( "CALL stats_register_payment('taler_out', ?::timestamp, (?, ?)::taler_amount, null)" ) - stmt.setObject(1, timestamp) - stmt.setLong(2, amount.value) - stmt.setInt(3, amount.frac) + stmt.bind(timestamp) + stmt.bind(amount) stmt.executeUpdate() } diff --git a/common/src/main/kotlin/db/DbPool.kt b/common/src/main/kotlin/db/DbPool.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -44,12 +44,12 @@ open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable { require(majorVersion >= MIN_VERSION) { "postgres version must be at least $MIN_VERSION.0 got $majorVersion.$minorVersion" } - checkMigrations(con, cfg, schema) + checkMigrations(con.unwrap(PgConnection::class.java), cfg, schema) } } /** Executes a query with automatic retry on serialization errors */ - suspend fun <R> serializable(query: String, lambda: PreparedStatement.() -> R): R = conn { conn -> + suspend fun <R> serializable(query: String, lambda: TalerStatement.() -> R): R = conn { conn -> // We could explicitly tell Postgres when a request is read-only, // but the performance improvement isn't obvious, it doesn't prevent // stored procedures from modifying the database and it adds a diff --git a/common/src/main/kotlin/db/helpers.kt b/common/src/main/kotlin/db/helpers.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -35,7 +35,7 @@ suspend fun <T> DbPool.page( params: PageParams, idName: String, query: String, - bind: PreparedStatement.() -> Int = { 0 }, + args: TalerStatement.() -> Unit = {}, map: (ResultSet) -> T ): List<T> { val backward = params.limit < 0 @@ -46,9 +46,9 @@ suspend fun <T> DbPool.page( LIMIT ? """ return serializable(pageQuery) { - val pad = bind() - setLong(pad + 1, params.offset) - setInt(pad + 2, abs(params.limit)) + args() + bind(params.offset) + bind(abs(params.limit)) all { map(it) } } } @@ -71,8 +71,7 @@ suspend fun <T> DbPool.poolHistory( "bank_transaction_id", "$query $accountColumn=? AND", { - setLong(1, bankAccountId) - 1 + bind(bankAccountId) }, map ) diff --git a/common/src/main/kotlin/db/schema.kt b/common/src/main/kotlin/db/schema.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -19,7 +19,7 @@ package tech.libeufin.common.db -import org.postgresql.ds.PGSimpleDataSource +import org.postgresql.ds.* import org.postgresql.jdbc.PgConnection import java.sql.Connection import kotlin.io.path.Path @@ -34,7 +34,7 @@ import kotlin.io.path.readText */ private fun maybeApplyV(conn: PgConnection, cfg: DatabaseConfig) { conn.transaction { - val checkVSchema = conn.prepareStatement( + val checkVSchema = conn.talerStatement( "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '_v'" ) if (!checkVSchema.executeQueryCheck()) { @@ -57,10 +57,10 @@ private fun initializeDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sq logger.info("doing DB initialization, sqldir ${cfg.sqlDir}") maybeApplyV(conn, cfg) conn.transaction { - val checkStmt = conn.prepareStatement("SELECT EXISTS(SELECT FROM _v.patches where patch_name = ?)") + val checkStmt = conn.talerStatement("SELECT EXISTS(SELECT FROM _v.patches where patch_name = ?)") for (patchName in migrationsPath(sqlFilePrefix)) { - checkStmt.setString(1, patchName) + checkStmt.bind(patchName) val applied = checkStmt.one { it.getBoolean(1) } if (applied) { logger.debug("patch $patchName already applied") @@ -85,11 +85,11 @@ private fun initializeDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sq } } -internal fun checkMigrations(conn: Connection, cfg: DatabaseConfig, sqlFilePrefix: String) { - val checkStmt = conn.prepareStatement("SELECT EXISTS(SELECT FROM _v.patches where patch_name = ?)") +internal fun checkMigrations(conn: PgConnection, cfg: DatabaseConfig, sqlFilePrefix: String) { + val checkStmt = conn.talerStatement("SELECT EXISTS(SELECT FROM _v.patches where patch_name = ?)") for (patchName in migrationsPath(sqlFilePrefix)) { - checkStmt.setString(1, patchName) + checkStmt.bind(patchName) val path = Path("${cfg.sqlDir}/$patchName.sql") if (!path.exists()) break val applied = checkStmt.one { it.getBoolean(1) } diff --git a/common/src/main/kotlin/db/statement.kt b/common/src/main/kotlin/db/statement.kt @@ -0,0 +1,203 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2025 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.postgresql.util.PSQLState +import tech.libeufin.common.* +import java.sql.* +import java.time.* +import java.util.* + +internal val logger: Logger = LoggerFactory.getLogger("libeufin-db") + +class TalerStatement(internal val stmt: PreparedStatement): java.io.Closeable { + override fun close() { + // Log warnings + var current = stmt.getWarnings() + while (current != null) { + //logger.warning(current.message) + current = current.getNextWarning() + } + // Close inner statement + stmt.close() + } + + /* ----- Bindings helpers ----- */ + + private var idx = 1; + + fun bind(string: String?) { + stmt.setString(idx, string) + idx+=1; + } + + fun bind(bool: Boolean) { + stmt.setBoolean(idx, bool) + idx+=1; + } + + fun bind(nb: Long) { + stmt.setLong(idx, nb) + idx+=1; + } + + fun bind(nb: Int) { + stmt.setInt(idx, nb) + idx+=1; + } + + fun bind(amount: TalerAmount?) { + bind(amount?.number()) + } + + fun bind(nb: DecimalNumber?) { + if (nb != null) { + stmt.setLong(idx, nb.value) + stmt.setInt(idx+1, nb.frac) + idx+=2 + } + } + + fun bind(timestamp: Instant) { + stmt.setLong(idx, timestamp.micros()) + idx+=1 + } + + fun bind(bytes: Base32Crockford64B?) { + stmt.setBytes(idx, bytes?.raw) + idx+=1 + } + + fun bind(bytes: Base32Crockford32B?) { + stmt.setBytes(idx, bytes?.raw) + idx+=1 + } + + fun bind(bytes: ByteArray?) { + stmt.setBytes(idx, bytes) + idx+=1 + } + + fun <T : kotlin.Enum<T>> bind(enum: T?) { + bind(enum?.name) + } + + fun bind(date: LocalDateTime) { + stmt.setObject(idx, date) + idx+=1 + } + + fun bind(uuid: UUID?) { + stmt.setObject(idx, uuid) + idx+=1 + } + + /* ----- Transaction helpers ----- */ + + fun executeQuery(): ResultSet { + return try { + stmt.executeQuery() + } finally { + stmt.clearParameters() + idx=1 + } + } + + fun executeUpdate(): Int { + return try { + stmt.executeUpdate() + } finally { + stmt.clearParameters() + idx=1 + } + } + + /** Read one row or null if none */ + fun <T> oneOrNull(lambda: (ResultSet) -> T): T? { + return executeQuery().use { + if (it.next()) lambda(it) else null + } + } + + /** Read one row or throw if none */ + fun <T> one(lambda: (ResultSet) -> T): T = + requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } + + /** Read one row or throw [err] in case or unique violation error */ + fun <T> oneUniqueViolation(err: T, lambda: (ResultSet) -> T): T { + return try { + one(lambda) + } catch (e: SQLException) { + if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return err + throw e // rethrowing, not to hide other types of errors. + } + } + + /** Read all rows */ + fun <T> all(lambda: (ResultSet) -> T): List<T> { + return executeQuery().use { + val ret = mutableListOf<T>() + while (it.next()) { + ret.add(lambda(it)) + } + ret + } + } + + /** Execute a query checking it return a least one row */ + fun executeQueryCheck(): Boolean { + return executeQuery().use { + it.next() + } + } + + /** Execute an update checking it update at least one row */ + fun executeUpdateCheck(): Boolean { + executeUpdate() + return stmt.updateCount > 0 + } + + /** Execute an update checking if fail because of unique violation error */ + fun executeUpdateViolation(): Boolean { + return try { + executeUpdateCheck() + } catch (e: SQLException) { + logger.debug(e.message) + if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false + throw e // rethrowing, not to hide other types of errors. + } + } + + /** Execute an update checking if fail because of unique violation error and resetting state */ + fun executeProcedureViolation(): Boolean { + val savepoint = stmt.connection.setSavepoint() + return try { + executeUpdate() + stmt.connection.releaseSavepoint(savepoint) + true + } catch (e: SQLException) { + stmt.connection.rollback(savepoint) + if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false + throw e // rethrowing, not to hide other types of errors. + } + } +} +\ No newline at end of file diff --git a/common/src/main/kotlin/db/transaction.kt b/common/src/main/kotlin/db/transaction.kt @@ -21,15 +21,11 @@ package tech.libeufin.common.db import org.postgresql.jdbc.PgConnection import org.postgresql.util.PSQLState -import org.slf4j.Logger -import org.slf4j.LoggerFactory import tech.libeufin.common.SERIALIZATION_RETRY import java.sql.PreparedStatement import java.sql.ResultSet import java.sql.SQLException -internal val logger: Logger = LoggerFactory.getLogger("libeufin-db") - val SERIALIZATION_ERROR = setOf( "40001", // serialization_failure "40P01", // deadlock_detected @@ -48,21 +44,11 @@ suspend fun <R> retrySerializationError(lambda: suspend () -> R): R { return lambda() } +fun PgConnection.talerStatement(query: String): TalerStatement = TalerStatement(prepareStatement(query)) + /** Run a postgres query using a prepared statement */ -inline fun <R> PgConnection.withStatement(query: String, lambda: PreparedStatement.() -> R): R { - val stmt = prepareStatement(query) - return stmt.use { - val res = stmt.lambda() - // Log warnings - var warning = stmt.getWarnings() - while (warning != null) { - logger.warning(warning.message) - warning = warning.getNextWarning() - } - stmt.clearWarnings() - res - } -} +inline fun <R> PgConnection.withStatement(query: String, lambda: TalerStatement.() -> R): R = + talerStatement(query).use { it.lambda() } /** Run a postgres [transaction] */ fun <R> PgConnection.transaction(transaction: (PgConnection) -> R): R { @@ -79,76 +65,6 @@ fun <R> PgConnection.transaction(transaction: (PgConnection) -> R): R { } } -/** Read one row or null if none */ -fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? { - executeQuery().use { - return if (it.next()) lambda(it) else null - } -} - -/** Read one row or throw if none */ -fun <T> PreparedStatement.one(lambda: (ResultSet) -> T): T = - requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } - -/** Read one row or throw [err] in case or unique violation error */ -fun <T> PreparedStatement.oneUniqueViolation(err: T, lambda: (ResultSet) -> T): T { - return try { - one(lambda) - } catch (e: SQLException) { - if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return err - throw e // rethrowing, not to hide other types of errors. - } -} - -/** Read all rows */ -fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> { - executeQuery().use { - val ret = mutableListOf<T>() - while (it.next()) { - ret.add(lambda(it)) - } - return ret - } -} - -/** Execute a query checking it return a least one row */ -fun PreparedStatement.executeQueryCheck(): Boolean { - executeQuery().use { - return it.next() - } -} - -/** Execute an update checking it update at least one row */ -fun PreparedStatement.executeUpdateCheck(): Boolean { - executeUpdate() - return updateCount > 0 -} - -/** Execute an update checking if fail because of unique violation error */ -fun PreparedStatement.executeUpdateViolation(): Boolean { - return try { - executeUpdateCheck() - } catch (e: SQLException) { - logger.debug(e.message) - if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false - throw e // rethrowing, not to hide other types of errors. - } -} - -/** Execute an update checking if fail because of unique violation error and resetting state */ -fun PreparedStatement.executeProcedureViolation(): Boolean { - val savepoint = connection.setSavepoint() - return try { - executeUpdate() - connection.releaseSavepoint(savepoint) - true - } catch (e: SQLException) { - connection.rollback(savepoint) - if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false - throw e // rethrowing, not to hide other types of errors. - } -} - /** * Execute an update of [table] with a dynamic query generated at runtime. * Every [fields] in each row matching [filter] are updated using values from [bind]. @@ -163,8 +79,8 @@ fun PgConnection.dynamicUpdate( if (sql.isEmpty()) return withStatement("UPDATE $table SET $sql $filter") { for ((idx, value) in bind.withIndex()) { - setObject(idx + 1, value) + stmt.setObject(idx + 1, value) } - executeUpdate() + stmt.executeUpdate() } } \ No newline at end of file diff --git a/common/src/main/kotlin/db/types.kt b/common/src/main/kotlin/db/types.kt @@ -20,7 +20,17 @@ package tech.libeufin.common.db import tech.libeufin.common.* -import java.sql.ResultSet +import java.sql.* +import java.time.* +import java.util.* + +fun optAmount(amount: TalerAmount?): String { + if (amount != null) { + return "(?,?)::taler_amount" + } else { + return "NULL" + } +} inline fun <reified T : Enum<T>> ResultSet.getEnum(name: String): T = java.lang.Enum.valueOf(T::class.java, getString(name)) diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/EbicsDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/EbicsDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -19,15 +19,13 @@ package tech.libeufin.nexus.db -import tech.libeufin.common.db.oneOrNull - /** Data access logic for EBICS transaction */ class EbicsDAO(private val db: Database) { /** Register a pending transaction */ suspend fun register(id: String) = db.serializable( "INSERT INTO pending_ebics_transactions (tx_id) VALUES (?) ON CONFLICT DO NOTHING" ) { - setString(1, id) + bind(id) executeUpdate() } @@ -35,7 +33,7 @@ class EbicsDAO(private val db: Database) { suspend fun remove(id: String) = db.serializable( "DELETE FROM pending_ebics_transactions WHERE tx_id = ?" ) { - setString(1, id) + bind(id) executeUpdate() } diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt @@ -122,15 +122,14 @@ class ExchangeDAO(private val db: Database) { ) { val subject = "${req.wtid} ${req.exchange_base_url.url}" - setBytes(1, req.request_uid.raw) - setBytes(2, req.wtid.raw) - setString(3, subject) - setLong(4, req.amount.value) - setInt(5, req.amount.frac) - setString(6, req.exchange_base_url.url) - setString(7, req.credit_account.canonical) - setString(8, endToEndId) - setLong(9, timestamp.micros()) + bind(req.request_uid) + bind(req.wtid) + bind(subject) + bind(req.amount) + bind(req.exchange_base_url.url) + bind(req.credit_account.canonical) + bind(endToEndId) + bind(timestamp) one { when { it.getBoolean("out_request_uid_reuse") -> TransferResult.RequestUidReuse @@ -162,7 +161,7 @@ class ExchangeDAO(private val db: Database) { WHERE initiated_outgoing_transaction_id=? """ ) { - setLong(1, id) + bind(id) oneOrNull { TransferStatus( status = it.getEnum<SubmissionState>("status").toTransferStatus(), @@ -202,15 +201,13 @@ class ExchangeDAO(private val db: Database) { """, { when (params.status) { - null -> 0 + null -> {} TransferStatusState.pending -> { - setString(1, SubmissionState.pending.name) - setString(2, SubmissionState.unsubmitted.name) - 2 + bind(SubmissionState.pending) + bind(SubmissionState.unsubmitted) } else -> { - setString(1, params.status?.name) - 1 + bind(params.status) } } } diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt @@ -52,12 +52,11 @@ class InitiatedDAO(private val db: Database) { """ ) { // TODO check payto uri - setLong(1, payment.amount.value) - setInt(2, payment.amount.frac) - setString(3, payment.subject) - setString(4, payment.creditor.toString()) - setLong(5, payment.initiationTime.micros()) - setString(6, payment.endToEndId) + bind(payment.amount) + bind(payment.subject) + bind(payment.creditor.toString()) + bind(payment.initiationTime) + bind(payment.endToEndId) oneUniqueViolation(PaymentInitiationResult.RequestUidReuse) { PaymentInitiationResult.Success(it.getLong("initiated_outgoing_transaction_id")) } @@ -81,9 +80,9 @@ class InitiatedDAO(private val db: Database) { WHERE initiated_outgoing_batch_id = ? AND order_id IS NULL """ ) { - setLong(1, timestamp.micros()) - setString(2, orderId) - setLong(3, id) + bind(timestamp) + bind(orderId) + bind(id) executeUpdate() } if (updated > 0) { @@ -95,8 +94,8 @@ class InitiatedDAO(private val db: Database) { WHERE initiated_outgoing_batch_id = ? AND $UNSETTLED_FILTER """ ) { - setLong(1, id) - execute() + bind(id) + executeUpdate() } } } @@ -120,14 +119,14 @@ class InitiatedDAO(private val db: Database) { """ ) { if (permanent) { - setString(1, StatusUpdate.permanent_failure.name) + bind(StatusUpdate.permanent_failure) } else { - setString(1, StatusUpdate.transient_failure.name) + bind(StatusUpdate.transient_failure) } - setLong(2, timestamp.micros()) - setString(3, msg) - setLong(4, id) - execute() + bind(timestamp) + bind(msg) + bind(id) + executeUpdate() } // Update unsettled batch's transaction status tx.withStatement( @@ -138,13 +137,13 @@ class InitiatedDAO(private val db: Database) { """ ) { if (permanent) { - setString(1, StatusUpdate.permanent_failure.name) + bind(StatusUpdate.permanent_failure) } else { - setString(1, StatusUpdate.transient_failure.name) + bind(StatusUpdate.transient_failure) } - setString(2, msg) - setLong(3, id) - execute() + bind(msg) + bind(id) + executeUpdate() } } @@ -159,8 +158,8 @@ class InitiatedDAO(private val db: Database) { RETURNING initiated_outgoing_batch_id """ ) { - setString(1, msg) - setString(2, orderId) + bind(msg) + bind(orderId) oneOrNull { it.getLong(1) } } if (batchId != null) { @@ -172,9 +171,9 @@ class InitiatedDAO(private val db: Database) { WHERE initiated_outgoing_batch_id = ? AND $PENDING_FILTER """ ) { - setString(1, msg) - setLong(2, batchId) - execute() + bind(msg) + bind(batchId) + executeUpdate() } } } @@ -190,7 +189,7 @@ class InitiatedDAO(private val db: Database) { RETURNING initiated_outgoing_batch_id, message_id """ ) { - setString(1, orderId) + bind(orderId) oneOrNull { Pair( it.getLong("initiated_outgoing_batch_id"), @@ -208,8 +207,8 @@ class InitiatedDAO(private val db: Database) { WHERE initiated_outgoing_batch_id = ? AND $UNSETTLED_FILTER """ ) { - setLong(1, batchId) - execute() + bind(batchId) + executeUpdate() } messageId } @@ -225,7 +224,7 @@ class InitiatedDAO(private val db: Database) { RETURNING initiated_outgoing_batch_id, message_id, status_msg """ ) { - setString(1, orderId) + bind(orderId) oneOrNull { Triple( it.getLong("initiated_outgoing_batch_id"), @@ -244,8 +243,8 @@ class InitiatedDAO(private val db: Database) { WHERE initiated_outgoing_batch_id = ? """ ) { - setLong(1, batchId) - execute() + bind(batchId) + executeUpdate() } Pair(messageId, msg) } @@ -254,21 +253,21 @@ class InitiatedDAO(private val db: Database) { suspend fun batchStatusUpdate(msgId: String, state: StatusUpdate, msg: String?) = db.serializable( "SELECT FROM batch_status_update(?,?::submission_state,?)" ) { - setString(1, msgId) - setString(2, state.name) - setString(3, msg) - execute() + bind(msgId) + bind(state) + bind(msg) + executeQuery() } /** Register payment status [state] with [msg] for transaction [endToEndId] in batch [msgId] */ suspend fun txStatusUpdate(endToEndId: String, msgId: String?, state: StatusUpdate, msg: String?) = db.serializable( "SELECT FROM tx_status_update(?,?,?::submission_state,?)" ) { - setString(1, endToEndId) - setString(2, msgId) - setString(3, state.name) - setString(4, msg) - execute() + bind(endToEndId) + bind(msgId) + bind(state) + bind(msg) + executeQuery() } /** Unsettled initiated payment in batch [msgId] */ @@ -285,7 +284,7 @@ class InitiatedDAO(private val db: Database) { AND initiated_outgoing_transactions.$UNSETTLED_FILTER """ ) { - setString(1, msgId) + bind(msgId) all { OutgoingPayment( id = OutgoingId( @@ -304,9 +303,9 @@ class InitiatedDAO(private val db: Database) { /** Group unbatched transaction into a single batch */ suspend fun batch(timestamp: Instant, ebicsId: String) { db.serializable("SELECT FROM batch_outgoing_transactions(?, ?)") { - setLong(1, timestamp.micros()) - setString(2, ebicsId) - execute() + bind(timestamp) + bind(ebicsId) + executeQuery() } } diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/KvDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/KvDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2024 Taler Systems S.A. + * Copyright (C) 2024-2025 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -58,7 +58,7 @@ class KvDAO( val db: Database) { suspend inline fun <reified T> get(key: String): T? = db.serializable( "SELECT value FROM kv WHERE key=?" ) { - setString(1, key) + bind(key) oneOrNull { val encoded = it.getString(1) JSON.decodeFromString<T>(encoded) @@ -70,8 +70,8 @@ class KvDAO( val db: Database) { "INSERT INTO kv (key, value) VALUES (?, ?::jsonb) ON CONFLICT (key) DO UPDATE SET value=EXCLUDED.value" ) { val encoded = JSON.encodeToString<T>(value) - setString(1, key) - setString(2, encoded) - execute() + bind(key) + bind(encoded) + executeUpdate() } } \ No newline at end of file diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt @@ -48,16 +48,15 @@ class PaymentDAO(private val db: Database) { ) { val executionTime = payment.executionTime.micros() - setLong(1, payment.amount.value) - setInt(2, payment.amount.frac) - setString(3, payment.subject) - setLong(4, executionTime) - setString(5, payment.creditor?.toString()) - setString(6, payment.id.endToEndId) - setString(7, payment.id.msgId) - setString(8, payment.id.acctSvcrRef) - setBytes(9, wtid?.raw) - setString(10, baseUrl?.url) + bind(payment.amount) + bind(payment.subject) + bind(executionTime) + bind(payment.creditor?.toString()) + bind(payment.id.endToEndId) + bind(payment.id.msgId) + bind(payment.id.acctSvcrRef) + bind(wtid) + bind(baseUrl?.url) one { OutgoingRegistrationResult( @@ -93,20 +92,17 @@ class PaymentDAO(private val db: Database) { FROM register_and_bounce_incoming((?,?)::taler_amount,(?,?)::taler_amount,?,?,?,?,?,?,(?,?)::taler_amount,?,?) """ ) { - setLong(1, payment.amount.value) - setInt(2, payment.amount.frac) - setLong(3, payment.creditFee?.value ?: 0L) - setInt(4, payment.creditFee?.frac ?: 0) - setString(5, payment.subject) - setLong(6, payment.executionTime.micros()) - setString(7, payment.debtor?.toString()) - setObject(8, payment.id.uetr) - setString(9, payment.id.txId) - setString(10, payment.id.acctSvcrRef) - setLong(11, bounceAmount.value) - setInt(12, bounceAmount.frac) - setLong(13, timestamp.micros()) - setString(14, bounceEndToEndId) + bind(payment.amount) + bind(payment.creditFee ?: TalerAmount.zero(db.currency)) + bind(payment.subject) + bind(payment.executionTime) + bind(payment.debtor?.toString()) + bind(payment.id.uetr) + bind(payment.id.txId) + bind(payment.id.acctSvcrRef) + bind(bounceAmount) + bind(timestamp) + bind(bounceEndToEndId) one { IncomingBounceRegistrationResult( it.getLong("out_tx_id"), @@ -133,19 +129,16 @@ class PaymentDAO(private val db: Database) { FROM register_incoming((?,?)::taler_amount,(?,?)::taler_amount,?,?,?,?,?,?,?::taler_incoming_type,?) """ ) { - val executionTime = payment.executionTime.micros() - setLong(1, payment.amount.value) - setInt(2, payment.amount.frac) - setLong(3, payment.creditFee?.value ?: 0L) - setInt(4, payment.creditFee?.frac ?: 0) - setString(5, payment.subject) - setLong(6, executionTime) - setString(7, payment.debtor?.toString()) - setObject(8, payment.id.uetr) - setString(9, payment.id.txId) - setString(10, payment.id.acctSvcrRef) - setString(11, metadata.type.name) - setBytes(12, metadata.key.raw) + bind(payment.amount) + bind(payment.creditFee ?: TalerAmount.zero(db.currency)) + bind(payment.subject) + bind(payment.executionTime) + bind(payment.debtor?.toString()) + bind(payment.id.uetr) + bind(payment.id.txId) + bind(payment.id.acctSvcrRef) + bind(metadata.type) + bind(metadata.key) one { when { it.getBoolean("out_reserve_pub_reuse") -> IncomingRegistrationResult.ReservePubReuse @@ -167,17 +160,14 @@ class PaymentDAO(private val db: Database) { FROM register_incoming((?,?)::taler_amount,(?,?)::taler_amount,?,?,?,?,?,?,NULL,NULL) """ ) { - val executionTime = payment.executionTime.micros() - setLong(1, payment.amount.value) - setInt(2, payment.amount.frac) - setLong(3, payment.creditFee?.value ?: 0L) - setInt(4, payment.creditFee?.frac ?: 0) - setString(5, payment.subject) - setLong(6, executionTime) - setString(7, payment.debtor?.toString()) - setObject(8, payment.id.uetr) - setString(9, payment.id.txId) - setString(10, payment.id.acctSvcrRef) + bind(payment.amount) + bind(payment.creditFee ?: TalerAmount.zero(db.currency)) + bind(payment.subject) + bind(payment.executionTime) + bind(payment.debtor?.toString()) + bind(payment.id.uetr) + bind(payment.id.txId) + bind(payment.id.acctSvcrRef) one { IncomingRegistrationResult.Success( it.getLong("out_tx_id"), diff --git a/nexus/src/test/kotlin/DatabaseTest.kt b/nexus/src/test/kotlin/DatabaseTest.kt @@ -171,7 +171,7 @@ class IncomingPaymentsTest { } db.conn { // Checking one incoming got created - val checkIncoming = it.prepareStatement(""" + val checkIncoming = it.talerStatement(""" SELECT (amount).val as amount_value, (amount).frac as amount_frac FROM incoming_transactions WHERE incoming_transaction_id = 1 """).executeQuery() @@ -179,13 +179,13 @@ class IncomingPaymentsTest { assertEquals(payment.amount.value, checkIncoming.getLong("amount_value")) assertEquals(payment.amount.frac, checkIncoming.getInt("amount_frac")) // Checking the bounced table got its row. - val checkBounced = it.prepareStatement(""" + val checkBounced = it.talerStatement(""" SELECT 1 FROM bounced_transactions WHERE incoming_transaction_id = 1 AND initiated_outgoing_transaction_id = 1 """).executeQuery() assertTrue(checkBounced.next()) // check the related initiated payment exists. - val checkInitiated = it.prepareStatement(""" + val checkInitiated = it.talerStatement(""" SELECT (amount).val as amount_value ,(amount).frac as amount_frac @@ -297,11 +297,11 @@ class IncomingPaymentsTest { FROM incoming_transactions ORDER BY incoming_transaction_id DESC LIMIT 1 """ ) { - setObject(1, payment.id.uetr) - setString(2, payment.id.txId) - setString(3, payment.id.acctSvcrRef) - setString(4, payment.subject) - setString(5, payment.debtor?.toString()) + bind(payment.id.uetr) + bind(payment.id.txId) + bind(payment.id.acctSvcrRef) + bind(payment.subject) + bind(payment.debtor?.toString()) one { assertTrue(it.getBoolean(1)) } @@ -409,7 +409,7 @@ class PaymentInitiationsTest { SELECT message_id, status, status_msg FROM initiated_outgoing_batches WHERE initiated_outgoing_batch_id=? """ ) { - setLong(1, batchId) + bind(batchId) one { val msgId = it.getString("message_id") assertEquals( @@ -426,7 +426,7 @@ class PaymentInitiationsTest { SELECT end_to_end_id, status, status_msg FROM initiated_outgoing_transactions WHERE initiated_outgoing_batch_id=? """ ) { - setLong(1, batchId) + bind(batchId) all { val endToEndId = it.getString("end_to_end_id") val expected = when (endToEndId) { @@ -450,7 +450,7 @@ class PaymentInitiationsTest { val batchId = db.serializable( "SELECT initiated_outgoing_batch_id FROM initiated_outgoing_batches WHERE order_id=?" ) { - setString(1, orderId) + bind(orderId) one { it.getLong("initiated_outgoing_batch_id") } diff --git a/testbench/src/test/kotlin/IntegrationTest.kt b/testbench/src/test/kotlin/IntegrationTest.kt @@ -35,7 +35,6 @@ import tech.libeufin.bank.cli.LibeufinBank import tech.libeufin.common.* import tech.libeufin.common.test.* import tech.libeufin.common.api.engine -import tech.libeufin.common.db.one import tech.libeufin.nexus.* import tech.libeufin.nexus.cli.LibeufinNexus import tech.libeufin.nexus.cli.registerIncomingPayment