commit 2fd8d17ec7010947d038d8a34bbf7aa9402ff214
parent a2436996041807e3bb136c436053f007b6340cea
Author: Antoine A <>
Date: Fri, 21 Jun 2024 00:37:31 +0200
common: close opened PreparedStatement and make some functions IMMUTABLE
Diffstat:
9 files changed, 392 insertions(+), 366 deletions(-)
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt
@@ -55,7 +55,7 @@ class AccountDAO(private val db: Database) {
ctx: BankPaytoCtx
): AccountCreationResult = db.serializableTransaction { conn ->
val timestamp = Instant.now().micros()
- val idempotent = conn.prepareStatement("""
+ val idempotent = conn.withStatement("""
SELECT password_hash, name=?
AND email IS NOT DISTINCT FROM ?
AND phone IS NOT DISTINCT FROM ?
@@ -69,7 +69,7 @@ class AccountDAO(private val db: Database) {
JOIN bank_accounts
ON customer_id=owning_customer_id
WHERE login=?
- """).run {
+ """) {
// TODO check max debt and min checkout ?
setString(1, name)
setString(2, email)
@@ -97,12 +97,12 @@ class AccountDAO(private val db: Database) {
}
} else {
if (internalPayto is IbanPayto)
- conn.prepareStatement("""
+ conn.withStatement ("""
INSERT INTO iban_history(
iban
,creation_time
) VALUES (?, ?)
- """).run {
+ """) {
setString(1, internalPayto.iban.value)
setLong(2, timestamp)
if (!executeUpdateViolation()) {
@@ -111,7 +111,7 @@ class AccountDAO(private val db: Database) {
}
}
- val customerId = conn.prepareStatement("""
+ val customerId = conn.withStatement("""
INSERT INTO customers (
login
,password_hash
@@ -123,7 +123,7 @@ class AccountDAO(private val db: Database) {
) VALUES (?, ?, ?, ?, ?, ?, ?::tan_enum)
RETURNING customer_id
"""
- ).run {
+ ) {
setString(1, login)
setString(2, PwCrypto.hashpw(password))
setString(3, name)
@@ -134,7 +134,7 @@ class AccountDAO(private val db: Database) {
oneOrNull { it.getLong("customer_id") }!!
}
- conn.prepareStatement("""
+ conn.withStatement("""
INSERT INTO bank_accounts(
internal_payto_uri
,owning_customer_id
@@ -143,7 +143,7 @@ class AccountDAO(private val db: Database) {
,max_debt
,min_cashout
) VALUES (?, ?, ?, ?, (?, ?)::taler_amount, ${if (minCashout == null) "NULL" else "(?, ?)::taler_amount"})
- """).run {
+ """) {
setString(1, internalPayto.canonical)
setLong(2, customerId)
setBoolean(3, isPublic)
@@ -161,10 +161,10 @@ class AccountDAO(private val db: Database) {
}
if (bonus.value != 0L || bonus.frac != 0) {
- conn.prepareStatement("""
+ conn.withStatement("""
SELECT out_balance_insufficient
FROM bank_transaction(?,'admin','bonus',(?,?)::taler_amount,?,true,NULL,(0, 0)::taler_amount)
- """).run {
+ """) {
setString(1, internalPayto.canonical)
setLong(2, bonus.value)
setInt(3, bonus.frac)
@@ -266,7 +266,7 @@ class AccountDAO(private val db: Database) {
)
// Get user ID and current data
- val curr = conn.prepareStatement("""
+ val curr = conn.withStatement("""
SELECT
customer_id, tan_channel, phone, email, name, cashout_payto
,(max_debt).val AS max_debt_val
@@ -277,7 +277,7 @@ class AccountDAO(private val db: Database) {
JOIN bank_accounts
ON customer_id=owning_customer_id
WHERE login=? AND deleted_at IS NULL
- """).run {
+ """) {
setString(1, login)
oneOrNull {
CurrentAccount(
@@ -339,9 +339,10 @@ class AccountDAO(private val db: Database) {
// Invalidate current challenges
if (patchChannel != null || patchInfo != null) {
- val stmt = conn.prepareStatement("UPDATE tan_challenges SET expiration_date=0 WHERE customer=?")
- stmt.setLong(1, curr.id)
- stmt.execute()
+ conn.withStatement("UPDATE tan_challenges SET expiration_date=0 WHERE customer=?") {
+ setLong(1, curr.id)
+ execute()
+ }
}
// Update bank info
@@ -412,10 +413,10 @@ class AccountDAO(private val db: Database) {
oldPw: String?,
is2fa: Boolean
): AccountPatchAuthResult = db.serializableTransaction { conn ->
- val (currentPwh, tanRequired) = conn.prepareStatement("""
+ val (currentPwh, tanRequired) = conn.withStatement("""
SELECT password_hash, (NOT ? AND tan_channel IS NOT NULL)
FROM customers WHERE login=? AND deleted_at IS NULL
- """).run {
+ """) {
setBoolean(1, is2fa)
setString(2, login)
oneOrNull {
@@ -427,12 +428,12 @@ class AccountDAO(private val db: Database) {
} else if (oldPw != null && !PwCrypto.checkpw(oldPw, currentPwh)) {
AccountPatchAuthResult.OldPasswordMismatch
} else {
- val stmt = conn.prepareStatement("""
- UPDATE customers SET password_hash=? where login=?
- """)
- stmt.setString(1, PwCrypto.hashpw(newPw))
- stmt.setString(2, login)
- stmt.executeUpdateCheck()
+ conn.withStatement("UPDATE customers SET password_hash=? where login=?") {
+ setString(1, PwCrypto.hashpw(newPw))
+ setString(2, login)
+ executeUpdateCheck()
+ }
+
AccountPatchAuthResult.Success
}
}
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt
@@ -28,43 +28,48 @@ import tech.libeufin.common.db.*
class ConversionDAO(private val db: Database) {
/** Update in-db conversion config */
suspend fun updateConfig(cfg: ConversionRate) = db.serializableTransaction { conn ->
- var stmt = conn.prepareStatement("CALL config_set_amount(?, (?, ?)::taler_amount)")
- for ((name, amount) in listOf(
- Pair("cashin_ratio", cfg.cashin_ratio),
- Pair("cashout_ratio", cfg.cashout_ratio),
- )) {
- stmt.setString(1, name)
- stmt.setLong(2, amount.value)
- stmt.setInt(3, amount.frac)
- stmt.executeUpdate()
- }
- for ((name, amount) in listOf(
- Pair("cashin_fee", cfg.cashin_fee),
- Pair("cashin_tiny_amount", cfg.cashin_tiny_amount),
- Pair("cashin_min_amount", cfg.cashin_min_amount),
- Pair("cashout_fee", cfg.cashout_fee),
- Pair("cashout_tiny_amount", cfg.cashout_tiny_amount),
- Pair("cashout_min_amount", cfg.cashout_min_amount),
- )) {
- stmt.setString(1, name)
- stmt.setLong(2, amount.value)
- stmt.setInt(3, amount.frac)
- stmt.executeUpdate()
+ conn.withStatement("CALL config_set_amount(?, (?, ?)::taler_amount)") {
+ for ((name, amount) in listOf(
+ Pair("cashin_ratio", cfg.cashin_ratio),
+ Pair("cashout_ratio", cfg.cashout_ratio),
+ )) {
+ setString(1, name)
+ setLong(2, amount.value)
+ setInt(3, amount.frac)
+ executeUpdate()
+ }
+ for ((name, amount) in listOf(
+ Pair("cashin_fee", cfg.cashin_fee),
+ Pair("cashin_tiny_amount", cfg.cashin_tiny_amount),
+ Pair("cashin_min_amount", cfg.cashin_min_amount),
+ Pair("cashout_fee", cfg.cashout_fee),
+ Pair("cashout_tiny_amount", cfg.cashout_tiny_amount),
+ Pair("cashout_min_amount", cfg.cashout_min_amount),
+ )) {
+ setString(1, name)
+ setLong(2, amount.value)
+ setInt(3, amount.frac)
+ executeUpdate()
+ }
}
- stmt = conn.prepareStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)")
- for ((name, value) in listOf(
- Pair("cashin_rounding_mode", cfg.cashin_rounding_mode),
- Pair("cashout_rounding_mode", cfg.cashout_rounding_mode)
- )) {
- stmt.setString(1, name)
- stmt.setString(2, value.name)
- stmt.executeUpdate()
+
+ conn.withStatement("CALL config_set_rounding_mode(?, ?::rounding_mode)") {
+ for ((name, value) in listOf(
+ Pair("cashin_rounding_mode", cfg.cashin_rounding_mode),
+ Pair("cashout_rounding_mode", cfg.cashout_rounding_mode)
+ )) {
+ setString(1, name)
+ setString(2, value.name)
+ executeUpdate()
+ }
}
}
/** Get in-db conversion config */
suspend fun getConfig(regional: String, fiat: String): ConversionRate? = db.serializableTransaction { conn ->
- val check = conn.prepareStatement("select exists(select 1 from config where key='cashin_ratio')").oneOrNull { it.getBoolean(1) }!!
+ val check = conn.withStatement("select exists(select 1 from config where key='cashin_ratio')") {
+ one { it.getBoolean(1) }
+ }
if (!check) return@serializableTransaction null
val amount = conn.prepareStatement("SELECT (amount).val as amount_val, (amount).frac as amount_frac FROM config_get_amount(?) as amount")
val roundingMode = conn.prepareStatement("SELECT config_get_rounding_mode(?)")
@@ -77,7 +82,7 @@ class ConversionDAO(private val db: Database) {
roundingMode.setString(1, name)
return roundingMode.oneOrNull { RoundingMode.valueOf(it.getString(1)) }!!
}
- ConversionRate(
+ val rate = ConversionRate(
cashin_ratio = getRatio("cashin_ratio"),
cashin_fee = getAmount("cashin_fee", regional),
cashin_tiny_amount = getAmount("cashin_tiny_amount", regional),
@@ -89,6 +94,9 @@ class ConversionDAO(private val db: Database) {
cashout_rounding_mode = getMode("cashout_rounding_mode"),
cashout_min_amount = getAmount("cashout_min_amount", regional),
)
+ amount.close()
+ roundingMode.close()
+ rate
}
/** Clear in-db conversion config */
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt
@@ -20,6 +20,7 @@
package tech.libeufin.bank.db
import tech.libeufin.common.micros
+import tech.libeufin.common.db.withStatement
import java.time.Duration
import java.time.Instant
@@ -37,9 +38,9 @@ class GcDAO(private val db: Database) {
val deleteAfterMicro = timestamp.minus(deleteAfter).micros()
// Abort pending operations
- conn.prepareStatement(
+ conn.withStatement(
"UPDATE taler_withdrawal_operations SET aborted = true WHERE creation_date < ?"
- ).run {
+ ) {
setLong(1, abortAfterMicro)
execute()
}
@@ -50,27 +51,27 @@ class GcDAO(private val db: Database) {
"DELETE FROM tan_challenges WHERE expiration_date < ?",
"DELETE FROM bearer_tokens WHERE expiration_time < ?"
)) {
- conn.prepareStatement(smt).run {
+ conn.withStatement(smt) {
setLong(1, cleanAfterMicro)
execute()
}
}
// Delete old bank transactions, linked operations are deleted by CASCADE
- conn.prepareStatement(
+ conn.withStatement(
"DELETE FROM bank_account_transactions WHERE transaction_date < ?"
- ).run {
+ ) {
setLong(1, deleteAfterMicro)
execute()
}
// Hard delete soft deleted customer without bank transactions, bank account are deleted by CASCADE
- conn.prepareStatement("""
+ conn.withStatement("""
DELETE FROM customers WHERE deleted_at IS NOT NULL AND NOT EXISTS(
SELECT 1 FROM bank_account_transactions NATURAL JOIN bank_accounts
WHERE owning_customer_id=customer_id
)
- """).run {
+ """) {
execute()
}
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt
@@ -54,7 +54,7 @@ class TransactionDAO(private val db: Database) {
wireTransferFees: TalerAmount
): BankTransactionResult = db.serializableTransaction { conn ->
val timestamp = timestamp.micros()
- val stmt = conn.prepareStatement("""
+ conn.withStatement("""
SELECT
out_creditor_not_found
,out_debtor_not_found
@@ -72,77 +72,77 @@ class TransactionDAO(private val db: Database) {
,out_idempotent
FROM bank_transaction(?,?,?,(?,?)::taler_amount,?,?,?,(?,?)::taler_amount)
"""
- )
- stmt.setString(1, creditAccountPayto.canonical)
- stmt.setString(2, debitAccountUsername)
- stmt.setString(3, subject)
- stmt.setLong(4, amount.value)
- stmt.setInt(5, amount.frac)
- stmt.setLong(6, timestamp)
- stmt.setBoolean(7, is2fa)
- stmt.setBytes(8, requestUid?.raw)
- stmt.setLong(9, wireTransferFees.value)
- stmt.setInt(10, wireTransferFees.frac)
- stmt.one {
- when {
- it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor
- it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor
- it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame
- it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient
- it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor
- it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse
- it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id"))
- it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired
- else -> {
- val creditAccountId = it.getLong("out_credit_bank_account_id")
- val creditRowId = it.getLong("out_credit_row_id")
- val debitAccountId = it.getLong("out_debit_bank_account_id")
- val debitRowId = it.getLong("out_debit_row_id")
- val exchangeCreditor = it.getBoolean("out_creditor_is_exchange")
- val exchangeDebtor = it.getBoolean("out_debtor_is_exchange")
- if (exchangeCreditor && exchangeDebtor) {
- logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future")
- } else if (exchangeCreditor) {
- val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold(
- onSuccess = { reservePub ->
- val registered = conn.prepareStatement("CALL register_incoming(?, ?)").run {
- setBytes(1, reservePub.raw)
- setLong(2, creditRowId)
- executeProcedureViolation()
+ ) {
+ setString(1, creditAccountPayto.canonical)
+ setString(2, debitAccountUsername)
+ setString(3, subject)
+ setLong(4, amount.value)
+ setInt(5, amount.frac)
+ setLong(6, timestamp)
+ setBoolean(7, is2fa)
+ setBytes(8, requestUid?.raw)
+ setLong(9, wireTransferFees.value)
+ setInt(10, wireTransferFees.frac)
+ one {
+ when {
+ it.getBoolean("out_creditor_not_found") -> BankTransactionResult.UnknownCreditor
+ it.getBoolean("out_debtor_not_found") -> BankTransactionResult.UnknownDebtor
+ it.getBoolean("out_same_account") -> BankTransactionResult.BothPartySame
+ it.getBoolean("out_balance_insufficient") -> BankTransactionResult.BalanceInsufficient
+ it.getBoolean("out_creditor_admin") -> BankTransactionResult.AdminCreditor
+ it.getBoolean("out_request_uid_reuse") -> BankTransactionResult.RequestUidReuse
+ it.getBoolean("out_idempotent") -> BankTransactionResult.Success(it.getLong("out_debit_row_id"))
+ it.getBoolean("out_tan_required") -> BankTransactionResult.TanRequired
+ else -> {
+ val creditAccountId = it.getLong("out_credit_bank_account_id")
+ val creditRowId = it.getLong("out_credit_row_id")
+ val debitAccountId = it.getLong("out_debit_bank_account_id")
+ val debitRowId = it.getLong("out_debit_row_id")
+ val exchangeCreditor = it.getBoolean("out_creditor_is_exchange")
+ val exchangeDebtor = it.getBoolean("out_debtor_is_exchange")
+ if (exchangeCreditor && exchangeDebtor) {
+ logger.warn("exchange account $exchangeDebtor sent a manual transaction to exchange account $exchangeCreditor, this should never happens and is not bounced to prevent bouncing loop, may fail in the future")
+ } else if (exchangeCreditor) {
+ val bounceCause = runCatching { parseIncomingTxMetadata(subject) }.fold(
+ onSuccess = { reservePub ->
+ val registered = conn.withStatement("CALL register_incoming(?, ?)") {
+ setBytes(1, reservePub.raw)
+ setLong(2, creditRowId)
+ executeProcedureViolation()
+ }
+ if (!registered) {
+ logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key")
+ "reserve public key reuse"
+ } else {
+ null
+ }
+ },
+ onFailure = { e ->
+ logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}")
+ "malformed metadata: ${e.message}"
}
- if (!registered) {
- logger.warn("exchange account $creditAccountId received an incoming taler transaction $creditRowId with an already used reserve public key")
- "reserve public key reuse"
- } else {
- null
+ )
+ if (bounceCause != null) {
+ // No error can happens because an opposite transaction already took place in the same transaction
+ conn.withStatement("""
+ SELECT bank_wire_transfer(
+ ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount
+ );
+ """) {
+ setLong(1, debitAccountId)
+ setLong(2, creditAccountId)
+ setString(3, "Bounce $creditRowId: $bounceCause")
+ setLong(4, amount.value)
+ setInt(5, amount.frac)
+ setLong(6, timestamp)
+ executeQuery()
}
- },
- onFailure = { e ->
- logger.warn("exchange account $creditAccountId received a manual transaction $creditRowId with malformed metadata: ${e.message}")
- "malformed metadata: ${e.message}"
- }
- )
- if (bounceCause != null) {
- // No error can happens because an opposite transaction already took place in the same transaction
- conn.prepareStatement("""
- SELECT bank_wire_transfer(
- ?, ?, ?, (?, ?)::taler_amount, ?, (0, 0)::taler_amount
- );
- """
- ).run {
- setLong(1, debitAccountId)
- setLong(2, creditAccountId)
- setString(3, "Bounce $creditRowId: $bounceCause")
- setLong(4, amount.value)
- setInt(5, amount.frac)
- setLong(6, timestamp)
- executeQuery()
}
+ } else if (exchangeDebtor) {
+ logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead")
}
- } else if (exchangeDebtor) {
- logger.warn("exchange account $debitAccountId sent a manual transaction $debitRowId which will not be recorderd as a taler outgoing transaction, use the API instead")
+ BankTransactionResult.Success(debitRowId)
}
- BankTransactionResult.Success(debitRowId)
}
}
}
diff --git a/common/src/main/kotlin/db/DbPool.kt b/common/src/main/kotlin/db/DbPool.kt
@@ -52,15 +52,18 @@ open class DbPool(cfg: DatabaseConfig, schema: String) : java.io.Closeable {
/** Executes a read-only query with automatic retry on serialization errors */
suspend fun <R> serializableRead(query: String, lambda: PreparedStatement.() -> R): R = conn { conn ->
- val stmt = conn.prepareStatement(query)
- // TODO explicit read only for better perf ?
- retrySerializationError { stmt.lambda() }
+ conn.withStatement(query) {
+ // TODO explicit read only for better perf ?
+ retrySerializationError { lambda() }
+ }
+
}
/** Executes a query with automatic retry on serialization errors */
suspend fun <R> serializableWrite(query: String, lambda: PreparedStatement.() -> R): R = conn { conn ->
- val stmt = conn.prepareStatement(query)
- retrySerializationError { stmt.lambda() }
+ conn.withStatement(query) {
+ retrySerializationError { lambda() }
+ }
}
/** Executes a transaction with automatic retry on serialization errors */
diff --git a/common/src/main/kotlin/db/transaction.kt b/common/src/main/kotlin/db/transaction.kt
@@ -43,6 +43,11 @@ suspend fun <R> retrySerializationError(lambda: suspend () -> R): R {
return lambda()
}
+/** Run a postgres query using a prepared statement */
+inline fun <R> PgConnection.withStatement(query: String, lambda: PreparedStatement.() -> R): R {
+ return prepareStatement(query).use(lambda)
+}
+
/** Run a postgres [transaction] */
fun <R> PgConnection.transaction(transaction: (PgConnection) -> R): R {
try {
@@ -140,7 +145,7 @@ fun PgConnection.dynamicUpdate(
) {
val sql = fields.joinToString()
if (sql.isEmpty()) return
- prepareStatement("UPDATE $table SET $sql $filter").run {
+ withStatement("UPDATE $table SET $sql $filter") {
for ((idx, value) in bind.withIndex()) {
setObject(idx + 1, value)
}
diff --git a/database-versioning/libeufin-bank-procedures.sql b/database-versioning/libeufin-bank-procedures.sql
@@ -28,7 +28,7 @@ CREATE FUNCTION amount_normalize(
IN amount taler_amount
,OUT normalized taler_amount
)
-LANGUAGE plpgsql AS $$
+LANGUAGE plpgsql IMMUTABLE AS $$
BEGIN
normalized.val = amount.val + amount.frac / 100000000;
IF (normalized.val > 1::INT8<<52) THEN
@@ -46,7 +46,7 @@ CREATE FUNCTION amount_add(
,IN r taler_amount
,OUT sum taler_amount
)
-LANGUAGE plpgsql AS $$
+LANGUAGE plpgsql IMMUTABLE AS $$
BEGIN
sum = (l.val + r.val, l.frac + r.frac);
SELECT normalized.val, normalized.frac INTO sum.val, sum.frac FROM amount_normalize(sum) as normalized;
@@ -60,7 +60,7 @@ CREATE FUNCTION amount_left_minus_right(
,OUT diff taler_amount
,OUT ok BOOLEAN
)
-LANGUAGE plpgsql AS $$
+LANGUAGE plpgsql IMMUTABLE AS $$
BEGIN
IF l.val > r.val THEN
ok = TRUE;
diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt
@@ -38,25 +38,26 @@ class InitiatedDAO(private val db: Database) {
}
/** Register a new pending payment in the database */
- suspend fun create(paymentData: InitiatedPayment): PaymentInitiationResult = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- INSERT INTO initiated_outgoing_transactions (
- amount
- ,wire_transfer_subject
- ,credit_payto_uri
- ,initiation_time
- ,request_uid
- ) VALUES ((?,?)::taler_amount,?,?,?,?)
- RETURNING initiated_outgoing_transaction_id
- """)
+ suspend fun create(paymentData: InitiatedPayment): PaymentInitiationResult = db.serializableWrite(
+ """
+ INSERT INTO initiated_outgoing_transactions (
+ amount
+ ,wire_transfer_subject
+ ,credit_payto_uri
+ ,initiation_time
+ ,request_uid
+ ) VALUES ((?,?)::taler_amount,?,?,?,?)
+ RETURNING initiated_outgoing_transaction_id
+ """
+ ) {
// TODO check payto uri
- stmt.setLong(1, paymentData.amount.value)
- stmt.setInt(2, paymentData.amount.frac)
- stmt.setString(3, paymentData.wireTransferSubject)
- stmt.setString(4, paymentData.creditPaytoUri)
- stmt.setLong(5, paymentData.initiationTime.micros())
- stmt.setString(6, paymentData.requestUid)
- stmt.oneUniqueViolation(PaymentInitiationResult.RequestUidReuse) {
+ setLong(1, paymentData.amount.value)
+ setInt(2, paymentData.amount.frac)
+ setString(3, paymentData.wireTransferSubject)
+ setString(4, paymentData.creditPaytoUri)
+ setLong(5, paymentData.initiationTime.micros())
+ setString(6, paymentData.requestUid)
+ oneUniqueViolation(PaymentInitiationResult.RequestUidReuse) {
PaymentInitiationResult.Success(it.getLong("initiated_outgoing_transaction_id"))
}
}
@@ -66,20 +67,21 @@ class InitiatedDAO(private val db: Database) {
id: Long,
timestamp: Instant,
orderId: String
- ) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions SET
- submitted = 'success'::submission_state
- ,last_submission_time = ?
- ,failure_message = NULL
- ,order_id = ?
- ,submission_counter = submission_counter + 1
- WHERE initiated_outgoing_transaction_id = ?
- """)
- stmt.setLong(1, timestamp.micros())
- stmt.setString(2, orderId)
- stmt.setLong(3, id)
- stmt.execute()
+ ) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions SET
+ submitted = 'success'::submission_state
+ ,last_submission_time = ?
+ ,failure_message = NULL
+ ,order_id = ?
+ ,submission_counter = submission_counter + 1
+ WHERE initiated_outgoing_transaction_id = ?
+ """
+ ) {
+ setLong(1, timestamp.micros())
+ setString(2, orderId)
+ setLong(3, id)
+ execute()
}
/** Register EBICS submission failure */
@@ -87,106 +89,101 @@ class InitiatedDAO(private val db: Database) {
id: Long,
timestamp: Instant,
msg: String?
- ) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions SET
- submitted = 'transient_failure'::submission_state
- ,last_submission_time = ?
- ,failure_message = ?
- ,submission_counter = submission_counter + 1
- WHERE initiated_outgoing_transaction_id = ?
- """)
- stmt.setLong(1, timestamp.micros())
- stmt.setString(2, msg)
- stmt.setLong(3, id)
- stmt.execute()
+ ) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions SET
+ submitted = 'transient_failure'::submission_state
+ ,last_submission_time = ?
+ ,failure_message = ?
+ ,submission_counter = submission_counter + 1
+ WHERE initiated_outgoing_transaction_id = ?
+ """
+ ) {
+ setLong(1, timestamp.micros())
+ setString(2, msg)
+ setLong(3, id)
+ execute()
}
/** Register EBICS log status message */
- suspend fun logMessage(orderId: String, msg: String) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions SET failure_message = ?
- WHERE order_id = ?
- """)
- stmt.setString(1, msg)
- stmt.setString(2, orderId)
- stmt.execute()
+ suspend fun logMessage(orderId: String, msg: String) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions SET failure_message = ?
+ WHERE order_id = ?
+ """
+ ) {
+ setString(1, msg)
+ setString(2, orderId)
+ execute()
}
/** Register EBICS log success and return request_uid if found */
- suspend fun logSuccess(orderId: String): String? = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT request_uid FROM initiated_outgoing_transactions
- WHERE order_id = ?
- """)
- stmt.setString(1, orderId)
- stmt.oneOrNull { it.getString(1) }
+ suspend fun logSuccess(orderId: String): String? = db.serializableWrite(
+ """
+ SELECT request_uid FROM initiated_outgoing_transactions
+ WHERE order_id = ?
+ """
+ ) {
+ setString(1, orderId)
+ oneOrNull { it.getString(1) }
}
/** Register EBICS log failure and return request_uid and previous message if found */
- suspend fun logFailure(orderId: String): Pair<String, String?>? = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions
- SET submitted = 'permanent_failure'::submission_state
- WHERE order_id = ?
- RETURNING request_uid, failure_message
- """)
- stmt.setString(1, orderId)
- stmt.oneOrNull { Pair(it.getString(1), it.getString(2)) }
+ suspend fun logFailure(orderId: String): Pair<String, String?>? = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions
+ SET submitted = 'permanent_failure'::submission_state
+ WHERE order_id = ?
+ RETURNING request_uid, failure_message
+ """
+ ) {
+ setString(1, orderId)
+ oneOrNull { Pair(it.getString(1), it.getString(2)) }
}
/** Register bank status message */
- suspend fun bankMessage(requestUID: String, msg: String) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions
- SET failure_message = ?
- WHERE request_uid = ?
- """)
- stmt.setString(1, msg)
- stmt.setString(2, requestUID)
- stmt.execute()
+ suspend fun bankMessage(requestUID: String, msg: String) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions
+ SET failure_message = ?
+ WHERE request_uid = ?
+ """
+ ) {
+ setString(1, msg)
+ setString(2, requestUID)
+ execute()
}
/** Register bank failure */
- suspend fun bankFailure(requestUID: String, msg: String) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions SET
- submitted = 'permanent_failure'::submission_state
- ,failure_message = ?
- WHERE request_uid = ?
- """)
- stmt.setString(1, msg)
- stmt.setString(2, requestUID)
- stmt.execute()
+ suspend fun bankFailure(requestUID: String, msg: String) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions SET
+ submitted = 'permanent_failure'::submission_state
+ ,failure_message = ?
+ WHERE request_uid = ?
+ """
+ ) {
+ setString(1, msg)
+ setString(2, requestUID)
+ execute()
}
/** Register reversal */
- suspend fun reversal(requestUID: String, msg: String) = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- UPDATE initiated_outgoing_transactions SET
- submitted = 'permanent_failure'::submission_state
- ,failure_message = ?
- WHERE request_uid = ?
- """)
- stmt.setString(1, msg)
- stmt.setString(2, requestUID)
- stmt.execute()
+ suspend fun reversal(requestUID: String, msg: String) = db.serializableWrite(
+ """
+ UPDATE initiated_outgoing_transactions SET
+ submitted = 'permanent_failure'::submission_state
+ ,failure_message = ?
+ WHERE request_uid = ?
+ """
+ ) {
+ setString(1, msg)
+ setString(2, requestUID)
+ execute()
}
/** List every initiated payment pending submission in the order they should be submitted */
- suspend fun submittable(currency: String): List<InitiatedPayment> = db.conn { conn ->
- fun extract(it: ResultSet): InitiatedPayment {
- val rowId = it.getLong("initiated_outgoing_transaction_id")
- val initiationTime = it.getLong("initiation_time").asInstant()
- return InitiatedPayment(
- id = it.getLong("initiated_outgoing_transaction_id"),
- amount = it.getAmount("amount", currency),
- creditPaytoUri = it.getString("credit_payto_uri"),
- wireTransferSubject = it.getString("wire_transfer_subject"),
- initiationTime = initiationTime,
- requestUid = it.getString("request_uid")
- )
- }
+ suspend fun submittable(currency: String): List<InitiatedPayment> {
val selectPart = """
SELECT
initiated_outgoing_transaction_id
@@ -206,12 +203,23 @@ class InitiatedDAO(private val db: Database) {
// Then we retry the failed transaction, starting with the oldest by submission time.
// This the bad path retrying each failed transaction applying a rotation based on
// resubmission time.
- val unsubmitted = conn.prepareStatement(
- "$selectPart WHERE submitted='unsubmitted' ORDER BY initiation_time"
- ).all(::extract)
- val failed = conn.prepareStatement(
- "$selectPart WHERE submitted='transient_failure' ORDER BY last_submission_time"
- ).all(::extract)
- unsubmitted + failed
+ return db.serializableRead(
+ """
+ ($selectPart WHERE submitted='unsubmitted' ORDER BY initiation_time)
+ UNION ALL
+ ($selectPart WHERE submitted='transient_failure' ORDER BY last_submission_time)
+ """
+ ) {
+ all {
+ InitiatedPayment(
+ id = it.getLong("initiated_outgoing_transaction_id"),
+ amount = it.getAmount("amount", currency),
+ creditPaytoUri = it.getString("credit_payto_uri"),
+ wireTransferSubject = it.getString("wire_transfer_subject"),
+ initiationTime = it.getLong("initiation_time").asInstant(),
+ requestUid = it.getString("request_uid")
+ )
+ }
+ }
}
}
\ No newline at end of file
diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt
@@ -39,30 +39,24 @@ class PaymentDAO(private val db: Database) {
paymentData: OutgoingPayment,
wtid: ShortHashCode?,
baseUrl: ExchangeUrl?,
- ): OutgoingRegistrationResult = db.conn {
- val stmt = it.prepareStatement("""
- SELECT out_tx_id, out_initiated, out_found
- FROM register_outgoing((?,?)::taler_amount,?,?,?,?,?,?)
- """)
+ ): OutgoingRegistrationResult = db.serializableWrite(
+ """
+ SELECT out_tx_id, out_initiated, out_found
+ FROM register_outgoing((?,?)::taler_amount,?,?,?,?,?,?)
+ """
+ ) {
val executionTime = paymentData.executionTime.micros()
- stmt.setLong(1, paymentData.amount.value)
- stmt.setInt(2, paymentData.amount.frac)
- stmt.setString(3, paymentData.wireTransferSubject)
- stmt.setLong(4, executionTime)
- stmt.setString(5, paymentData.creditPaytoUri)
- stmt.setString(6, paymentData.messageId)
- if (wtid != null) {
- stmt.setBytes(7, wtid.raw)
- } else {
- stmt.setNull(7, java.sql.Types.NULL)
- }
- if (baseUrl != null) {
- stmt.setString(8, baseUrl.url)
- } else {
- stmt.setNull(8, java.sql.Types.NULL)
- }
+
+ setLong(1, paymentData.amount.value)
+ setInt(2, paymentData.amount.frac)
+ setString(3, paymentData.wireTransferSubject)
+ setLong(4, executionTime)
+ setString(5, paymentData.creditPaytoUri)
+ setString(6, paymentData.messageId)
+ setBytes(7, wtid?.raw)
+ setString(8, baseUrl?.url)
- stmt.one {
+ one {
OutgoingRegistrationResult(
it.getLong("out_tx_id"),
it.getBoolean("out_initiated"),
@@ -83,21 +77,22 @@ class PaymentDAO(private val db: Database) {
paymentData: IncomingPayment,
bounceAmount: TalerAmount,
timestamp: Instant
- ): IncomingBounceRegistrationResult = db.conn {
- val stmt = it.prepareStatement("""
- SELECT out_found, out_tx_id, out_bounce_id
- FROM register_incoming_and_bounce((?,?)::taler_amount,?,?,?,?,(?,?)::taler_amount,?)
- """)
- stmt.setLong(1, paymentData.amount.value)
- stmt.setInt(2, paymentData.amount.frac)
- stmt.setString(3, paymentData.wireTransferSubject)
- stmt.setLong(4, paymentData.executionTime.micros())
- stmt.setString(5, paymentData.debitPaytoUri)
- stmt.setString(6, paymentData.bankId)
- stmt.setLong(7, bounceAmount.value)
- stmt.setInt(8, bounceAmount.frac)
- stmt.setLong(9, timestamp.micros())
- stmt.one {
+ ): IncomingBounceRegistrationResult = db.serializableWrite(
+ """
+ SELECT out_found, out_tx_id, out_bounce_id
+ FROM register_incoming_and_bounce((?,?)::taler_amount,?,?,?,?,(?,?)::taler_amount,?)
+ """
+ ) {
+ setLong(1, paymentData.amount.value)
+ setInt(2, paymentData.amount.frac)
+ setString(3, paymentData.wireTransferSubject)
+ setLong(4, paymentData.executionTime.micros())
+ setString(5, paymentData.debitPaytoUri)
+ setString(6, paymentData.bankId)
+ setLong(7, bounceAmount.value)
+ setInt(8, bounceAmount.frac)
+ setLong(9, timestamp.micros())
+ one {
IncomingBounceRegistrationResult(
it.getLong("out_tx_id"),
it.getString("out_bounce_id"),
@@ -116,20 +111,21 @@ class PaymentDAO(private val db: Database) {
suspend fun registerTalerableIncoming(
paymentData: IncomingPayment,
reservePub: EddsaPublicKey
- ): IncomingRegistrationResult = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT out_reserve_pub_reuse, out_found, out_tx_id
- FROM register_incoming_and_talerable((?,?)::taler_amount,?,?,?,?,?)
- """)
+ ): IncomingRegistrationResult = db.serializableWrite(
+ """
+ SELECT out_reserve_pub_reuse, out_found, out_tx_id
+ FROM register_incoming_and_talerable((?,?)::taler_amount,?,?,?,?,?)
+ """
+ ) {
val executionTime = paymentData.executionTime.micros()
- stmt.setLong(1, paymentData.amount.value)
- stmt.setInt(2, paymentData.amount.frac)
- stmt.setString(3, paymentData.wireTransferSubject)
- stmt.setLong(4, executionTime)
- stmt.setString(5, paymentData.debitPaytoUri)
- stmt.setString(6, paymentData.bankId)
- stmt.setBytes(7, reservePub.raw)
- stmt.one {
+ setLong(1, paymentData.amount.value)
+ setInt(2, paymentData.amount.frac)
+ setString(3, paymentData.wireTransferSubject)
+ setLong(4, executionTime)
+ setString(5, paymentData.debitPaytoUri)
+ setString(6, paymentData.bankId)
+ setBytes(7, reservePub.raw)
+ one {
when {
it.getBoolean("out_reserve_pub_reuse") -> IncomingRegistrationResult.ReservePubReuse
else -> IncomingRegistrationResult.Success(
@@ -143,19 +139,20 @@ class PaymentDAO(private val db: Database) {
/** Register an incoming payment */
suspend fun registerIncoming(
paymentData: IncomingPayment
- ): IncomingRegistrationResult.Success = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT out_found, out_tx_id
- FROM register_incoming((?,?)::taler_amount,?,?,?,?)
- """)
+ ): IncomingRegistrationResult.Success = db.serializableWrite(
+ """
+ SELECT out_found, out_tx_id
+ FROM register_incoming((?,?)::taler_amount,?,?,?,?)
+ """
+ ) {
val executionTime = paymentData.executionTime.micros()
- stmt.setLong(1, paymentData.amount.value)
- stmt.setInt(2, paymentData.amount.frac)
- stmt.setString(3, paymentData.wireTransferSubject)
- stmt.setLong(4, executionTime)
- stmt.setString(5, paymentData.debitPaytoUri)
- stmt.setString(6, paymentData.bankId)
- stmt.one {
+ setLong(1, paymentData.amount.value)
+ setInt(2, paymentData.amount.frac)
+ setString(3, paymentData.wireTransferSubject)
+ setLong(4, executionTime)
+ setString(5, paymentData.debitPaytoUri)
+ setString(6, paymentData.bankId)
+ one {
IncomingRegistrationResult.Success(
it.getLong("out_tx_id"),
!it.getBoolean("out_found")
@@ -187,21 +184,22 @@ class PaymentDAO(private val db: Database) {
}
/** List incoming transaction metadata for debugging */
- suspend fun metadataIncoming(): List<IncomingTxMetadata> = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT
- (amount).val as amount_val
- ,(amount).frac AS amount_frac
- ,wire_transfer_subject
- ,execution_time
- ,debit_payto_uri
- ,bank_id
- ,reserve_public_key
- FROM incoming_transactions
- LEFT OUTER JOIN talerable_incoming_transactions using (incoming_transaction_id)
- ORDER BY execution_time
- """)
- stmt.all {
+ suspend fun metadataIncoming(): List<IncomingTxMetadata> = db.serializableRead(
+ """
+ SELECT
+ (amount).val as amount_val
+ ,(amount).frac AS amount_frac
+ ,wire_transfer_subject
+ ,execution_time
+ ,debit_payto_uri
+ ,bank_id
+ ,reserve_public_key
+ FROM incoming_transactions
+ LEFT OUTER JOIN talerable_incoming_transactions using (incoming_transaction_id)
+ ORDER BY execution_time
+ """
+ ) {
+ all {
IncomingTxMetadata(
date = it.getLong("execution_time").asInstant(),
amount = it.getDecimal("amount"),
@@ -214,22 +212,23 @@ class PaymentDAO(private val db: Database) {
}
/** List outgoing transaction metadata for debugging */
- suspend fun metadataOutgoing(): List<OutgoingTxMetadata> = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT
- (amount).val as amount_val
- ,(amount).frac AS amount_frac
- ,wire_transfer_subject
- ,execution_time
- ,credit_payto_uri
- ,message_id
- ,wtid
- ,exchange_base_url
- FROM outgoing_transactions
- LEFT OUTER JOIN talerable_outgoing_transactions using (outgoing_transaction_id)
- ORDER BY execution_time
- """)
- stmt.all {
+ suspend fun metadataOutgoing(): List<OutgoingTxMetadata> = db.serializableRead(
+ """
+ SELECT
+ (amount).val as amount_val
+ ,(amount).frac AS amount_frac
+ ,wire_transfer_subject
+ ,execution_time
+ ,credit_payto_uri
+ ,message_id
+ ,wtid
+ ,exchange_base_url
+ FROM outgoing_transactions
+ LEFT OUTER JOIN talerable_outgoing_transactions using (outgoing_transaction_id)
+ ORDER BY execution_time
+ """
+ ) {
+ all {
OutgoingTxMetadata(
date = it.getLong("execution_time").asInstant(),
amount = it.getDecimal("amount"),
@@ -243,23 +242,24 @@ class PaymentDAO(private val db: Database) {
}
/** List initiated transaction metadata for debugging */
- suspend fun metadataInitiated(): List<InitiatedTxMetadata> = db.conn { conn ->
- val stmt = conn.prepareStatement("""
- SELECT
- (amount).val as amount_val
- ,(amount).frac AS amount_frac
- ,wire_transfer_subject
- ,initiation_time
- ,last_submission_time
- ,submission_counter
- ,credit_payto_uri
- ,submitted
- ,request_uid
- ,failure_message
- FROM initiated_outgoing_transactions
- ORDER BY initiation_time
- """)
- stmt.all {
+ suspend fun metadataInitiated(): List<InitiatedTxMetadata> = db.serializableRead(
+ """
+ SELECT
+ (amount).val as amount_val
+ ,(amount).frac AS amount_frac
+ ,wire_transfer_subject
+ ,initiation_time
+ ,last_submission_time
+ ,submission_counter
+ ,credit_payto_uri
+ ,submitted
+ ,request_uid
+ ,failure_message
+ FROM initiated_outgoing_transactions
+ ORDER BY initiation_time
+ """
+ ) {
+ all {
InitiatedTxMetadata(
date = it.getLong("initiation_time").asInstant(),
amount = it.getDecimal("amount"),