libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

commit ac8a642657d8b346a5cf83d3864f6426b1733cc0
parent 5239ccfb2ca324170330cbce64c22a3bfc2af0b6
Author: Antoine A <>
Date:   Thu, 20 Jun 2024 17:35:24 +0200

bank: handle serialization errors even for read-only queries and improve db utilities

Diffstat:
Mbank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt | 426+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt | 107++++++++++++++++++++++++++++++++++++++-----------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt | 141+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/Database.kt | 49+++++++++++++++++++++++++------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt | 103+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt | 109+++++++++++++++++++++++++++++++++++++++++--------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt | 98++++++++++++++++++++++++++++++++++++++++---------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt | 216+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mbank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt | 291+++++++++++++++++++++++++++++++++++++++----------------------------------------
Mbank/src/test/kotlin/DatabaseTest.kt | 43++++++++++++++++++++++++++++++++++++++++---
Mcommon/src/main/kotlin/db/DbPool.kt | 40++++++++++++++++++++++------------------
Mcommon/src/main/kotlin/db/helpers.kt | 4++--
Mcommon/src/main/kotlin/db/transaction.kt | 43+++++++++++++++++++++++++++++++++++--------
Mnexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt | 47++++++++++++++++++++++++-----------------------
14 files changed, 889 insertions(+), 828 deletions(-)

diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -53,137 +53,134 @@ class AccountDAO(private val db: Database) { // Whether to check [internalPaytoUri] for idempotency checkPaytoIdempotent: Boolean, ctx: BankPaytoCtx - ): AccountCreationResult = db.serializable { it -> + ): AccountCreationResult = db.serializableTransaction { conn -> val timestamp = Instant.now().micros() - it.transaction { conn -> - val idempotent = conn.prepareStatement(""" - SELECT password_hash, name=? - AND email IS NOT DISTINCT FROM ? - AND phone IS NOT DISTINCT FROM ? - AND cashout_payto IS NOT DISTINCT FROM ? - AND (NOT ? OR internal_payto_uri=?) - AND is_public=? - AND is_taler_exchange=? - AND tan_channel IS NOT DISTINCT FROM ?::tan_enum - ,internal_payto_uri, name - FROM customers - JOIN bank_accounts - ON customer_id=owning_customer_id - WHERE login=? - """).run { - // TODO check max debt and min checkout ? - setString(1, name) - setString(2, email) - setString(3, phone) - setString(4, cashoutPayto?.full(name)) - setBoolean(5, checkPaytoIdempotent) - setString(6, internalPayto.canonical) - setBoolean(7, isPublic) - setBoolean(8, isTalerExchange) - setString(9, tanChannel?.name) - setString(10, login) - oneOrNull { - Pair( - PwCrypto.checkpw(password, it.getString(1)) && it.getBoolean(2), - it.getBankPayto("internal_payto_uri", "name", ctx) - ) - } - } - - if (idempotent != null) { - if (idempotent.first) { - AccountCreationResult.Success(idempotent.second) - } else { - AccountCreationResult.LoginReuse - } + val idempotent = conn.prepareStatement(""" + SELECT password_hash, name=? + AND email IS NOT DISTINCT FROM ? + AND phone IS NOT DISTINCT FROM ? + AND cashout_payto IS NOT DISTINCT FROM ? + AND (NOT ? OR internal_payto_uri=?) + AND is_public=? + AND is_taler_exchange=? + AND tan_channel IS NOT DISTINCT FROM ?::tan_enum + ,internal_payto_uri, name + FROM customers + JOIN bank_accounts + ON customer_id=owning_customer_id + WHERE login=? + """).run { + // TODO check max debt and min checkout ? + setString(1, name) + setString(2, email) + setString(3, phone) + setString(4, cashoutPayto?.full(name)) + setBoolean(5, checkPaytoIdempotent) + setString(6, internalPayto.canonical) + setBoolean(7, isPublic) + setBoolean(8, isTalerExchange) + setString(9, tanChannel?.name) + setString(10, login) + oneOrNull { + Pair( + PwCrypto.checkpw(password, it.getString(1)) && it.getBoolean(2), + it.getBankPayto("internal_payto_uri", "name", ctx) + ) + } + } + + if (idempotent != null) { + if (idempotent.first) { + AccountCreationResult.Success(idempotent.second) } else { - if (internalPayto is IbanPayto) - conn.prepareStatement(""" - INSERT INTO iban_history( - iban - ,creation_time - ) VALUES (?, ?) - """).run { - setString(1, internalPayto.iban.value) - setLong(2, timestamp) - if (!executeUpdateViolation()) { - conn.rollback() - return@transaction AccountCreationResult.PayToReuse - } + AccountCreationResult.LoginReuse + } + } else { + if (internalPayto is IbanPayto) + conn.prepareStatement(""" + INSERT INTO iban_history( + iban + ,creation_time + ) VALUES (?, ?) + """).run { + setString(1, internalPayto.iban.value) + setLong(2, timestamp) + if (!executeUpdateViolation()) { + conn.rollback() + return@serializableTransaction AccountCreationResult.PayToReuse } + } - val customerId = conn.prepareStatement(""" - INSERT INTO customers ( - login - ,password_hash - ,name - ,email - ,phone - ,cashout_payto - ,tan_channel - ) VALUES (?, ?, ?, ?, ?, ?, ?::tan_enum) - RETURNING customer_id - """ - ).run { - setString(1, login) - setString(2, PwCrypto.hashpw(password)) - setString(3, name) - setString(4, email) - setString(5, phone) - setString(6, cashoutPayto?.full(name)) - setString(7, tanChannel?.name) - oneOrNull { it.getLong("customer_id") }!! + val customerId = conn.prepareStatement(""" + INSERT INTO customers ( + login + ,password_hash + ,name + ,email + ,phone + ,cashout_payto + ,tan_channel + ) VALUES (?, ?, ?, ?, ?, ?, ?::tan_enum) + RETURNING customer_id + """ + ).run { + setString(1, login) + setString(2, PwCrypto.hashpw(password)) + setString(3, name) + setString(4, email) + setString(5, phone) + setString(6, cashoutPayto?.full(name)) + setString(7, tanChannel?.name) + oneOrNull { it.getLong("customer_id") }!! + } + + conn.prepareStatement(""" + INSERT INTO bank_accounts( + internal_payto_uri + ,owning_customer_id + ,is_public + ,is_taler_exchange + ,max_debt + ,min_cashout + ) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${if (minCashout == null) "NULL" else "(?, ?)::taler_amount"}) + """).run { + setString(1, internalPayto.canonical) + setLong(2, customerId) + setBoolean(3, isPublic) + setBoolean(4, isTalerExchange) + setLong(5, maxDebt.value) + setInt(6, maxDebt.frac) + if (minCashout != null) { + setLong(7, minCashout.value) + setInt(8, minCashout.frac) } + if (!executeUpdateViolation()) { + conn.rollback() + return@serializableTransaction AccountCreationResult.PayToReuse + } + } + if (bonus.value != 0L || bonus.frac != 0) { conn.prepareStatement(""" - INSERT INTO bank_accounts( - internal_payto_uri - ,owning_customer_id - ,is_public - ,is_taler_exchange - ,max_debt - ,min_cashout - ) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${if (minCashout == null) "NULL" else "(?, ?)::taler_amount"}) + SELECT out_balance_insufficient + FROM bank_transaction(?,'admin','bonus',(?,?)::taler_amount,?,true,NULL,(0, 0)::taler_amount) """).run { setString(1, internalPayto.canonical) - setLong(2, customerId) - setBoolean(3, isPublic) - setBoolean(4, isTalerExchange) - setLong(5, maxDebt.value) - setInt(6, maxDebt.frac) - if (minCashout != null) { - setLong(7, minCashout.value) - setInt(8, minCashout.frac) - } - if (!executeUpdateViolation()) { - conn.rollback() - return@transaction AccountCreationResult.PayToReuse - } - } - - if (bonus.value != 0L || bonus.frac != 0) { - conn.prepareStatement(""" - SELECT out_balance_insufficient - FROM bank_transaction(?,'admin','bonus',(?,?)::taler_amount,?,true,NULL,(0, 0)::taler_amount) - """).run { - setString(1, internalPayto.canonical) - setLong(2, bonus.value) - setInt(3, bonus.frac) - setLong(4, timestamp) - executeQuery().use { - when { - !it.next() -> throw internalServerError("Bank transaction didn't properly return") - it.getBoolean("out_balance_insufficient") -> { - conn.rollback() - AccountCreationResult.BonusBalanceInsufficient - } - else -> AccountCreationResult.Success(internalPayto.bank(name, ctx)) + setLong(2, bonus.value) + setInt(3, bonus.frac) + setLong(4, timestamp) + one { + when { + it.getBoolean("out_balance_insufficient") -> { + conn.rollback() + AccountCreationResult.BonusBalanceInsufficient } + else -> AccountCreationResult.Success(internalPayto.bank(name, ctx)) } } - } else { - AccountCreationResult.Success(internalPayto.bank(name, ctx)) } + } else { + AccountCreationResult.Success(internalPayto.bank(name, ctx)) } } } @@ -200,20 +197,20 @@ class AccountDAO(private val db: Database) { suspend fun delete( login: String, is2fa: Boolean - ): AccountDeletionResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_not_found, - out_balance_not_zero, - out_tan_required - FROM account_delete(?,?,?) - """) - stmt.setString(1, login) - stmt.setLong(2, Instant.now().micros()) - stmt.setBoolean(3, is2fa) - stmt.executeQuery().use { + ): AccountDeletionResult = db.serializableWrite( + """ + SELECT + out_not_found, + out_balance_not_zero, + out_tan_required + FROM account_delete(?,?,?) + """ + ) { + setString(1, login) + setLong(2, Instant.now().micros()) + setBoolean(3, is2fa) + one { when { - !it.next() -> throw internalServerError("Deletion returned nothing.") it.getBoolean("out_not_found") -> AccountDeletionResult.UnknownAccount it.getBoolean("out_balance_not_zero") -> AccountDeletionResult.BalanceNotZero it.getBoolean("out_tan_required") -> AccountDeletionResult.TanRequired @@ -251,7 +248,7 @@ class AccountDAO(private val db: Database) { faInfo: String?, allowEditName: Boolean, allowEditCashout: Boolean, - ): AccountPatchResult = db.serializable { it.transaction { conn -> + ): AccountPatchResult = db.serializableTransaction { conn -> val checkName = !isAdmin && !allowEditName && name != null val checkCashout = !isAdmin && !allowEditCashout && cashoutPayto.isSome() val checkDebtLimit = !isAdmin && debtLimit != null @@ -293,7 +290,7 @@ class AccountDAO(private val db: Database) { debtLimit = it.getAmount("max_debt", db.bankCurrency), minCashout = it.getOptAmount("min_cashout", db.bankCurrency), ) - } ?: return@transaction AccountPatchResult.UnknownAccount + } ?: return@serializableTransaction AccountPatchResult.UnknownAccount } // Patched TAN channel @@ -317,26 +314,26 @@ class AccountDAO(private val db: Database) { // Check reconfig rights if (checkName && name != curr.name) - return@transaction AccountPatchResult.NonAdminName + return@serializableTransaction AccountPatchResult.NonAdminName if (checkCashout && fullCashoutPayto != curr.cashoutPayTo) - return@transaction AccountPatchResult.NonAdminCashout + return@serializableTransaction AccountPatchResult.NonAdminCashout if (checkDebtLimit && debtLimit != curr.debtLimit) - return@transaction AccountPatchResult.NonAdminDebtLimit + return@serializableTransaction AccountPatchResult.NonAdminDebtLimit if (checkMinCashout && minCashout.get() != curr.minCashout) - return@transaction AccountPatchResult.NonAdminMinCashout + return@serializableTransaction AccountPatchResult.NonAdminMinCashout if (patchChannel != null && newInfo == null) - return@transaction AccountPatchResult.MissingTanInfo + return@serializableTransaction AccountPatchResult.MissingTanInfo // Tan channel verification if (!isAdmin) { // Check performed 2fa check if (curr.channel != null && !is2fa) { // Perform challenge with current settings - return@transaction AccountPatchResult.TanRequired(channel = null, info = null) + return@serializableTransaction AccountPatchResult.TanRequired(channel = null, info = null) } // If channel or info changed and the 2fa challenge is performed with old settings perform a new challenge with new settings if ((patchChannel != null && patchChannel != faChannel) || (patchInfo != null && patchInfo != faInfo)) { - return@transaction AccountPatchResult.TanRequired(channel = newChannel, info = newInfo) + return@serializableTransaction AccountPatchResult.TanRequired(channel = newChannel, info = newInfo) } } @@ -397,7 +394,7 @@ class AccountDAO(private val db: Database) { ) AccountPatchResult.Success - }} + } /** Result status of customer account auth patch */ @@ -414,59 +411,57 @@ class AccountDAO(private val db: Database) { newPw: String, oldPw: String?, is2fa: Boolean - ): AccountPatchAuthResult = db.serializable { - it.transaction { conn -> - val (currentPwh, tanRequired) = conn.prepareStatement(""" - SELECT password_hash, (NOT ? AND tan_channel IS NOT NULL) - FROM customers WHERE login=? AND deleted_at IS NULL - """).run { - setBoolean(1, is2fa) - setString(2, login) - oneOrNull { - Pair(it.getString(1), it.getBoolean(2)) - } ?: return@transaction AccountPatchAuthResult.UnknownAccount - } - if (tanRequired) { - AccountPatchAuthResult.TanRequired - } else if (oldPw != null && !PwCrypto.checkpw(oldPw, currentPwh)) { - AccountPatchAuthResult.OldPasswordMismatch - } else { - val stmt = conn.prepareStatement(""" - UPDATE customers SET password_hash=? where login=? - """) - stmt.setString(1, PwCrypto.hashpw(newPw)) - stmt.setString(2, login) - stmt.executeUpdateCheck() - AccountPatchAuthResult.Success - } + ): AccountPatchAuthResult = db.serializableTransaction { conn -> + val (currentPwh, tanRequired) = conn.prepareStatement(""" + SELECT password_hash, (NOT ? AND tan_channel IS NOT NULL) + FROM customers WHERE login=? AND deleted_at IS NULL + """).run { + setBoolean(1, is2fa) + setString(2, login) + oneOrNull { + Pair(it.getString(1), it.getBoolean(2)) + } ?: return@serializableTransaction AccountPatchAuthResult.UnknownAccount + } + if (tanRequired) { + AccountPatchAuthResult.TanRequired + } else if (oldPw != null && !PwCrypto.checkpw(oldPw, currentPwh)) { + AccountPatchAuthResult.OldPasswordMismatch + } else { + val stmt = conn.prepareStatement(""" + UPDATE customers SET password_hash=? where login=? + """) + stmt.setString(1, PwCrypto.hashpw(newPw)) + stmt.setString(2, login) + stmt.executeUpdateCheck() + AccountPatchAuthResult.Success } } /** Get password hash of account [login] */ - suspend fun passwordHash(login: String): String? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT password_hash FROM customers WHERE login=? AND deleted_at IS NULL - """) - stmt.setString(1, login) - stmt.oneOrNull { + suspend fun passwordHash(login: String): String? = db.serializableRead( + "SELECT password_hash FROM customers WHERE login=? AND deleted_at IS NULL" + ) { + setString(1, login) + oneOrNull { it.getString(1) } } /** Get bank info of account [login] */ - suspend fun bankInfo(login: String, ctx: BankPaytoCtx): BankInfo? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - bank_account_id - ,internal_payto_uri - ,name - ,is_taler_exchange - FROM bank_accounts - JOIN customers ON customer_id=owning_customer_id - WHERE login=? - """) - stmt.setString(1, login) - stmt.oneOrNull { + suspend fun bankInfo(login: String, ctx: BankPaytoCtx): BankInfo? = db.serializableRead( + """ + SELECT + bank_account_id + ,internal_payto_uri + ,name + ,is_taler_exchange + FROM bank_accounts + JOIN customers ON customer_id=owning_customer_id + WHERE login=? + """ + ) { + setString(1, login) + oneOrNull { BankInfo( payto = it.getBankPayto("internal_payto_uri", "name", ctx), isTalerExchange = it.getBoolean("is_taler_exchange"), @@ -476,35 +471,36 @@ class AccountDAO(private val db: Database) { } /** Get data of account [login] */ - suspend fun get(login: String, ctx: BankPaytoCtx): AccountData? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - name - ,email - ,phone - ,tan_channel - ,cashout_payto - ,internal_payto_uri - ,(balance).val AS balance_val - ,(balance).frac AS balance_frac - ,has_debt - ,(max_debt).val AS max_debt_val - ,(max_debt).frac AS max_debt_frac - ,(min_cashout).val AS min_cashout_val - ,(min_cashout).frac AS min_cashout_frac - ,is_public - ,is_taler_exchange - ,CASE - WHEN deleted_at IS NOT NULL THEN 'deleted' - ELSE 'active' - END as status - FROM customers - JOIN bank_accounts - ON customer_id=owning_customer_id - WHERE login=? - """) - stmt.setString(1, login) - stmt.oneOrNull { + suspend fun get(login: String, ctx: BankPaytoCtx): AccountData? = db.serializableRead( + """ + SELECT + name + ,email + ,phone + ,tan_channel + ,cashout_payto + ,internal_payto_uri + ,(balance).val AS balance_val + ,(balance).frac AS balance_frac + ,has_debt + ,(max_debt).val AS max_debt_val + ,(max_debt).frac AS max_debt_frac + ,(min_cashout).val AS min_cashout_val + ,(min_cashout).frac AS min_cashout_frac + ,is_public + ,is_taler_exchange + ,CASE + WHEN deleted_at IS NOT NULL THEN 'deleted' + ELSE 'active' + END as status + FROM customers + JOIN bank_accounts + ON customer_id=owning_customer_id + WHERE login=? + """ + ) { + setString(1, login) + oneOrNull { AccountData( name = it.getString("name"), contact_data = ChallengeContactData( diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -21,10 +21,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.* import tech.libeufin.common.* -import tech.libeufin.common.db.getAmount -import tech.libeufin.common.db.getTalerTimestamp -import tech.libeufin.common.db.oneOrNull -import tech.libeufin.common.db.page +import tech.libeufin.common.db.* import java.time.Instant /** Data access logic for cashout operations */ @@ -51,33 +48,32 @@ class CashoutDAO(private val db: Database) { subject: String, timestamp: Instant, is2fa: Boolean - ): CashoutCreationResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_bad_conversion, - out_account_not_found, - out_account_is_exchange, - out_balance_insufficient, - out_request_uid_reuse, - out_no_cashout_payto, - out_tan_required, - out_cashout_id, - out_under_min - FROM cashout_create(?,?,(?,?)::taler_amount,(?,?)::taler_amount,?,?,?) - """) - stmt.setString(1, login) - stmt.setBytes(2, requestUid.raw) - stmt.setLong(3, amountDebit.value) - stmt.setInt(4, amountDebit.frac) - stmt.setLong(5, amountCredit.value) - stmt.setInt(6, amountCredit.frac) - stmt.setString(7, subject) - stmt.setLong(8, timestamp.micros()) - stmt.setBoolean(9, is2fa) - stmt.executeQuery().use { + ): CashoutCreationResult = db.serializableWrite( + """ + SELECT + out_bad_conversion, + out_account_not_found, + out_account_is_exchange, + out_balance_insufficient, + out_request_uid_reuse, + out_no_cashout_payto, + out_tan_required, + out_cashout_id, + out_under_min + FROM cashout_create(?,?,(?,?)::taler_amount,(?,?)::taler_amount,?,?,?) + """ + ) { + setString(1, login) + setBytes(2, requestUid.raw) + setLong(3, amountDebit.value) + setInt(4, amountDebit.frac) + setLong(5, amountCredit.value) + setInt(6, amountCredit.frac) + setString(7, subject) + setLong(8, timestamp.micros()) + setBoolean(9, is2fa) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure cashout_create") it.getBoolean("out_under_min") -> CashoutCreationResult.UnderMin it.getBoolean("out_bad_conversion") -> CashoutCreationResult.BadConversion it.getBoolean("out_account_not_found") -> CashoutCreationResult.AccountNotFound @@ -92,30 +88,31 @@ class CashoutDAO(private val db: Database) { } /** Get status of cashout operation [id] owned by [login] */ - suspend fun get(id: Long, login: String): CashoutStatusResponse? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - (amount_debit).val as amount_debit_val - ,(amount_debit).frac as amount_debit_frac - ,(amount_credit).val as amount_credit_val - ,(amount_credit).frac as amount_credit_frac - ,cashout_operations.subject - ,creation_time - ,transaction_date as confirmation_date - ,tan_channel - ,CASE tan_channel - WHEN 'sms' THEN phone - WHEN 'email' THEN email - END as tan_info - FROM cashout_operations - JOIN bank_accounts ON bank_account=bank_account_id - JOIN customers ON owning_customer_id=customer_id - LEFT JOIN bank_account_transactions ON local_transaction=bank_transaction_id - WHERE cashout_id=? AND login=? - """) - stmt.setLong(1, id) - stmt.setString(2, login) - stmt.oneOrNull { + suspend fun get(id: Long, login: String): CashoutStatusResponse? = db.serializableRead( + """ + SELECT + (amount_debit).val as amount_debit_val + ,(amount_debit).frac as amount_debit_frac + ,(amount_credit).val as amount_credit_val + ,(amount_credit).frac as amount_credit_frac + ,cashout_operations.subject + ,creation_time + ,transaction_date as confirmation_date + ,tan_channel + ,CASE tan_channel + WHEN 'sms' THEN phone + WHEN 'email' THEN email + END as tan_info + FROM cashout_operations + JOIN bank_accounts ON bank_account=bank_account_id + JOIN customers ON owning_customer_id=customer_id + LEFT JOIN bank_account_transactions ON local_transaction=bank_transaction_id + WHERE cashout_id=? AND login=? + """ + ) { + setLong(1, id) + setString(2, login) + oneOrNull { CashoutStatusResponse( status = CashoutStatus.confirmed, amount_debit = it.getAmount("amount_debit", db.bankCurrency), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -27,77 +27,75 @@ import tech.libeufin.common.db.* /** Data access logic for conversion */ class ConversionDAO(private val db: Database) { /** Update in-db conversion config */ - suspend fun updateConfig(cfg: ConversionRate) = db.serializable { - it.transaction { conn -> - var stmt = conn.prepareStatement("CALL config_set_amount(?, (?, ?)::taler_amount)") - for ((name, amount) in listOf( - Pair("cashin_ratio", cfg.cashin_ratio), - Pair("cashout_ratio", cfg.cashout_ratio), - )) { - stmt.setString(1, name) - stmt.setLong(2, amount.value) - stmt.setInt(3, amount.frac) - stmt.executeUpdate() - } - for ((name, amount) in listOf( - Pair("cashin_fee", cfg.cashin_fee), - Pair("cashin_tiny_amount", cfg.cashin_tiny_amount), - Pair("cashin_min_amount", cfg.cashin_min_amount), - Pair("cashout_fee", cfg.cashout_fee), - Pair("cashout_tiny_amount", cfg.cashout_tiny_amount), - Pair("cashout_min_amount", cfg.cashout_min_amount), - )) { - stmt.setString(1, name) - stmt.setLong(2, amount.value) - stmt.setInt(3, amount.frac) - stmt.executeUpdate() - } - stmt = conn.prepareStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)") - for ((name, value) in listOf( - Pair("cashin_rounding_mode", cfg.cashin_rounding_mode), - Pair("cashout_rounding_mode", cfg.cashout_rounding_mode) - )) { - stmt.setString(1, name) - stmt.setString(2, value.name) - stmt.executeUpdate() - } + suspend fun updateConfig(cfg: ConversionRate) = db.serializableTransaction { conn -> + var stmt = conn.prepareStatement("CALL config_set_amount(?, (?, ?)::taler_amount)") + for ((name, amount) in listOf( + Pair("cashin_ratio", cfg.cashin_ratio), + Pair("cashout_ratio", cfg.cashout_ratio), + )) { + stmt.setString(1, name) + stmt.setLong(2, amount.value) + stmt.setInt(3, amount.frac) + stmt.executeUpdate() + } + for ((name, amount) in listOf( + Pair("cashin_fee", cfg.cashin_fee), + Pair("cashin_tiny_amount", cfg.cashin_tiny_amount), + Pair("cashin_min_amount", cfg.cashin_min_amount), + Pair("cashout_fee", cfg.cashout_fee), + Pair("cashout_tiny_amount", cfg.cashout_tiny_amount), + Pair("cashout_min_amount", cfg.cashout_min_amount), + )) { + stmt.setString(1, name) + stmt.setLong(2, amount.value) + stmt.setInt(3, amount.frac) + stmt.executeUpdate() + } + stmt = conn.prepareStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)") + for ((name, value) in listOf( + Pair("cashin_rounding_mode", cfg.cashin_rounding_mode), + Pair("cashout_rounding_mode", cfg.cashout_rounding_mode) + )) { + stmt.setString(1, name) + stmt.setString(2, value.name) + stmt.executeUpdate() } } /** Get in-db conversion config */ - suspend fun getConfig(regional: String, fiat: String): ConversionRate? = db.conn { - it.transaction { conn -> - val check = conn.prepareStatement("select exists(select 1 from config where key='cashin_ratio')").oneOrNull { it.getBoolean(1) }!! - if (!check) return@transaction null - val amount = conn.prepareStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount") - val roundingMode = conn.prepareStatement("SELECT config_get_rounding_mode(?)") - fun getAmount(name: String, currency: String): TalerAmount { - amount.setString(1, name) - return amount.oneOrNull { it.getAmount("amount", currency) }!! - } - fun getRatio(name: String): DecimalNumber = getAmount(name, "").run { DecimalNumber(value, frac) } - fun getMode(name: String): RoundingMode { - roundingMode.setString(1, name) - return roundingMode.oneOrNull { RoundingMode.valueOf(it.getString(1)) }!! - } - ConversionRate( - cashin_ratio = getRatio("cashin_ratio"), - cashin_fee = getAmount("cashin_fee", regional), - cashin_tiny_amount = getAmount("cashin_tiny_amount", regional), - cashin_rounding_mode = getMode("cashin_rounding_mode"), - cashin_min_amount = getAmount("cashin_min_amount", fiat), - cashout_ratio = getRatio("cashout_ratio"), - cashout_fee = getAmount("cashout_fee", fiat), - cashout_tiny_amount = getAmount("cashout_tiny_amount", fiat), - cashout_rounding_mode = getMode("cashout_rounding_mode"), - cashout_min_amount = getAmount("cashout_min_amount", regional), - ) + suspend fun getConfig(regional: String, fiat: String): ConversionRate? = db.serializableTransaction { conn -> + val check = conn.prepareStatement("select exists(select 1 from config where key='cashin_ratio')").oneOrNull { it.getBoolean(1) }!! + if (!check) return@serializableTransaction null + val amount = conn.prepareStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount") + val roundingMode = conn.prepareStatement("SELECT config_get_rounding_mode(?)") + fun getAmount(name: String, currency: String): TalerAmount { + amount.setString(1, name) + return amount.oneOrNull { it.getAmount("amount", currency) }!! + } + fun getRatio(name: String): DecimalNumber = getAmount(name, "").run { DecimalNumber(value, frac) } + fun getMode(name: String): RoundingMode { + roundingMode.setString(1, name) + return roundingMode.oneOrNull { RoundingMode.valueOf(it.getString(1)) }!! } + ConversionRate( + cashin_ratio = getRatio("cashin_ratio"), + cashin_fee = getAmount("cashin_fee", regional), + cashin_tiny_amount = getAmount("cashin_tiny_amount", regional), + cashin_rounding_mode = getMode("cashin_rounding_mode"), + cashin_min_amount = getAmount("cashin_min_amount", fiat), + cashout_ratio = getRatio("cashout_ratio"), + cashout_fee = getAmount("cashout_fee", fiat), + cashout_tiny_amount = getAmount("cashout_tiny_amount", fiat), + cashout_rounding_mode = getMode("cashout_rounding_mode"), + cashout_min_amount = getAmount("cashout_min_amount", regional), + ) } /** Clear in-db conversion config */ - suspend fun clearConfig() = db.serializable { conn -> - conn.prepareStatement("DELETE FROM config WHERE key LIKE 'cashin%' OR key like 'cashout%'").executeUpdate() + suspend fun clearConfig() = db.serializableWrite( + "DELETE FROM config WHERE key LIKE 'cashin%' OR key like 'cashout%'" + ) { + executeUpdate() } /** Result of conversions operations */ @@ -108,15 +106,14 @@ class ConversionDAO(private val db: Database) { } /** Perform [direction] conversion of [amount] using in-db [function] */ - private suspend fun conversion(amount: TalerAmount, direction: String, function: String): ConversionResult = db.conn { conn -> - val stmt = conn.prepareStatement("SELECT too_small, no_config, (converted).val AS amount_val, (converted).frac AS amount_frac FROM $function((?, ?)::taler_amount, ?, (0, 0)::taler_amount)") - stmt.setLong(1, amount.value) - stmt.setInt(2, amount.frac) - stmt.setString(3, direction) - stmt.executeQuery().use { + private suspend fun conversion(amount: TalerAmount, direction: String, function: String): ConversionResult = db.serializableRead( + "SELECT too_small, no_config, (converted).val AS amount_val, (converted).frac AS amount_frac FROM $function((?, ?)::taler_amount, ?, (0, 0)::taler_amount)" + ) { + setLong(1, amount.value) + setInt(2, amount.frac) + setString(3, direction) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure $function") it.getBoolean("no_config") -> ConversionResult.MissingConfig it.getBoolean("too_small") -> ConversionResult.ToSmall else -> ConversionResult.Success( diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt @@ -105,30 +105,31 @@ class Database(dbConfig: DatabaseConfig, internal val bankCurrency: String, inte suspend fun monitor( params: MonitorParams - ): MonitorResponse = conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - cashin_count - ,(cashin_regional_volume).val as cashin_regional_volume_val - ,(cashin_regional_volume).frac as cashin_regional_volume_frac - ,(cashin_fiat_volume).val as cashin_fiat_volume_val - ,(cashin_fiat_volume).frac as cashin_fiat_volume_frac - ,cashout_count - ,(cashout_regional_volume).val as cashout_regional_volume_val - ,(cashout_regional_volume).frac as cashout_regional_volume_frac - ,(cashout_fiat_volume).val as cashout_fiat_volume_val - ,(cashout_fiat_volume).frac as cashout_fiat_volume_frac - ,taler_in_count - ,(taler_in_volume).val as taler_in_volume_val - ,(taler_in_volume).frac as taler_in_volume_frac - ,taler_out_count - ,(taler_out_volume).val as taler_out_volume_val - ,(taler_out_volume).frac as taler_out_volume_frac - FROM stats_get_frame(?::timestamp, ?::stat_timeframe_enum) - """) - stmt.setObject(1, params.date) - stmt.setString(2, params.timeframe.name) - stmt.oneOrNull { + ): MonitorResponse = serializableRead( + """ + SELECT + cashin_count + ,(cashin_regional_volume).val as cashin_regional_volume_val + ,(cashin_regional_volume).frac as cashin_regional_volume_frac + ,(cashin_fiat_volume).val as cashin_fiat_volume_val + ,(cashin_fiat_volume).frac as cashin_fiat_volume_frac + ,cashout_count + ,(cashout_regional_volume).val as cashout_regional_volume_val + ,(cashout_regional_volume).frac as cashout_regional_volume_frac + ,(cashout_fiat_volume).val as cashout_fiat_volume_val + ,(cashout_fiat_volume).frac as cashout_fiat_volume_frac + ,taler_in_count + ,(taler_in_volume).val as taler_in_volume_val + ,(taler_in_volume).frac as taler_in_volume_frac + ,taler_out_count + ,(taler_out_volume).val as taler_out_volume_val + ,(taler_out_volume).frac as taler_out_volume_frac + FROM stats_get_frame(?::timestamp, ?::stat_timeframe_enum) + """ + ) { + setObject(1, params.date) + setString(2, params.timeframe.name) + oneOrNull { fiatCurrency?.run { MonitorWithConversion( cashinCount = it.getLong("cashin_count"), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt @@ -102,40 +102,39 @@ class ExchangeDAO(private val db: Database) { req: TransferRequest, login: String, timestamp: Instant - ): TransferResult = db.serializable { conn -> + ): TransferResult = db.serializableWrite( + """ + SELECT + out_debtor_not_found + ,out_debtor_not_exchange + ,out_creditor_not_found + ,out_both_exchanges + ,out_request_uid_reuse + ,out_exchange_balance_insufficient + ,out_tx_row_id + ,out_timestamp + FROM + taler_transfer ( + ?, ?, ?, + (?,?)::taler_amount, + ?, ?, ?, ? + ); + """ + ) { val subject = "${req.wtid} ${req.exchange_base_url.url}" - val stmt = conn.prepareStatement(""" - SELECT - out_debtor_not_found - ,out_debtor_not_exchange - ,out_creditor_not_found - ,out_both_exchanges - ,out_request_uid_reuse - ,out_exchange_balance_insufficient - ,out_tx_row_id - ,out_timestamp - FROM - taler_transfer ( - ?, ?, ?, - (?,?)::taler_amount, - ?, ?, ?, ? - ); - """) - stmt.setBytes(1, req.request_uid.raw) - stmt.setBytes(2, req.wtid.raw) - stmt.setString(3, subject) - stmt.setLong(4, req.amount.value) - stmt.setInt(5, req.amount.frac) - stmt.setString(6, req.exchange_base_url.url) - stmt.setString(7, req.credit_account.canonical) - stmt.setString(8, login) - stmt.setLong(9, timestamp.micros()) + setBytes(1, req.request_uid.raw) + setBytes(2, req.wtid.raw) + setString(3, subject) + setLong(4, req.amount.value) + setInt(5, req.amount.frac) + setString(6, req.exchange_base_url.url) + setString(7, req.credit_account.canonical) + setString(8, login) + setLong(9, timestamp.micros()) - stmt.executeQuery().use { + one { when { - !it.next() -> - throw internalServerError("SQL function taler_transfer did not return anything.") it.getBoolean("out_debtor_not_found") -> TransferResult.UnknownExchange it.getBoolean("out_debtor_not_exchange") -> TransferResult.NotAnExchange it.getBoolean("out_creditor_not_found") -> TransferResult.UnknownCreditor @@ -167,33 +166,33 @@ class ExchangeDAO(private val db: Database) { req: AddIncomingRequest, login: String, timestamp: Instant - ): AddIncomingResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_creditor_not_found - ,out_creditor_not_exchange - ,out_debtor_not_found - ,out_both_exchanges - ,out_reserve_pub_reuse - ,out_debitor_balance_insufficient - ,out_tx_row_id - FROM + ): AddIncomingResult = db.serializableWrite( + """ + SELECT + out_creditor_not_found + ,out_creditor_not_exchange + ,out_debtor_not_found + ,out_both_exchanges + ,out_reserve_pub_reuse + ,out_debitor_balance_insufficient + ,out_tx_row_id + FROM taler_add_incoming ( ?, ?, (?,?)::taler_amount, ?, ?, ? - ); - """) - - stmt.setBytes(1, req.reserve_pub.raw) - stmt.setString(2, "Manual incoming ${req.reserve_pub}") - stmt.setLong(3, req.amount.value) - stmt.setInt(4, req.amount.frac) - stmt.setString(5, req.debit_account.canonical) - stmt.setString(6, login) - stmt.setLong(7, timestamp.micros()) + ); + """ + ) { + setBytes(1, req.reserve_pub.raw) + setString(2, "Manual incoming ${req.reserve_pub}") + setLong(3, req.amount.value) + setInt(4, req.amount.frac) + setString(5, req.debit_account.canonical) + setString(6, login) + setLong(7, timestamp.micros()) - stmt.one { + one { when { it.getBoolean("out_creditor_not_found") -> AddIncomingResult.UnknownExchange it.getBoolean("out_creditor_not_exchange") -> AddIncomingResult.NotAnExchange diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -22,6 +22,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.Operation import tech.libeufin.bank.TanChannel import tech.libeufin.common.db.oneOrNull +import tech.libeufin.common.db.one import tech.libeufin.common.internalServerError import tech.libeufin.common.micros import java.time.Duration @@ -41,18 +42,19 @@ class TanDAO(private val db: Database) { validityPeriod: Duration, channel: TanChannel? = null, info: String? = null - ): Long = db.serializable { conn -> - val stmt = conn.prepareStatement("SELECT tan_challenge_create(?,?::op_enum,?,?,?,?,?,?::tan_enum,?)") - stmt.setString(1, body) - stmt.setString(2, op.name) - stmt.setString(3, code) - stmt.setLong(4, timestamp.micros()) - stmt.setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) - stmt.setInt(6, retryCounter) - stmt.setString(7, login) - stmt.setString(8, channel?.name) - stmt.setString(9, info) - stmt.oneOrNull { + ): Long = db.serializableWrite( + "SELECT tan_challenge_create(?,?::op_enum,?,?,?,?,?,?::tan_enum,?)" + ) { + setString(1, body) + setString(2, op.name) + setString(3, code) + setLong(4, timestamp.micros()) + setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) + setInt(6, retryCounter) + setString(7, login) + setString(8, channel?.name) + setString(9, info) + oneOrNull { it.getLong(1) } ?: throw internalServerError("TAN challenge returned nothing.") } @@ -71,17 +73,17 @@ class TanDAO(private val db: Database) { timestamp: Instant, retryCounter: Int, validityPeriod: Duration - ) = db.serializable { conn -> - val stmt = conn.prepareStatement("SELECT out_no_op, out_tan_code, out_tan_channel, out_tan_info FROM tan_challenge_send(?,?,?,?,?,?)") - stmt.setLong(1, id) - stmt.setString(2, login) - stmt.setString(3, code) - stmt.setLong(4, timestamp.micros()) - stmt.setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) - stmt.setInt(6, retryCounter) - stmt.executeQuery().use { + ) = db.serializableWrite( + "SELECT out_no_op, out_tan_code, out_tan_channel, out_tan_info FROM tan_challenge_send(?,?,?,?,?,?)" + ) { + setLong(1, id) + setString(2, login) + setString(3, code) + setLong(4, timestamp.micros()) + setLong(5, TimeUnit.MICROSECONDS.convert(validityPeriod)) + setInt(6, retryCounter) + one { when { - !it.next() -> throw internalServerError("TAN send returned nothing.") it.getBoolean("out_no_op") -> TanSendResult.NotFound else -> TanSendResult.Success( tanInfo = it.getString("out_tan_info"), @@ -97,12 +99,13 @@ class TanDAO(private val db: Database) { id: Long, timestamp: Instant, retransmissionPeriod: Duration - ) = db.serializable { conn -> - val stmt = conn.prepareStatement("SELECT tan_challenge_mark_sent(?,?,?)") - stmt.setLong(1, id) - stmt.setLong(2, timestamp.micros()) - stmt.setLong(3, TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) - stmt.executeQuery() + ) = db.serializableWrite( + "SELECT tan_challenge_mark_sent(?,?,?)" + ) { + setLong(1, id) + setLong(2, timestamp.micros()) + setLong(3, TimeUnit.MICROSECONDS.convert(retransmissionPeriod)) + executeQuery() } /** Result of TAN challenge solution */ @@ -120,19 +123,20 @@ class TanDAO(private val db: Database) { login: String, code: String, timestamp: Instant - ) = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_ok, out_no_op, out_no_retry, out_expired, - out_body, out_op, out_channel, out_info - FROM tan_challenge_try(?,?,?,?)""") - stmt.setLong(1, id) - stmt.setString(2, login) - stmt.setString(3, code) - stmt.setLong(4, timestamp.micros()) - stmt.executeQuery().use { + ) = db.serializableWrite( + """ + SELECT + out_ok, out_no_op, out_no_retry, out_expired, + out_body, out_op, out_channel, out_info + FROM tan_challenge_try(?,?,?,?) + """ + ) { + setLong(1, id) + setString(2, login) + setString(3, code) + setLong(4, timestamp.micros()) + one { when { - !it.next() -> throw internalServerError("TAN try returned nothing") it.getBoolean("out_ok") -> TanSolveResult.Success( body = it.getString("out_body"), op = Operation.valueOf(it.getString("out_op")), @@ -158,17 +162,18 @@ class TanDAO(private val db: Database) { id: Long, login: String, op: Operation - ) = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT body, tan_challenges.tan_channel, tan_info - FROM tan_challenges - JOIN customers ON customer=customer_id - WHERE challenge_id=? AND op=?::op_enum AND login=? AND deleted_at IS NULL - """) - stmt.setLong(1, id) - stmt.setString(2, op.name) - stmt.setString(3, login) - stmt.oneOrNull { + ) = db.serializableWrite( + """ + SELECT body, tan_challenges.tan_channel, tan_info + FROM tan_challenges + JOIN customers ON customer=customer_id + WHERE challenge_id=? AND op=?::op_enum AND login=? AND deleted_at IS NULL + """ + ) { + setLong(1, id) + setString(2, op.name) + setString(3, login) + oneOrNull { Challenge( body = it.getString("body"), channel = it.getString("tan_channel")?.run { TanChannel.valueOf(this) }, diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -42,50 +42,53 @@ class TokenDAO(private val db: Database) { scope: TokenScope, isRefreshable: Boolean, description: String? - ): Boolean = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - INSERT INTO bearer_tokens ( - content, - creation_time, - expiration_time, - scope, - bank_customer, - is_refreshable, - description, - last_access - ) VALUES ( - ?,?,?,?::token_scope_enum, - (SELECT customer_id FROM customers WHERE login=? AND deleted_at IS NULL), - ?,?,?) - """) - stmt.setBytes(1, content) - stmt.setLong(2, creationTime.micros()) - stmt.setLong(3, expirationTime.micros()) - stmt.setString(4, scope.name) - stmt.setString(5, login) - stmt.setBoolean(6, isRefreshable) - stmt.setString(7, description) - stmt.setLong(8, creationTime.micros()) - stmt.executeUpdateViolation() + ): Boolean = db.serializableWrite( + """ + INSERT INTO bearer_tokens ( + content, + creation_time, + expiration_time, + scope, + bank_customer, + is_refreshable, + description, + last_access + ) VALUES ( + ?,?,?,?::token_scope_enum, + (SELECT customer_id FROM customers WHERE login=? AND deleted_at IS NULL), + ?,?,? + ) + """ + ) { + setBytes(1, content) + setLong(2, creationTime.micros()) + setLong(3, expirationTime.micros()) + setString(4, scope.name) + setString(5, login) + setBoolean(6, isRefreshable) + setString(7, description) + setLong(8, creationTime.micros()) + executeUpdateViolation() } /** Get info for [token] */ - suspend fun access(token: ByteArray, accessTime: Instant): BearerToken? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - UPDATE bearer_tokens - SET last_access=? - FROM customers - WHERE bank_customer=customer_id AND content=? AND deleted_at IS NULL - RETURNING - creation_time, - expiration_time, - login, - scope, - is_refreshable - """) - stmt.setLong(1, accessTime.micros()) - stmt.setBytes(2, token) - stmt.oneOrNull { + suspend fun access(token: ByteArray, accessTime: Instant): BearerToken? = db.serializableWrite( + """ + UPDATE bearer_tokens + SET last_access=? + FROM customers + WHERE bank_customer=customer_id AND content=? AND deleted_at IS NULL + RETURNING + creation_time, + expiration_time, + login, + scope, + is_refreshable + """ + ) { + setLong(1, accessTime.micros()) + setBytes(2, token) + oneOrNull { BearerToken( creationTime = it.getLong("creation_time").asInstant(), expirationTime = it.getLong("expiration_time").asInstant(), @@ -97,12 +100,11 @@ class TokenDAO(private val db: Database) { } /** Delete token [token] */ - suspend fun delete(token: ByteArray) = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - DELETE FROM bearer_tokens WHERE content = ? - """) - stmt.setBytes(1, token) - stmt.execute() + suspend fun delete(token: ByteArray) = db.serializableWrite( + "DELETE FROM bearer_tokens WHERE content = ?" + ) { + setBytes(1, token) + execute() } /** Get a page of all public accounts */ diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt @@ -52,127 +52,125 @@ class TransactionDAO(private val db: Database) { is2fa: Boolean, requestUid: ShortHashCode?, wireTransferFees: TalerAmount - ): BankTransactionResult = db.serializable { conn -> + ): BankTransactionResult = db.serializableTransaction { conn -> val timestamp = timestamp.micros() - conn.transaction { - val stmt = conn.prepareStatement(""" - SELECT - out_creditor_not_found - ,out_debtor_not_found - ,out_same_account - ,out_balance_insufficient - ,out_request_uid_reuse - ,out_tan_required - ,out_credit_bank_account_id - ,out_debit_bank_account_id - ,out_credit_row_id - ,out_debit_row_id - ,out_creditor_is_exchange - ,out_debtor_is_exchange - ,out_creditor_admin - ,out_idempotent - FROM bank_transaction(?,?,?,(?,?)::taler_amount,?,?,?,(?,?)::taler_amount) - """ - ) - stmt.setString(1, creditAccountPayto.canonical) - stmt.setString(2, debitAccountUsername) - stmt.setString(3, subject) - stmt.setLong(4, amount.value) - stmt.setInt(5, amount.frac) - stmt.setLong(6, timestamp) - stmt.setBoolean(7, is2fa) - stmt.setBytes(8, requestUid?.raw) - stmt.setLong(9, wireTransferFees.value) - stmt.setInt(10, wireTransferFees.frac) - stmt.executeQuery().use { - when { - !it.next() -> throw internalServerError("Bank transaction didn't properly return") - it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor - it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor - it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame - it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient - it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor - it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse - it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id")) - it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired - else -> { - val creditAccountId = it.getLong("out_credit_bank_account_id") - val creditRowId = it.getLong("out_credit_row_id") - val debitAccountId = it.getLong("out_debit_bank_account_id") - val debitRowId = it.getLong("out_debit_row_id") - val exchangeCreditor = it.getBoolean("out_creditor_is_exchange") - val exchangeDebtor = it.getBoolean("out_debtor_is_exchange") - if (exchangeCreditor && exchangeDebtor) { - logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future") - } else if (exchangeCreditor) { - val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold( - onSuccess = { reservePub -> - val registered = conn.prepareStatement("CALL register_incoming(?, ?)").run { - setBytes(1, reservePub.raw) - setLong(2, creditRowId) - executeProcedureViolation() - } - if (!registered) { - logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key") - "reserve public key reuse" - } else { - null - } - }, - onFailure = { e -> - logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}") - "malformed metadata: ${e.message}" + val stmt = conn.prepareStatement(""" + SELECT + out_creditor_not_found + ,out_debtor_not_found + ,out_same_account + ,out_balance_insufficient + ,out_request_uid_reuse + ,out_tan_required + ,out_credit_bank_account_id + ,out_debit_bank_account_id + ,out_credit_row_id + ,out_debit_row_id + ,out_creditor_is_exchange + ,out_debtor_is_exchange + ,out_creditor_admin + ,out_idempotent + FROM bank_transaction(?,?,?,(?,?)::taler_amount,?,?,?,(?,?)::taler_amount) + """ + ) + stmt.setString(1, creditAccountPayto.canonical) + stmt.setString(2, debitAccountUsername) + stmt.setString(3, subject) + stmt.setLong(4, amount.value) + stmt.setInt(5, amount.frac) + stmt.setLong(6, timestamp) + stmt.setBoolean(7, is2fa) + stmt.setBytes(8, requestUid?.raw) + stmt.setLong(9, wireTransferFees.value) + stmt.setInt(10, wireTransferFees.frac) + stmt.one { + when { + it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor + it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor + it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame + it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient + it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor + it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse + it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id")) + it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired + else -> { + val creditAccountId = it.getLong("out_credit_bank_account_id") + val creditRowId = it.getLong("out_credit_row_id") + val debitAccountId = it.getLong("out_debit_bank_account_id") + val debitRowId = it.getLong("out_debit_row_id") + val exchangeCreditor = it.getBoolean("out_creditor_is_exchange") + val exchangeDebtor = it.getBoolean("out_debtor_is_exchange") + if (exchangeCreditor && exchangeDebtor) { + logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future") + } else if (exchangeCreditor) { + val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold( + onSuccess = { reservePub -> + val registered = conn.prepareStatement("CALL register_incoming(?, ?)").run { + setBytes(1, reservePub.raw) + setLong(2, creditRowId) + executeProcedureViolation() } - ) - if (bounceCause != null) { - // No error can happens because an opposite transaction already took place in the same transaction - conn.prepareStatement(""" - SELECT bank_wire_transfer( - ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount - ); - """ - ).run { - setLong(1, debitAccountId) - setLong(2, creditAccountId) - setString(3, "Bounce $creditRowId: $bounceCause") - setLong(4, amount.value) - setInt(5, amount.frac) - setLong(6, timestamp) - executeQuery() + if (!registered) { + logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key") + "reserve public key reuse" + } else { + null } + }, + onFailure = { e -> + logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}") + "malformed metadata: ${e.message}" + } + ) + if (bounceCause != null) { + // No error can happens because an opposite transaction already took place in the same transaction + conn.prepareStatement(""" + SELECT bank_wire_transfer( + ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount + ); + """ + ).run { + setLong(1, debitAccountId) + setLong(2, creditAccountId) + setString(3, "Bounce $creditRowId: $bounceCause") + setLong(4, amount.value) + setInt(5, amount.frac) + setLong(6, timestamp) + executeQuery() } - } else if (exchangeDebtor) { - logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead") } - BankTransactionResult.Success(debitRowId) + } else if (exchangeDebtor) { + logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead") } + BankTransactionResult.Success(debitRowId) } } } } /** Get transaction [rowId] owned by [login] */ - suspend fun get(rowId: Long, login: String, ctx: BankPaytoCtx): BankAccountTransactionInfo? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - creditor_payto_uri - ,creditor_name - ,debtor_payto_uri - ,debtor_name - ,subject - ,(amount).val AS amount_val - ,(amount).frac AS amount_frac - ,transaction_date - ,direction - ,bank_transaction_id - FROM bank_account_transactions - JOIN bank_accounts ON bank_account_transactions.bank_account_id=bank_accounts.bank_account_id - JOIN customers ON customer_id=owning_customer_id - WHERE bank_transaction_id=? AND login=? - """) - stmt.setLong(1, rowId) - stmt.setString(2, login) - stmt.oneOrNull { + suspend fun get(rowId: Long, login: String, ctx: BankPaytoCtx): BankAccountTransactionInfo? = db.serializableRead( + """ + SELECT + creditor_payto_uri + ,creditor_name + ,debtor_payto_uri + ,debtor_name + ,subject + ,(amount).val AS amount_val + ,(amount).frac AS amount_frac + ,transaction_date + ,direction + ,bank_transaction_id + FROM bank_account_transactions + JOIN bank_accounts ON bank_account_transactions.bank_account_id=bank_accounts.bank_account_id + JOIN customers ON customer_id=owning_customer_id + WHERE bank_transaction_id=? AND login=? + """ + ) { + setLong(1, rowId) + setString(2, login) + oneOrNull { BankAccountTransactionInfo( creditor_payto_uri = it.getBankPayto("creditor_payto_uri", "creditor_name", ctx), debtor_payto_uri = it.getBankPayto("debtor_payto_uri", "debtor_name", ctx), diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt @@ -1,6 +1,6 @@ /* * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. + * Copyright (C) 2023-2024 Taler Systems S.A. * LibEuFin is free software; you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -47,39 +47,38 @@ class WithdrawalDAO(private val db: Database) { suggested_amount: TalerAmount?, timestamp: Instant, wireTransferFees: TalerAmount, - ): WithdrawalCreationResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_account_not_found, - out_account_is_exchange, - out_balance_insufficient - FROM create_taler_withdrawal( - ?,?, - ${if (amount != null) "(?,?)::taler_amount" else "NULL"}, - ${if (suggested_amount != null) "(?,?)::taler_amount" else "NULL"}, - ?, (?, ?)::taler_amount - ); - """) - stmt.setString(1, login) - stmt.setObject(2, uuid) + ): WithdrawalCreationResult = db.serializableWrite( + """ + SELECT + out_account_not_found, + out_account_is_exchange, + out_balance_insufficient + FROM create_taler_withdrawal( + ?,?, + ${if (amount != null) "(?,?)::taler_amount" else "NULL"}, + ${if (suggested_amount != null) "(?,?)::taler_amount" else "NULL"}, + ?, (?, ?)::taler_amount + ); + """ + ) { + setString(1, login) + setObject(2, uuid) var id = 3 if (amount != null) { - stmt.setLong(id, amount.value) - stmt.setInt(id+1, amount.frac) + setLong(id, amount.value) + setInt(id+1, amount.frac) id += 2 } if (suggested_amount != null) { - stmt.setLong(id, suggested_amount.value) - stmt.setInt(id+1, suggested_amount.frac) + setLong(id, suggested_amount.value) + setInt(id+1, suggested_amount.frac) id += 2 } - stmt.setLong(id, timestamp.micros()) - stmt.setLong(id+1, wireTransferFees.value) - stmt.setInt(id+2, wireTransferFees.frac) - stmt.executeQuery().use { + setLong(id, timestamp.micros()) + setLong(id+1, wireTransferFees.value) + setInt(id+2, wireTransferFees.frac) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure create_taler_withdrawal") it.getBoolean("out_account_not_found") -> WithdrawalCreationResult.UnknownAccount it.getBoolean("out_account_is_exchange") -> WithdrawalCreationResult.AccountIsExchange it.getBoolean("out_balance_insufficient") -> WithdrawalCreationResult.BalanceInsufficient @@ -89,18 +88,17 @@ class WithdrawalDAO(private val db: Database) { } /** Abort withdrawal operation [uuid] */ - suspend fun abort(uuid: UUID): AbortResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_no_op, - out_already_confirmed - FROM abort_taler_withdrawal(?) - """) - stmt.setObject(1, uuid) - stmt.executeQuery().use { + suspend fun abort(uuid: UUID): AbortResult = db.serializableWrite( + """ + SELECT + out_no_op, + out_already_confirmed + FROM abort_taler_withdrawal(?) + """ + ) { + setObject(1, uuid) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure abort_taler_withdrawal") it.getBoolean("out_no_op") -> AbortResult.UnknownOperation it.getBoolean("out_already_confirmed") -> AbortResult.AlreadyConfirmed else -> AbortResult.Success @@ -128,41 +126,39 @@ class WithdrawalDAO(private val db: Database) { reservePub: EddsaPublicKey, amount: TalerAmount?, wireTransferFees: TalerAmount, - ): WithdrawalSelectionResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_no_op, - out_already_selected, - out_reserve_pub_reuse, - out_account_not_found, - out_account_is_not_exchange, - out_status, - out_missing_amount, - out_amount_differs, - out_balance_insufficient - FROM select_taler_withdrawal( - ?, ?, ?, ?, - ${if (amount != null) "(?, ?)::taler_amount" else "NULL"}, - (?,?)::taler_amount - ); + ): WithdrawalSelectionResult = db.serializableWrite( """ - ) - stmt.setObject(1, uuid) - stmt.setBytes(2, reservePub.raw) - stmt.setString(3, "Taler withdrawal $reservePub") - stmt.setString(4, exchangePayto.canonical) + SELECT + out_no_op, + out_already_selected, + out_reserve_pub_reuse, + out_account_not_found, + out_account_is_not_exchange, + out_status, + out_missing_amount, + out_amount_differs, + out_balance_insufficient + FROM select_taler_withdrawal( + ?, ?, ?, ?, + ${if (amount != null) "(?, ?)::taler_amount" else "NULL"}, + (?,?)::taler_amount + ); + """ + ) { + setObject(1, uuid) + setBytes(2, reservePub.raw) + setString(3, "Taler withdrawal $reservePub") + setString(4, exchangePayto.canonical) var id = 5 if (amount != null) { - stmt.setLong(id, amount.value) - stmt.setInt(id+1, amount.frac) + setLong(id, amount.value) + setInt(id+1, amount.frac) id += 2 } - stmt.setLong(id, wireTransferFees.value) - stmt.setInt(id+1, wireTransferFees.frac) - stmt.executeQuery().use { + setLong(id, wireTransferFees.value) + setInt(id+1, wireTransferFees.frac) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure select_taler_withdrawal") it.getBoolean("out_balance_insufficient") -> WithdrawalSelectionResult.BalanceInsufficient it.getBoolean("out_no_op") -> WithdrawalSelectionResult.UnknownOperation it.getBoolean("out_already_selected") -> WithdrawalSelectionResult.AlreadySelected @@ -194,28 +190,26 @@ class WithdrawalDAO(private val db: Database) { wireTransferFees: TalerAmount, timestamp: Instant, is2fa: Boolean - ): WithdrawalConfirmationResult = db.serializable { conn -> - val stmt = conn.prepareStatement(""" - SELECT - out_no_op, - out_exchange_not_found, - out_balance_insufficient, - out_not_selected, - out_aborted, - out_tan_required - FROM confirm_taler_withdrawal(?,?,?,?,(?,?)::taler_amount); + ): WithdrawalConfirmationResult = db.serializableWrite( + """ + SELECT + out_no_op, + out_exchange_not_found, + out_balance_insufficient, + out_not_selected, + out_aborted, + out_tan_required + FROM confirm_taler_withdrawal(?,?,?,?,(?,?)::taler_amount); """ - ) - stmt.setString(1, login) - stmt.setObject(2, uuid) - stmt.setLong(3, timestamp.micros()) - stmt.setBoolean(4, is2fa) - stmt.setLong(5, wireTransferFees.value) - stmt.setInt(6, wireTransferFees.frac) - stmt.executeQuery().use { + ) { + setString(1, login) + setObject(2, uuid) + setLong(3, timestamp.micros()) + setBoolean(4, is2fa) + setLong(5, wireTransferFees.value) + setInt(6, wireTransferFees.frac) + one { when { - !it.next() -> - throw internalServerError("No result from DB procedure confirm_taler_withdrawal") it.getBoolean("out_no_op") -> WithdrawalConfirmationResult.UnknownOperation it.getBoolean("out_exchange_not_found") -> WithdrawalConfirmationResult.UnknownExchange it.getBoolean("out_balance_insufficient") -> WithdrawalConfirmationResult.BalanceInsufficient @@ -228,16 +222,17 @@ class WithdrawalDAO(private val db: Database) { } /** Get withdrawal operation [uuid] linked account username */ - suspend fun getUsername(uuid: UUID): String? = db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT login - FROM taler_withdrawal_operations - JOIN bank_accounts ON wallet_bank_account=bank_account_id - JOIN customers ON customer_id=owning_customer_id - WHERE withdrawal_uuid=? - """) - stmt.setObject(1, uuid) - stmt.oneOrNull { it.getString(1) } + suspend fun getUsername(uuid: UUID): String? = db.serializableRead( + """ + SELECT login + FROM taler_withdrawal_operations + JOIN bank_accounts ON wallet_bank_account=bank_account_id + JOIN customers ON customer_id=owning_customer_id + WHERE withdrawal_uuid=? + """ + ) { + setObject(1, uuid) + oneOrNull { it.getString(1) } } private suspend fun <T> poll( @@ -275,32 +270,33 @@ class WithdrawalDAO(private val db: Database) { /** Pool public info of operation [uuid] */ suspend fun pollInfo(uuid: UUID, params: StatusParams): WithdrawalPublicInfo? = poll(uuid, params, status = { it.status }) { - db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - CASE - WHEN confirmation_done THEN 'confirmed' - WHEN aborted THEN 'aborted' - WHEN selection_done THEN 'selected' - ELSE 'pending' - END as status - ,(amount).val as amount_val - ,(amount).frac as amount_frac - ,(suggested_amount).val as suggested_amount_val - ,(suggested_amount).frac as suggested_amount_frac - ,selection_done - ,aborted - ,confirmation_done - ,reserve_pub - ,selected_exchange_payto - ,login - FROM taler_withdrawal_operations - JOIN bank_accounts ON wallet_bank_account=bank_account_id - JOIN customers ON customer_id=owning_customer_id - WHERE withdrawal_uuid=? - """) - stmt.setObject(1, uuid) - stmt.oneOrNull { + db.serializableRead( + """ + SELECT + CASE + WHEN confirmation_done THEN 'confirmed' + WHEN aborted THEN 'aborted' + WHEN selection_done THEN 'selected' + ELSE 'pending' + END as status + ,(amount).val as amount_val + ,(amount).frac as amount_frac + ,(suggested_amount).val as suggested_amount_val + ,(suggested_amount).frac as suggested_amount_frac + ,selection_done + ,aborted + ,confirmation_done + ,reserve_pub + ,selected_exchange_payto + ,login + FROM taler_withdrawal_operations + JOIN bank_accounts ON wallet_bank_account=bank_account_id + JOIN customers ON customer_id=owning_customer_id + WHERE withdrawal_uuid=? + """ + ) { + setObject(1, uuid) + oneOrNull { WithdrawalPublicInfo( status = WithdrawalStatus.valueOf(it.getString("status")), amount = it.getOptAmount("amount", db.bankCurrency), @@ -316,31 +312,32 @@ class WithdrawalDAO(private val db: Database) { /** Pool public status of operation [uuid] */ suspend fun pollStatus(uuid: UUID, params: StatusParams, wire: WireMethod): BankWithdrawalOperationStatus? = poll(uuid, params, status = { it.status }) { - db.conn { conn -> - val stmt = conn.prepareStatement(""" - SELECT - CASE - WHEN confirmation_done THEN 'confirmed' - WHEN aborted THEN 'aborted' - WHEN selection_done THEN 'selected' - ELSE 'pending' - END as status - ,(amount).val as amount_val - ,(amount).frac as amount_frac - ,(suggested_amount).val as suggested_amount_val - ,(suggested_amount).frac as suggested_amount_frac - ,selection_done - ,aborted - ,confirmation_done - ,internal_payto_uri - ,reserve_pub - ,selected_exchange_payto - FROM taler_withdrawal_operations - JOIN bank_accounts ON (wallet_bank_account=bank_account_id) - WHERE withdrawal_uuid=? - """) - stmt.setObject(1, uuid) - stmt.oneOrNull { + db.serializableRead( + """ + SELECT + CASE + WHEN confirmation_done THEN 'confirmed' + WHEN aborted THEN 'aborted' + WHEN selection_done THEN 'selected' + ELSE 'pending' + END as status + ,(amount).val as amount_val + ,(amount).frac as amount_frac + ,(suggested_amount).val as suggested_amount_val + ,(suggested_amount).frac as suggested_amount_frac + ,selection_done + ,aborted + ,confirmation_done + ,internal_payto_uri + ,reserve_pub + ,selected_exchange_payto + FROM taler_withdrawal_operations + JOIN bank_accounts ON (wallet_bank_account=bank_account_id) + WHERE withdrawal_uuid=? + """ + ) { + setObject(1, uuid) + oneOrNull { BankWithdrawalOperationStatus( status = WithdrawalStatus.valueOf(it.getString("status")), amount = it.getOptAmount("amount", db.bankCurrency), diff --git a/bank/src/test/kotlin/DatabaseTest.kt b/bank/src/test/kotlin/DatabaseTest.kt @@ -17,6 +17,7 @@ * <http://www.gnu.org/licenses/> */ +import io.ktor.http.* import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import org.junit.Test @@ -44,9 +45,11 @@ class DatabaseTest { } @Test - fun serialisation() = bankSetup { + fun serialisation() = bankSetup { db -> assertBalance("customer", "+KUDOS:0") assertBalance("merchant", "+KUDOS:0") + + // Generate concurrent write transactions and check that they are all successful coroutineScope { repeat(10) { launch { @@ -56,16 +59,50 @@ class DatabaseTest { } assertBalance("customer", "-KUDOS:4.5") assertBalance("merchant", "+KUDOS:4.5") + + // Generate concurrent write and read transactions and check that they only write transaction sometimes fails coroutineScope { - repeat(5) { + repeat(100) { + // Write transaction launch { - tx("customer", "KUDOS:0.0$it", "merchant", "concurrent 0$it") + while (true) { + val result = client.postA("/accounts/customer/transactions") { + json { + "payto_uri" to "$merchantPayto?message=${"concurrent 0$it".encodeURLQueryComponent()}&amount=KUDOS:0.0$it" + } + } + if (result.status == HttpStatusCode.InternalServerError) { + val body = result.json<TalerError>() + assertEquals(TalerErrorCode.BANK_SOFT_EXCEPTION.code, body.code) + continue // retry + } else { + result.assertOk() + break + } + } } + // Simple read transaction :SELECT launch { client.getA("/accounts/merchant/transactions").assertOk() } + // Complex read transaction: stored procedure + launch { + client.getA("/monitor") { + pwAuth("admin") + }.assertOk() + } + // GC logic + launch { + try { + db.gc.collect(Instant.now(), Duration.ofMillis(20), Duration.ofMillis(20), Duration.ofMillis(20)) + } catch (e: Exception) { + // Check only serialization exception + } + } } } + assertBalance("customer", "-KUDOS:9.855") + assertBalance("merchant", "+KUDOS:9.855") } @Test diff --git a/common/src/main/kotlin/db/DbPool.kt b/common/src/main/kotlin/db/DbPool.kt @@ -26,8 +26,8 @@ import kotlinx.coroutines.withContext import org.postgresql.jdbc.PgConnection import org.postgresql.util.PSQLState import tech.libeufin.common.MIN_VERSION -import tech.libeufin.common.SERIALIZATION_RETRY import java.sql.SQLException +import java.sql.PreparedStatement open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable { val pgSource = pgDataSource(cfg.dbConnStr) @@ -50,6 +50,27 @@ open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable { } } + /** Executes a read-only query with automatic retry on serialization errors */ + suspend fun <R> serializableRead(query: String, lambda: PreparedStatement.() -> R): R = conn { conn -> + val stmt = conn.prepareStatement(query) + // TODO explicit read only for better perf ? + retrySerializationError { stmt.lambda() } + } + + /** Executes a query with automatic retry on serialization errors */ + suspend fun <R> serializableWrite(query: String, lambda: PreparedStatement.() -> R): R = conn { conn -> + val stmt = conn.prepareStatement(query) + retrySerializationError { stmt.lambda() } + } + + /** Executes a transaction with automatic retry on serialization errors */ + suspend fun <R> serializableTransaction(transaction: (PgConnection) -> R): R = conn { conn -> + retrySerializationError { + conn.transaction(transaction) + } + } + + /** Run db logic using a connection from the pool */ suspend fun <R> conn(lambda: suspend (PgConnection) -> R): R { // Use a coroutine dispatcher that we can block as JDBC API is blocking return withContext(Dispatchers.IO) { @@ -57,23 +78,6 @@ open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable { } } - suspend fun <R> serializable(lambda: suspend (PgConnection) -> R): R = conn { conn -> - repeat(SERIALIZATION_RETRY) { - try { - return@conn lambda(conn) - } catch (e: SQLException) { - if (e.sqlState != PSQLState.SERIALIZATION_FAILURE.state) - throw e - } - } - try { - return@conn lambda(conn) - } catch (e: SQLException) { - logger.warn("Serialization failure after $SERIALIZATION_RETRY retry") - throw e - } - } - override fun close() { pool.close() } diff --git a/common/src/main/kotlin/db/helpers.kt b/common/src/main/kotlin/db/helpers.kt @@ -37,7 +37,7 @@ suspend fun <T> DbPool.page( query: String, bind: PreparedStatement.() -> Int = { 0 }, map: (ResultSet) -> T -): List<T> = conn { conn -> +): List<T> { val backward = params.delta < 0 val pageQuery = """ $query @@ -45,7 +45,7 @@ suspend fun <T> DbPool.page( ORDER BY $idName ${if (backward) "DESC" else "ASC"} LIMIT ? """ - conn.prepareStatement(pageQuery).run { + return serializableRead(pageQuery) { val pad = bind() setLong(pad + 1, params.start) setInt(pad + 2, abs(params.delta)) diff --git a/common/src/main/kotlin/db/transaction.kt b/common/src/main/kotlin/db/transaction.kt @@ -26,13 +26,34 @@ import org.slf4j.LoggerFactory import java.sql.PreparedStatement import java.sql.ResultSet import java.sql.SQLException +import tech.libeufin.common.SERIALIZATION_RETRY internal val logger: Logger = LoggerFactory.getLogger("libeufin-db") -fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R { +/** Executes db logic with automatic retry on serialization errors */ +suspend fun <R> retrySerializationError(lambda: suspend () -> R): R { + repeat(SERIALIZATION_RETRY) { + try { + return lambda() + } catch (e: SQLException) { + println("${e.sqlState} ?? ${PSQLState.SERIALIZATION_FAILURE.state}") + if (e.sqlState != PSQLState.SERIALIZATION_FAILURE.state) + throw e + } + } + try { + return lambda() + } catch (e: SQLException) { + logger.warn("Serialization failure after $SERIALIZATION_RETRY retry") + throw e + } +} + +/** Run a postgres [transaction] */ +fun <R> PgConnection.transaction(transaction: (PgConnection) -> R): R { try { autoCommit = false - val result = lambda(this) + val result = transaction(this) commit() autoCommit = true return result @@ -43,15 +64,18 @@ fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R { } } +/** Read one row or null if none */ fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? { executeQuery().use { return if (it.next()) lambda(it) else null } } +/** Read one row or throw if none */ fun <T> PreparedStatement.one(lambda: (ResultSet) -> T): T = requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } +/** Read one row or throw [err] in case or unique violation error */ fun <T> PreparedStatement.oneUniqueViolation(err: T, lambda: (ResultSet) -> T): T { return try { one(lambda) @@ -61,6 +85,7 @@ fun <T> PreparedStatement.oneUniqueViolation(err: T, lambda: (ResultSet) -> T): } } +/** Read all rows */ fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> { executeQuery().use { val ret = mutableListOf<T>() @@ -71,22 +96,20 @@ fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> { } } +/** Execute a query checking it return a least one row */ fun PreparedStatement.executeQueryCheck(): Boolean { executeQuery().use { return it.next() } } +/** Execute an update checking it update at least one row */ fun PreparedStatement.executeUpdateCheck(): Boolean { executeUpdate() return updateCount > 0 } -/** - * Helper that returns false if the row to be inserted - * hits a unique key constraint violation, true when it - * succeeds. Any other error (re)throws exception. - */ +/** Execute an update checking if fail because of unique violation error */ fun PreparedStatement.executeUpdateViolation(): Boolean { return try { executeUpdateCheck() @@ -97,6 +120,7 @@ fun PreparedStatement.executeUpdateViolation(): Boolean { } } +/** Execute an update checking if fail because of unique violation error and reseting state */ fun PreparedStatement.executeProcedureViolation(): Boolean { val savepoint = connection.setSavepoint() return try { @@ -110,7 +134,10 @@ fun PreparedStatement.executeProcedureViolation(): Boolean { } } -// TODO comment +/** + * Execute an update of [table] with a dynamic query generated at runtime. + * Every [fields] in each row matching [filter] are updated using values from [bind]. + **/ fun PgConnection.dynamicUpdate( table: String, fields: Sequence<String>, diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/ExchangeDAO.kt @@ -92,31 +92,32 @@ class ExchangeDAO(private val db: Database) { req: TransferRequest, bankId: String, timestamp: Instant - ): TransferResult = db.serializable { conn -> + ): TransferResult = db.serializableWrite( + """ + SELECT + out_request_uid_reuse + ,out_tx_row_id + ,out_timestamp + FROM + taler_transfer ( + ?, ?, ?, + (?,?)::taler_amount, + ?, ?, ?, ? + ); + """ + ) { val subject = "${req.wtid} ${req.exchange_base_url.url}" - val stmt = conn.prepareStatement(""" - SELECT - out_request_uid_reuse - ,out_tx_row_id - ,out_timestamp - FROM - taler_transfer ( - ?, ?, ?, - (?,?)::taler_amount, - ?, ?, ?, ? - ); - """) - stmt.setBytes(1, req.request_uid.raw) - stmt.setBytes(2, req.wtid.raw) - stmt.setString(3, subject) - stmt.setLong(4, req.amount.value) - stmt.setInt(5, req.amount.frac) - stmt.setString(6, req.exchange_base_url.url) - stmt.setString(7, req.credit_account.canonical) - stmt.setString(8, bankId) - stmt.setLong(9, timestamp.micros()) - stmt.one { + setBytes(1, req.request_uid.raw) + setBytes(2, req.wtid.raw) + setString(3, subject) + setLong(4, req.amount.value) + setInt(5, req.amount.frac) + setString(6, req.exchange_base_url.url) + setString(7, req.credit_account.canonical) + setString(8, bankId) + setLong(9, timestamp.micros()) + one { when { it.getBoolean("out_request_uid_reuse") -> TransferResult.RequestUidReuse else -> TransferResult.Success(