libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

commit 2fd8d17ec7010947d038d8a34bbf7aa9402ff214
parent a2436996041807e3bb136c436053f007b6340cea
Author: Antoine A <>
Date:   Fri, 21 Jun 2024 00:37:31 +0200

common: close opened PreparedStatement and make some functions IMMUTABLE

Diffstat:
Mbank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt | 47++++++++++++++++++++++++-----------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt | 72++++++++++++++++++++++++++++++++++++++++--------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt | 15++++++++-------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt | 134++++++++++++++++++++++++++++++++++++++++----------------------------------------
Mcommon/src/main/kotlin/db/DbPool.kt | 13++++++++-----
Mcommon/src/main/kotlin/db/transaction.kt | 7++++++-
Mdatabase-versioning/libeufin-bank-procedures.sql | 6+++---
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt | 244+++++++++++++++++++++++++++++++++++++++++--------------------------------------
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt | 220++++++++++++++++++++++++++++++++++++++++----------------------------------------
9 files changed, 392 insertions(+), 366 deletions(-)

diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt @@ -55,7 +55,7 @@ class AccountDAO(private val db: Database) { ctx: BankPaytoCtx ): AccountCreationResult = db.serializableTransaction { conn -> val timestamp = Instant.now().micros() - val idempotent = conn.prepareStatement(""" + val idempotent = conn.withStatement(""" SELECT password_hash, name=? AND email IS NOT DISTINCT FROM ? AND phone IS NOT DISTINCT FROM ? @@ -69,7 +69,7 @@ class AccountDAO(private val db: Database) { JOIN bank_accounts ON customer_id=owning_customer_id WHERE login=? - """).run { + """) { // TODO check max debt and min checkout ? setString(1, name) setString(2, email) @@ -97,12 +97,12 @@ class AccountDAO(private val db: Database) { } } else { if (internalPayto is IbanPayto) - conn.prepareStatement(""" + conn.withStatement (""" INSERT INTO iban_history( iban ,creation_time ) VALUES (?, ?) - """).run { + """) { setString(1, internalPayto.iban.value) setLong(2, timestamp) if (!executeUpdateViolation()) { @@ -111,7 +111,7 @@ class AccountDAO(private val db: Database) { } } - val customerId = conn.prepareStatement(""" + val customerId = conn.withStatement(""" INSERT INTO customers ( login ,password_hash @@ -123,7 +123,7 @@ class AccountDAO(private val db: Database) { ) VALUES (?, ?, ?, ?, ?, ?, ?::tan_enum) RETURNING customer_id """ - ).run { + ) { setString(1, login) setString(2, PwCrypto.hashpw(password)) setString(3, name) @@ -134,7 +134,7 @@ class AccountDAO(private val db: Database) { oneOrNull { it.getLong("customer_id") }!! } - conn.prepareStatement(""" + conn.withStatement(""" INSERT INTO bank_accounts( internal_payto_uri ,owning_customer_id @@ -143,7 +143,7 @@ class AccountDAO(private val db: Database) { ,max_debt ,min_cashout ) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${if (minCashout == null) "NULL" else "(?, ?)::taler_amount"}) - """).run { + """) { setString(1, internalPayto.canonical) setLong(2, customerId) setBoolean(3, isPublic) @@ -161,10 +161,10 @@ class AccountDAO(private val db: Database) { } if (bonus.value != 0L || bonus.frac != 0) { - conn.prepareStatement(""" + conn.withStatement(""" SELECT out_balance_insufficient FROM bank_transaction(?,'admin','bonus',(?,?)::taler_amount,?,true,NULL,(0, 0)::taler_amount) - """).run { + """) { setString(1, internalPayto.canonical) setLong(2, bonus.value) setInt(3, bonus.frac) @@ -266,7 +266,7 @@ class AccountDAO(private val db: Database) { ) // Get user ID and current data - val curr = conn.prepareStatement(""" + val curr = conn.withStatement(""" SELECT customer_id, tan_channel, phone, email, name, cashout_payto ,(max_debt).val AS max_debt_val @@ -277,7 +277,7 @@ class AccountDAO(private val db: Database) { JOIN bank_accounts ON customer_id=owning_customer_id WHERE login=? AND deleted_at IS NULL - """).run { + """) { setString(1, login) oneOrNull { CurrentAccount( @@ -339,9 +339,10 @@ class AccountDAO(private val db: Database) { // Invalidate current challenges if (patchChannel != null || patchInfo != null) { - val stmt = conn.prepareStatement("UPDATE tan_challenges SET expiration_date=0 WHERE customer=?") - stmt.setLong(1, curr.id) - stmt.execute() + conn.withStatement("UPDATE tan_challenges SET expiration_date=0 WHERE customer=?") { + setLong(1, curr.id) + execute() + } } // Update bank info @@ -412,10 +413,10 @@ class AccountDAO(private val db: Database) { oldPw: String?, is2fa: Boolean ): AccountPatchAuthResult = db.serializableTransaction { conn -> - val (currentPwh, tanRequired) = conn.prepareStatement(""" + val (currentPwh, tanRequired) = conn.withStatement(""" SELECT password_hash, (NOT ? AND tan_channel IS NOT NULL) FROM customers WHERE login=? AND deleted_at IS NULL - """).run { + """) { setBoolean(1, is2fa) setString(2, login) oneOrNull { @@ -427,12 +428,12 @@ class AccountDAO(private val db: Database) { } else if (oldPw != null && !PwCrypto.checkpw(oldPw, currentPwh)) { AccountPatchAuthResult.OldPasswordMismatch } else { - val stmt = conn.prepareStatement(""" - UPDATE customers SET password_hash=? where login=? - """) - stmt.setString(1, PwCrypto.hashpw(newPw)) - stmt.setString(2, login) - stmt.executeUpdateCheck() + conn.withStatement("UPDATE customers SET password_hash=? where login=?") { + setString(1, PwCrypto.hashpw(newPw)) + setString(2, login) + executeUpdateCheck() + } + AccountPatchAuthResult.Success } } diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt @@ -28,43 +28,48 @@ import tech.libeufin.common.db.* class ConversionDAO(private val db: Database) { /** Update in-db conversion config */ suspend fun updateConfig(cfg: ConversionRate) = db.serializableTransaction { conn -> - var stmt = conn.prepareStatement("CALL config_set_amount(?, (?, ?)::taler_amount)") - for ((name, amount) in listOf( - Pair("cashin_ratio", cfg.cashin_ratio), - Pair("cashout_ratio", cfg.cashout_ratio), - )) { - stmt.setString(1, name) - stmt.setLong(2, amount.value) - stmt.setInt(3, amount.frac) - stmt.executeUpdate() - } - for ((name, amount) in listOf( - Pair("cashin_fee", cfg.cashin_fee), - Pair("cashin_tiny_amount", cfg.cashin_tiny_amount), - Pair("cashin_min_amount", cfg.cashin_min_amount), - Pair("cashout_fee", cfg.cashout_fee), - Pair("cashout_tiny_amount", cfg.cashout_tiny_amount), - Pair("cashout_min_amount", cfg.cashout_min_amount), - )) { - stmt.setString(1, name) - stmt.setLong(2, amount.value) - stmt.setInt(3, amount.frac) - stmt.executeUpdate() + conn.withStatement("CALL config_set_amount(?, (?, ?)::taler_amount)") { + for ((name, amount) in listOf( + Pair("cashin_ratio", cfg.cashin_ratio), + Pair("cashout_ratio", cfg.cashout_ratio), + )) { + setString(1, name) + setLong(2, amount.value) + setInt(3, amount.frac) + executeUpdate() + } + for ((name, amount) in listOf( + Pair("cashin_fee", cfg.cashin_fee), + Pair("cashin_tiny_amount", cfg.cashin_tiny_amount), + Pair("cashin_min_amount", cfg.cashin_min_amount), + Pair("cashout_fee", cfg.cashout_fee), + Pair("cashout_tiny_amount", cfg.cashout_tiny_amount), + Pair("cashout_min_amount", cfg.cashout_min_amount), + )) { + setString(1, name) + setLong(2, amount.value) + setInt(3, amount.frac) + executeUpdate() + } } - stmt = conn.prepareStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)") - for ((name, value) in listOf( - Pair("cashin_rounding_mode", cfg.cashin_rounding_mode), - Pair("cashout_rounding_mode", cfg.cashout_rounding_mode) - )) { - stmt.setString(1, name) - stmt.setString(2, value.name) - stmt.executeUpdate() + + conn.withStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)") { + for ((name, value) in listOf( + Pair("cashin_rounding_mode", cfg.cashin_rounding_mode), + Pair("cashout_rounding_mode", cfg.cashout_rounding_mode) + )) { + setString(1, name) + setString(2, value.name) + executeUpdate() + } } } /** Get in-db conversion config */ suspend fun getConfig(regional: String, fiat: String): ConversionRate? = db.serializableTransaction { conn -> - val check = conn.prepareStatement("select exists(select 1 from config where key='cashin_ratio')").oneOrNull { it.getBoolean(1) }!! + val check = conn.withStatement("select exists(select 1 from config where key='cashin_ratio')") { + one { it.getBoolean(1) } + } if (!check) return@serializableTransaction null val amount = conn.prepareStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount") val roundingMode = conn.prepareStatement("SELECT config_get_rounding_mode(?)") @@ -77,7 +82,7 @@ class ConversionDAO(private val db: Database) { roundingMode.setString(1, name) return roundingMode.oneOrNull { RoundingMode.valueOf(it.getString(1)) }!! } - ConversionRate( + val rate = ConversionRate( cashin_ratio = getRatio("cashin_ratio"), cashin_fee = getAmount("cashin_fee", regional), cashin_tiny_amount = getAmount("cashin_tiny_amount", regional), @@ -89,6 +94,9 @@ class ConversionDAO(private val db: Database) { cashout_rounding_mode = getMode("cashout_rounding_mode"), cashout_min_amount = getAmount("cashout_min_amount", regional), ) + amount.close() + roundingMode.close() + rate } /** Clear in-db conversion config */ diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt @@ -20,6 +20,7 @@ package tech.libeufin.bank.db import tech.libeufin.common.micros +import tech.libeufin.common.db.withStatement import java.time.Duration import java.time.Instant @@ -37,9 +38,9 @@ class GcDAO(private val db: Database) { val deleteAfterMicro = timestamp.minus(deleteAfter).micros() // Abort pending operations - conn.prepareStatement( + conn.withStatement( "UPDATE taler_withdrawal_operations SET aborted = true WHERE creation_date < ?" - ).run { + ) { setLong(1, abortAfterMicro) execute() } @@ -50,27 +51,27 @@ class GcDAO(private val db: Database) { "DELETE FROM tan_challenges WHERE expiration_date < ?", "DELETE FROM bearer_tokens WHERE expiration_time < ?" )) { - conn.prepareStatement(smt).run { + conn.withStatement(smt) { setLong(1, cleanAfterMicro) execute() } } // Delete old bank transactions, linked operations are deleted by CASCADE - conn.prepareStatement( + conn.withStatement( "DELETE FROM bank_account_transactions WHERE transaction_date < ?" - ).run { + ) { setLong(1, deleteAfterMicro) execute() } // Hard delete soft deleted customer without bank transactions, bank account are deleted by CASCADE - conn.prepareStatement(""" + conn.withStatement(""" DELETE FROM customers WHERE deleted_at IS NOT NULL AND NOT EXISTS( SELECT 1 FROM bank_account_transactions NATURAL JOIN bank_accounts WHERE owning_customer_id=customer_id ) - """).run { + """) { execute() } diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt @@ -54,7 +54,7 @@ class TransactionDAO(private val db: Database) { wireTransferFees: TalerAmount ): BankTransactionResult = db.serializableTransaction { conn -> val timestamp = timestamp.micros() - val stmt = conn.prepareStatement(""" + conn.withStatement(""" SELECT out_creditor_not_found ,out_debtor_not_found @@ -72,77 +72,77 @@ class TransactionDAO(private val db: Database) { ,out_idempotent FROM bank_transaction(?,?,?,(?,?)::taler_amount,?,?,?,(?,?)::taler_amount) """ - ) - stmt.setString(1, creditAccountPayto.canonical) - stmt.setString(2, debitAccountUsername) - stmt.setString(3, subject) - stmt.setLong(4, amount.value) - stmt.setInt(5, amount.frac) - stmt.setLong(6, timestamp) - stmt.setBoolean(7, is2fa) - stmt.setBytes(8, requestUid?.raw) - stmt.setLong(9, wireTransferFees.value) - stmt.setInt(10, wireTransferFees.frac) - stmt.one { - when { - it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor - it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor - it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame - it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient - it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor - it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse - it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id")) - it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired - else -> { - val creditAccountId = it.getLong("out_credit_bank_account_id") - val creditRowId = it.getLong("out_credit_row_id") - val debitAccountId = it.getLong("out_debit_bank_account_id") - val debitRowId = it.getLong("out_debit_row_id") - val exchangeCreditor = it.getBoolean("out_creditor_is_exchange") - val exchangeDebtor = it.getBoolean("out_debtor_is_exchange") - if (exchangeCreditor && exchangeDebtor) { - logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future") - } else if (exchangeCreditor) { - val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold( - onSuccess = { reservePub -> - val registered = conn.prepareStatement("CALL register_incoming(?, ?)").run { - setBytes(1, reservePub.raw) - setLong(2, creditRowId) - executeProcedureViolation() + ) { + setString(1, creditAccountPayto.canonical) + setString(2, debitAccountUsername) + setString(3, subject) + setLong(4, amount.value) + setInt(5, amount.frac) + setLong(6, timestamp) + setBoolean(7, is2fa) + setBytes(8, requestUid?.raw) + setLong(9, wireTransferFees.value) + setInt(10, wireTransferFees.frac) + one { + when { + it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor + it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor + it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame + it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient + it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor + it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse + it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id")) + it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired + else -> { + val creditAccountId = it.getLong("out_credit_bank_account_id") + val creditRowId = it.getLong("out_credit_row_id") + val debitAccountId = it.getLong("out_debit_bank_account_id") + val debitRowId = it.getLong("out_debit_row_id") + val exchangeCreditor = it.getBoolean("out_creditor_is_exchange") + val exchangeDebtor = it.getBoolean("out_debtor_is_exchange") + if (exchangeCreditor && exchangeDebtor) { + logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future") + } else if (exchangeCreditor) { + val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold( + onSuccess = { reservePub -> + val registered = conn.withStatement("CALL register_incoming(?, ?)") { + setBytes(1, reservePub.raw) + setLong(2, creditRowId) + executeProcedureViolation() + } + if (!registered) { + logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key") + "reserve public key reuse" + } else { + null + } + }, + onFailure = { e -> + logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}") + "malformed metadata: ${e.message}" } - if (!registered) { - logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key") - "reserve public key reuse" - } else { - null + ) + if (bounceCause != null) { + // No error can happens because an opposite transaction already took place in the same transaction + conn.withStatement(""" + SELECT bank_wire_transfer( + ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount + ); + """) { + setLong(1, debitAccountId) + setLong(2, creditAccountId) + setString(3, "Bounce $creditRowId: $bounceCause") + setLong(4, amount.value) + setInt(5, amount.frac) + setLong(6, timestamp) + executeQuery() } - }, - onFailure = { e -> - logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}") - "malformed metadata: ${e.message}" - } - ) - if (bounceCause != null) { - // No error can happens because an opposite transaction already took place in the same transaction - conn.prepareStatement(""" - SELECT bank_wire_transfer( - ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount - ); - """ - ).run { - setLong(1, debitAccountId) - setLong(2, creditAccountId) - setString(3, "Bounce $creditRowId: $bounceCause") - setLong(4, amount.value) - setInt(5, amount.frac) - setLong(6, timestamp) - executeQuery() } + } else if (exchangeDebtor) { + logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead") } - } else if (exchangeDebtor) { - logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead") + BankTransactionResult.Success(debitRowId) } - BankTransactionResult.Success(debitRowId) } } } diff --git a/common/src/main/kotlin/db/DbPool.kt b/common/src/main/kotlin/db/DbPool.kt @@ -52,15 +52,18 @@ open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable { /** Executes a read-only query with automatic retry on serialization errors */ suspend fun <R> serializableRead(query: String, lambda: PreparedStatement.() -> R): R = conn { conn -> - val stmt = conn.prepareStatement(query) - // TODO explicit read only for better perf ? - retrySerializationError { stmt.lambda() } + conn.withStatement(query) { + // TODO explicit read only for better perf ? + retrySerializationError { lambda() } + } + } /** Executes a query with automatic retry on serialization errors */ suspend fun <R> serializableWrite(query: String, lambda: PreparedStatement.() -> R): R = conn { conn -> - val stmt = conn.prepareStatement(query) - retrySerializationError { stmt.lambda() } + conn.withStatement(query) { + retrySerializationError { lambda() } + } } /** Executes a transaction with automatic retry on serialization errors */ diff --git a/common/src/main/kotlin/db/transaction.kt b/common/src/main/kotlin/db/transaction.kt @@ -43,6 +43,11 @@ suspend fun <R> retrySerializationError(lambda: suspend () -> R): R { return lambda() } +/** Run a postgres query using a prepared statement */ +inline fun <R> PgConnection.withStatement(query: String, lambda: PreparedStatement.() -> R): R { + return prepareStatement(query).use(lambda) +} + /** Run a postgres [transaction] */ fun <R> PgConnection.transaction(transaction: (PgConnection) -> R): R { try { @@ -140,7 +145,7 @@ fun PgConnection.dynamicUpdate( ) { val sql = fields.joinToString() if (sql.isEmpty()) return - prepareStatement("UPDATE $table SET $sql $filter").run { + withStatement("UPDATE $table SET $sql $filter") { for ((idx, value) in bind.withIndex()) { setObject(idx + 1, value) } diff --git a/database-versioning/libeufin-bank-procedures.sql b/database-versioning/libeufin-bank-procedures.sql @@ -28,7 +28,7 @@ CREATE FUNCTION amount_normalize( IN amount taler_amount ,OUT normalized taler_amount ) -LANGUAGE plpgsql AS $$ +LANGUAGE plpgsql IMMUTABLE AS $$ BEGIN normalized.val = amount.val + amount.frac / 100000000; IF (normalized.val > 1::INT8<<52) THEN @@ -46,7 +46,7 @@ CREATE FUNCTION amount_add( ,IN r taler_amount ,OUT sum taler_amount ) -LANGUAGE plpgsql AS $$ +LANGUAGE plpgsql IMMUTABLE AS $$ BEGIN sum = (l.val + r.val, l.frac + r.frac); SELECT normalized.val, normalized.frac INTO sum.val, sum.frac FROM amount_normalize(sum) as normalized; @@ -60,7 +60,7 @@ CREATE FUNCTION amount_left_minus_right( ,OUT diff taler_amount ,OUT ok BOOLEAN ) -LANGUAGE plpgsql AS $$ +LANGUAGE plpgsql IMMUTABLE AS $$ BEGIN IF l.val > r.val THEN ok = TRUE; diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt @@ -38,25 +38,26 @@ class InitiatedDAO(private val db: Database) { } /** Register a new pending payment in the database */ - suspend fun create(paymentData: InitiatedPayment): PaymentInitiationResult = db.conn { conn -> - val stmt = conn.prepareStatement(""" - INSERT INTO initiated_outgoing_transactions ( - amount - ,wire_transfer_subject - ,credit_payto_uri - ,initiation_time - ,request_uid - ) VALUES ((?,?)::taler_amount,?,?,?,?) - RETURNING initiated_outgoing_transaction_id - """) + suspend fun create(paymentData: InitiatedPayment): PaymentInitiationResult = db.serializableWrite( + """ + INSERT INTO initiated_outgoing_transactions ( + amount + ,wire_transfer_subject + ,credit_payto_uri + ,initiation_time + ,request_uid + ) VALUES ((?,?)::taler_amount,?,?,?,?) + RETURNING initiated_outgoing_transaction_id + """ + ) { // TODO check payto uri - stmt.setLong(1, paymentData.amount.value) - stmt.setInt(2, paymentData.amount.frac) - stmt.setString(3, paymentData.wireTransferSubject) - stmt.setString(4, paymentData.creditPaytoUri) - stmt.setLong(5, paymentData.initiationTime.micros()) - stmt.setString(6, paymentData.requestUid) - stmt.oneUniqueViolation(PaymentInitiationResult.RequestUidReuse) { + setLong(1, paymentData.amount.value) + setInt(2, paymentData.amount.frac) + setString(3, paymentData.wireTransferSubject) + setString(4, paymentData.creditPaytoUri) + setLong(5, paymentData.initiationTime.micros()) + setString(6, paymentData.requestUid) + oneUniqueViolation(PaymentInitiationResult.RequestUidReuse) { PaymentInitiationResult.Success(it.getLong("initiated_outgoing_transaction_id")) } } @@ -66,20 +67,21 @@ class InitiatedDAO(private val db: Database) { id: Long, timestamp: Instant, orderId: String - ) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions SET - submitted = 'success'::submission_state - ,last_submission_time = ? - ,failure_message = NULL - ,order_id = ? - ,submission_counter = submission_counter + 1 - WHERE initiated_outgoing_transaction_id = ? - """) - stmt.setLong(1, timestamp.micros()) - stmt.setString(2, orderId) - stmt.setLong(3, id) - stmt.execute() + ) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions SET + submitted = 'success'::submission_state + ,last_submission_time = ? + ,failure_message = NULL + ,order_id = ? + ,submission_counter = submission_counter + 1 + WHERE initiated_outgoing_transaction_id = ? + """ + ) { + setLong(1, timestamp.micros()) + setString(2, orderId) + setLong(3, id) + execute() } /** Register EBICS submission failure */ @@ -87,106 +89,101 @@ class InitiatedDAO(private val db: Database) { id: Long, timestamp: Instant, msg: String? - ) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions SET - submitted = 'transient_failure'::submission_state - ,last_submission_time = ? - ,failure_message = ? - ,submission_counter = submission_counter + 1 - WHERE initiated_outgoing_transaction_id = ? - """) - stmt.setLong(1, timestamp.micros()) - stmt.setString(2, msg) - stmt.setLong(3, id) - stmt.execute() + ) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions SET + submitted = 'transient_failure'::submission_state + ,last_submission_time = ? + ,failure_message = ? + ,submission_counter = submission_counter + 1 + WHERE initiated_outgoing_transaction_id = ? + """ + ) { + setLong(1, timestamp.micros()) + setString(2, msg) + setLong(3, id) + execute() } /** Register EBICS log status message */ - suspend fun logMessage(orderId: String, msg: String) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions SET failure_message = ? - WHERE order_id = ? - """) - stmt.setString(1, msg) - stmt.setString(2, orderId) - stmt.execute() + suspend fun logMessage(orderId: String, msg: String) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions SET failure_message = ? + WHERE order_id = ? + """ + ) { + setString(1, msg) + setString(2, orderId) + execute() } /** Register EBICS log success and return request_uid if found */ - suspend fun logSuccess(orderId: String): String? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT request_uid FROM initiated_outgoing_transactions - WHERE order_id = ? - """) - stmt.setString(1, orderId) - stmt.oneOrNull { it.getString(1) } + suspend fun logSuccess(orderId: String): String? = db.serializableWrite( + """ + SELECT request_uid FROM initiated_outgoing_transactions + WHERE order_id = ? + """ + ) { + setString(1, orderId) + oneOrNull { it.getString(1) } } /** Register EBICS log failure and return request_uid and previous message if found */ - suspend fun logFailure(orderId: String): Pair<String, String?>? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions - SET submitted = 'permanent_failure'::submission_state - WHERE order_id = ? - RETURNING request_uid, failure_message - """) - stmt.setString(1, orderId) - stmt.oneOrNull { Pair(it.getString(1), it.getString(2)) } + suspend fun logFailure(orderId: String): Pair<String, String?>? = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions + SET submitted = 'permanent_failure'::submission_state + WHERE order_id = ? + RETURNING request_uid, failure_message + """ + ) { + setString(1, orderId) + oneOrNull { Pair(it.getString(1), it.getString(2)) } } /** Register bank status message */ - suspend fun bankMessage(requestUID: String, msg: String) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions - SET failure_message = ? - WHERE request_uid = ? - """) - stmt.setString(1, msg) - stmt.setString(2, requestUID) - stmt.execute() + suspend fun bankMessage(requestUID: String, msg: String) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions + SET failure_message = ? + WHERE request_uid = ? + """ + ) { + setString(1, msg) + setString(2, requestUID) + execute() } /** Register bank failure */ - suspend fun bankFailure(requestUID: String, msg: String) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions SET - submitted = 'permanent_failure'::submission_state - ,failure_message = ? - WHERE request_uid = ? - """) - stmt.setString(1, msg) - stmt.setString(2, requestUID) - stmt.execute() + suspend fun bankFailure(requestUID: String, msg: String) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions SET + submitted = 'permanent_failure'::submission_state + ,failure_message = ? + WHERE request_uid = ? + """ + ) { + setString(1, msg) + setString(2, requestUID) + execute() } /** Register reversal */ - suspend fun reversal(requestUID: String, msg: String) = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE initiated_outgoing_transactions SET - submitted = 'permanent_failure'::submission_state - ,failure_message = ? - WHERE request_uid = ? - """) - stmt.setString(1, msg) - stmt.setString(2, requestUID) - stmt.execute() + suspend fun reversal(requestUID: String, msg: String) = db.serializableWrite( + """ + UPDATE initiated_outgoing_transactions SET + submitted = 'permanent_failure'::submission_state + ,failure_message = ? + WHERE request_uid = ? + """ + ) { + setString(1, msg) + setString(2, requestUID) + execute() } /** List every initiated payment pending submission in the order they should be submitted */ - suspend fun submittable(currency: String): List<InitiatedPayment> = db.conn { conn -> - fun extract(it: ResultSet): InitiatedPayment { - val rowId = it.getLong("initiated_outgoing_transaction_id") - val initiationTime = it.getLong("initiation_time").asInstant() - return InitiatedPayment( - id = it.getLong("initiated_outgoing_transaction_id"), - amount = it.getAmount("amount", currency), - creditPaytoUri = it.getString("credit_payto_uri"), - wireTransferSubject = it.getString("wire_transfer_subject"), - initiationTime = initiationTime, - requestUid = it.getString("request_uid") - ) - } + suspend fun submittable(currency: String): List<InitiatedPayment> { val selectPart = """ SELECT initiated_outgoing_transaction_id @@ -206,12 +203,23 @@ class InitiatedDAO(private val db: Database) { // Then we retry the failed transaction, starting with the oldest by submission time. // This the bad path retrying each failed transaction applying a rotation based on // resubmission time. - val unsubmitted = conn.prepareStatement( - "$selectPart WHERE submitted='unsubmitted' ORDER BY initiation_time" - ).all(::extract) - val failed = conn.prepareStatement( - "$selectPart WHERE submitted='transient_failure' ORDER BY last_submission_time" - ).all(::extract) - unsubmitted + failed + return db.serializableRead( + """ + ($selectPart WHERE submitted='unsubmitted' ORDER BY initiation_time) + UNION ALL + ($selectPart WHERE submitted='transient_failure' ORDER BY last_submission_time) + """ + ) { + all { + InitiatedPayment( + id = it.getLong("initiated_outgoing_transaction_id"), + amount = it.getAmount("amount", currency), + creditPaytoUri = it.getString("credit_payto_uri"), + wireTransferSubject = it.getString("wire_transfer_subject"), + initiationTime = it.getLong("initiation_time").asInstant(), + requestUid = it.getString("request_uid") + ) + } + } } } \ No newline at end of file diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt @@ -39,30 +39,24 @@ class PaymentDAO(private val db: Database) { paymentData: OutgoingPayment, wtid: ShortHashCode?, baseUrl: ExchangeUrl?, - ): OutgoingRegistrationResult = db.conn { - val stmt = it.prepareStatement(""" - SELECT out_tx_id, out_initiated, out_found - FROM register_outgoing((?,?)::taler_amount,?,?,?,?,?,?) - """) + ): OutgoingRegistrationResult = db.serializableWrite( + """ + SELECT out_tx_id, out_initiated, out_found + FROM register_outgoing((?,?)::taler_amount,?,?,?,?,?,?) + """ + ) { val executionTime = paymentData.executionTime.micros() - stmt.setLong(1, paymentData.amount.value) - stmt.setInt(2, paymentData.amount.frac) - stmt.setString(3, paymentData.wireTransferSubject) - stmt.setLong(4, executionTime) - stmt.setString(5, paymentData.creditPaytoUri) - stmt.setString(6, paymentData.messageId) - if (wtid != null) { - stmt.setBytes(7, wtid.raw) - } else { - stmt.setNull(7, java.sql.Types.NULL) - } - if (baseUrl != null) { - stmt.setString(8, baseUrl.url) - } else { - stmt.setNull(8, java.sql.Types.NULL) - } + + setLong(1, paymentData.amount.value) + setInt(2, paymentData.amount.frac) + setString(3, paymentData.wireTransferSubject) + setLong(4, executionTime) + setString(5, paymentData.creditPaytoUri) + setString(6, paymentData.messageId) + setBytes(7, wtid?.raw) + setString(8, baseUrl?.url) - stmt.one { + one { OutgoingRegistrationResult( it.getLong("out_tx_id"), it.getBoolean("out_initiated"), @@ -83,21 +77,22 @@ class PaymentDAO(private val db: Database) { paymentData: IncomingPayment, bounceAmount: TalerAmount, timestamp: Instant - ): IncomingBounceRegistrationResult = db.conn { - val stmt = it.prepareStatement(""" - SELECT out_found, out_tx_id, out_bounce_id - FROM register_incoming_and_bounce((?,?)::taler_amount,?,?,?,?,(?,?)::taler_amount,?) - """) - stmt.setLong(1, paymentData.amount.value) - stmt.setInt(2, paymentData.amount.frac) - stmt.setString(3, paymentData.wireTransferSubject) - stmt.setLong(4, paymentData.executionTime.micros()) - stmt.setString(5, paymentData.debitPaytoUri) - stmt.setString(6, paymentData.bankId) - stmt.setLong(7, bounceAmount.value) - stmt.setInt(8, bounceAmount.frac) - stmt.setLong(9, timestamp.micros()) - stmt.one { + ): IncomingBounceRegistrationResult = db.serializableWrite( + """ + SELECT out_found, out_tx_id, out_bounce_id + FROM register_incoming_and_bounce((?,?)::taler_amount,?,?,?,?,(?,?)::taler_amount,?) + """ + ) { + setLong(1, paymentData.amount.value) + setInt(2, paymentData.amount.frac) + setString(3, paymentData.wireTransferSubject) + setLong(4, paymentData.executionTime.micros()) + setString(5, paymentData.debitPaytoUri) + setString(6, paymentData.bankId) + setLong(7, bounceAmount.value) + setInt(8, bounceAmount.frac) + setLong(9, timestamp.micros()) + one { IncomingBounceRegistrationResult( it.getLong("out_tx_id"), it.getString("out_bounce_id"), @@ -116,20 +111,21 @@ class PaymentDAO(private val db: Database) { suspend fun registerTalerableIncoming( paymentData: IncomingPayment, reservePub: EddsaPublicKey - ): IncomingRegistrationResult = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT out_reserve_pub_reuse, out_found, out_tx_id - FROM register_incoming_and_talerable((?,?)::taler_amount,?,?,?,?,?) - """) + ): IncomingRegistrationResult = db.serializableWrite( + """ + SELECT out_reserve_pub_reuse, out_found, out_tx_id + FROM register_incoming_and_talerable((?,?)::taler_amount,?,?,?,?,?) + """ + ) { val executionTime = paymentData.executionTime.micros() - stmt.setLong(1, paymentData.amount.value) - stmt.setInt(2, paymentData.amount.frac) - stmt.setString(3, paymentData.wireTransferSubject) - stmt.setLong(4, executionTime) - stmt.setString(5, paymentData.debitPaytoUri) - stmt.setString(6, paymentData.bankId) - stmt.setBytes(7, reservePub.raw) - stmt.one { + setLong(1, paymentData.amount.value) + setInt(2, paymentData.amount.frac) + setString(3, paymentData.wireTransferSubject) + setLong(4, executionTime) + setString(5, paymentData.debitPaytoUri) + setString(6, paymentData.bankId) + setBytes(7, reservePub.raw) + one { when { it.getBoolean("out_reserve_pub_reuse") -> IncomingRegistrationResult.ReservePubReuse else -> IncomingRegistrationResult.Success( @@ -143,19 +139,20 @@ class PaymentDAO(private val db: Database) { /** Register an incoming payment */ suspend fun registerIncoming( paymentData: IncomingPayment - ): IncomingRegistrationResult.Success = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT out_found, out_tx_id - FROM register_incoming((?,?)::taler_amount,?,?,?,?) - """) + ): IncomingRegistrationResult.Success = db.serializableWrite( + """ + SELECT out_found, out_tx_id + FROM register_incoming((?,?)::taler_amount,?,?,?,?) + """ + ) { val executionTime = paymentData.executionTime.micros() - stmt.setLong(1, paymentData.amount.value) - stmt.setInt(2, paymentData.amount.frac) - stmt.setString(3, paymentData.wireTransferSubject) - stmt.setLong(4, executionTime) - stmt.setString(5, paymentData.debitPaytoUri) - stmt.setString(6, paymentData.bankId) - stmt.one { + setLong(1, paymentData.amount.value) + setInt(2, paymentData.amount.frac) + setString(3, paymentData.wireTransferSubject) + setLong(4, executionTime) + setString(5, paymentData.debitPaytoUri) + setString(6, paymentData.bankId) + one { IncomingRegistrationResult.Success( it.getLong("out_tx_id"), !it.getBoolean("out_found") @@ -187,21 +184,22 @@ class PaymentDAO(private val db: Database) { } /** List incoming transaction metadata for debugging */ - suspend fun metadataIncoming(): List<IncomingTxMetadata> = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - (amount).val as amount_val - ,(amount).frac AS amount_frac - ,wire_transfer_subject - ,execution_time - ,debit_payto_uri - ,bank_id - ,reserve_public_key - FROM incoming_transactions - LEFT OUTER JOIN talerable_incoming_transactions using (incoming_transaction_id) - ORDER BY execution_time - """) - stmt.all { + suspend fun metadataIncoming(): List<IncomingTxMetadata> = db.serializableRead( + """ + SELECT + (amount).val as amount_val + ,(amount).frac AS amount_frac + ,wire_transfer_subject + ,execution_time + ,debit_payto_uri + ,bank_id + ,reserve_public_key + FROM incoming_transactions + LEFT OUTER JOIN talerable_incoming_transactions using (incoming_transaction_id) + ORDER BY execution_time + """ + ) { + all { IncomingTxMetadata( date = it.getLong("execution_time").asInstant(), amount = it.getDecimal("amount"), @@ -214,22 +212,23 @@ class PaymentDAO(private val db: Database) { } /** List outgoing transaction metadata for debugging */ - suspend fun metadataOutgoing(): List<OutgoingTxMetadata> = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - (amount).val as amount_val - ,(amount).frac AS amount_frac - ,wire_transfer_subject - ,execution_time - ,credit_payto_uri - ,message_id - ,wtid - ,exchange_base_url - FROM outgoing_transactions - LEFT OUTER JOIN talerable_outgoing_transactions using (outgoing_transaction_id) - ORDER BY execution_time - """) - stmt.all { + suspend fun metadataOutgoing(): List<OutgoingTxMetadata> = db.serializableRead( + """ + SELECT + (amount).val as amount_val + ,(amount).frac AS amount_frac + ,wire_transfer_subject + ,execution_time + ,credit_payto_uri + ,message_id + ,wtid + ,exchange_base_url + FROM outgoing_transactions + LEFT OUTER JOIN talerable_outgoing_transactions using (outgoing_transaction_id) + ORDER BY execution_time + """ + ) { + all { OutgoingTxMetadata( date = it.getLong("execution_time").asInstant(), amount = it.getDecimal("amount"), @@ -243,23 +242,24 @@ class PaymentDAO(private val db: Database) { } /** List initiated transaction metadata for debugging */ - suspend fun metadataInitiated(): List<InitiatedTxMetadata> = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - (amount).val as amount_val - ,(amount).frac AS amount_frac - ,wire_transfer_subject - ,initiation_time - ,last_submission_time - ,submission_counter - ,credit_payto_uri - ,submitted - ,request_uid - ,failure_message - FROM initiated_outgoing_transactions - ORDER BY initiation_time - """) - stmt.all { + suspend fun metadataInitiated(): List<InitiatedTxMetadata> = db.serializableRead( + """ + SELECT + (amount).val as amount_val + ,(amount).frac AS amount_frac + ,wire_transfer_subject + ,initiation_time + ,last_submission_time + ,submission_counter + ,credit_payto_uri + ,submitted + ,request_uid + ,failure_message + FROM initiated_outgoing_transactions + ORDER BY initiation_time + """ + ) { + all { InitiatedTxMetadata( date = it.getLong("initiation_time").asInstant(), amount = it.getDecimal("amount"),