From 6e8b2bebde52f9f22b0ed8f8399571341b445ed8 Mon Sep 17 00:00:00 2001 From: Antoine A <> Date: Fri, 6 Oct 2023 17:34:31 +0000 Subject: Add database pooling --- bank/build.gradle | 1 + .../src/main/kotlin/tech/libeufin/bank/Database.kt | 498 ++++++++++----------- bank/src/test/kotlin/TalerApiTest.kt | 31 +- 3 files changed, 243 insertions(+), 287 deletions(-) diff --git a/bank/build.gradle b/bank/build.gradle index 8af6f364..e344cd1b 100644 --- a/bank/build.gradle +++ b/bank/build.gradle @@ -54,6 +54,7 @@ dependencies { implementation "org.glassfish.jaxb:jaxb-runtime:2.3.1" implementation 'org.postgresql:postgresql:42.2.27' + implementation 'com.zaxxer:HikariCP:5.0.1' implementation group: 'org.apache.commons', name: 'commons-compress', version: '1.21' implementation('com.github.ajalt:clikt:2.8.0') diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/Database.kt index fecec14c..49406be4 100644 --- a/bank/src/main/kotlin/tech/libeufin/bank/Database.kt +++ b/bank/src/main/kotlin/tech/libeufin/bank/Database.kt @@ -32,6 +32,7 @@ import java.sql.* import java.time.Instant import java.util.* import kotlin.math.abs +import com.zaxxer.hikari.* private const val DB_CTR_LIMIT = 1000000 @@ -174,43 +175,26 @@ private fun PreparedStatement.executeUpdateViolation(): Boolean { } } -class Database(private val dbConfig: String, private val bankCurrency: String) { - private var dbConn: PgConnection? = null - private var dbCtr: Int = 0 - private val preparedStatements: MutableMap = mutableMapOf() +class Database(dbConfig: String, private val bankCurrency: String): java.io.Closeable { + private val dbPool: HikariDataSource init { - Class.forName("org.postgresql.Driver") + val config = HikariConfig(); + config.jdbcUrl = getJdbcConnectionFromPg(dbConfig) + config.driverClassName = "org.postgresql.Driver" + config.maximumPoolSize = 2 + config.connectionInitSql = "SET search_path TO libeufin_bank;" + config.validate() + dbPool = HikariDataSource(config); } - internal fun conn(): PgConnection? { - // Translate "normal" postgresql:// connection URI to something that JDBC likes. - val jdbcConnStr = getJdbcConnectionFromPg(dbConfig) - logger.info("connecting to database via JDBC string '$jdbcConnStr'") - val conn = DriverManager.getConnection(jdbcConnStr).unwrap(PgConnection::class.java) - conn?.execSQLUpdate("SET search_path TO libeufin_bank;") - return conn + override fun close() { + dbPool.close() } - private fun reconnect() { - dbCtr++ - val myDbConn = dbConn - if ((dbCtr < DB_CTR_LIMIT && myDbConn != null) && !(myDbConn.isClosed)) - return - dbConn?.close() - preparedStatements.clear() - dbConn = conn() - dbCtr = 0 - } - - private fun prepare(sql: String): PreparedStatement { - var ps = preparedStatements[sql] - if (ps != null) return ps - val myDbConn = dbConn - if (myDbConn == null) throw internalServerError("DB connection down") - ps = myDbConn.prepareStatement(sql) - preparedStatements[sql] = ps - return ps + private fun conn(lambda: (Connection) -> R): R { + val conn = dbPool.getConnection() + return conn.use(lambda) } // CUSTOMERS @@ -222,9 +206,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * * In case of conflict, this method returns null. */ - fun customerCreate(customer: Customer): Long? { - reconnect() - val stmt = prepare(""" + fun customerCreate(customer: Customer): Long? = conn { conn -> + val stmt = conn.prepareStatement(""" INSERT INTO customers ( login ,password_hash @@ -250,13 +233,14 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.executeQuery() } catch (e: SQLException) { logger.error(e.message) - if (e.errorCode == 0) return null // unique constraint violation. + if (e.errorCode == 0) return@conn null // unique constraint violation. throw e // rethrow on other errors. } res.use { - if (!it.next()) - throw internalServerError("SQL RETURNING gave no customer_id.") - return it.getLong("customer_id") + when { + !it.next() -> throw internalServerError("SQL RETURNING gave no customer_id.") + else -> it.getLong("customer_id") + } } } @@ -264,16 +248,15 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * Deletes a customer (including its bank account row) from * the database. The bank account gets deleted by the cascade. */ - fun customerDeleteIfBalanceIsZero(login: String): CustomerDeletionResult { - reconnect() - val stmt = prepare(""" + fun customerDeleteIfBalanceIsZero(login: String): CustomerDeletionResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_nx_customer, out_balance_not_zero FROM customer_delete(?); """) stmt.setString(1, login) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("Deletion returned nothing.") it.getBoolean("out_nx_customer") -> CustomerDeletionResult.CUSTOMER_NOT_FOUND @@ -284,9 +267,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { } // Mostly used to get customers out of bearer tokens. - fun customerGetFromRowId(customer_id: Long): Customer? { - reconnect() - val stmt = prepare(""" + fun customerGetFromRowId(customer_id: Long): Customer? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT login, password_hash, @@ -299,7 +281,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE customer_id=? """) stmt.setLong(1, customer_id) - return stmt.oneOrNull { + stmt.oneOrNull { Customer( login = it.getString("login"), passwordHash = it.getString("password_hash"), @@ -313,19 +295,17 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { } } - fun customerChangePassword(customerName: String, passwordHash: String): Boolean { - reconnect() - val stmt = prepare(""" + fun customerChangePassword(customerName: String, passwordHash: String): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" UPDATE customers SET password_hash=? where login=? """) stmt.setString(1, passwordHash) stmt.setString(2, customerName) - return stmt.executeUpdateCheck() + stmt.executeUpdateCheck() } - fun customerGetFromLogin(login: String): Customer? { - reconnect() - val stmt = prepare(""" + fun customerGetFromLogin(login: String): Customer? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT customer_id, password_hash, @@ -338,7 +318,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE login=? """) stmt.setString(1, login) - return stmt.oneOrNull { + stmt.oneOrNull { Customer( login = login, passwordHash = it.getString("password_hash"), @@ -355,9 +335,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { // Possibly more "customerGetFrom*()" to come. // BEARER TOKEN - fun bearerTokenCreate(token: BearerToken): Boolean { - reconnect() - val stmt = prepare(""" + fun bearerTokenCreate(token: BearerToken): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" INSERT INTO bearer_tokens (content, creation_time, @@ -374,11 +353,10 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setString(4, token.scope.name) stmt.setLong(5, token.bankCustomer) stmt.setBoolean(6, token.isRefreshable) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } - fun bearerTokenGet(token: ByteArray): BearerToken? { - reconnect() - val stmt = prepare(""" + fun bearerTokenGet(token: ByteArray): BearerToken? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT expiration_time, creation_time, @@ -389,7 +367,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE content=?; """) stmt.setBytes(1, token) - return stmt.oneOrNull { + stmt.oneOrNull { BearerToken( content = token, creationTime = it.getLong("creation_time").microsToJavaInstant() ?: throw faultyTimestampByBank(), @@ -409,15 +387,14 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * if deletion succeeds or false if the token could not be * deleted (= not found). */ - fun bearerTokenDelete(token: ByteArray): Boolean { - reconnect() - val stmt = prepare(""" + fun bearerTokenDelete(token: ByteArray): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" DELETE FROM bearer_tokens WHERE content = ? RETURNING bearer_token_id; """) stmt.setBytes(1, token) - return stmt.executeQueryCheck() + stmt.executeQueryCheck() } // MIXED CUSTOMER AND BANK ACCOUNT DATA @@ -443,9 +420,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { phoneNumber: String?, emailAddress: String?, isTalerExchange: Boolean? - ): AccountReconfigDBResult { - reconnect() - val stmt = prepare(""" + ): AccountReconfigDBResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_nx_customer, out_nx_bank_account @@ -461,7 +437,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setNull(6, Types.NULL) else stmt.setBoolean(6, isTalerExchange) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("accountReconfig() returned nothing") it.getBoolean("out_nx_customer") -> AccountReconfigDBResult.CUSTOMER_NOT_FOUND @@ -478,9 +454,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * * Returns an empty list, if no public account was found. */ - fun accountsGetPublic(internalCurrency: String, loginFilter: String = "%"): List { - reconnect() - val stmt = prepare(""" + fun accountsGetPublic(internalCurrency: String, loginFilter: String = "%"): List = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT (balance).val AS balance_val, (balance).frac AS balance_frac, @@ -492,7 +467,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE is_public=true AND c.login LIKE ?; """) stmt.setString(1, loginFilter) - return stmt.all { + stmt.all { PublicAccount( account_name = it.getString("login"), payto_uri = it.getString("internal_payto_uri"), @@ -518,9 +493,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * LIKE operator. If it's null, it defaults to the "%" wildcard, meaning * that it returns ALL the existing accounts. */ - fun accountsGetForAdmin(nameFilter: String = "%"): List { - reconnect() - val stmt = prepare(""" + fun accountsGetForAdmin(nameFilter: String = "%"): List = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT login, name, @@ -534,7 +508,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE name LIKE ?; """) stmt.setString(1, nameFilter) - return stmt.all { + stmt.all { AccountMinimalData( username = it.getString("login"), name = it.getString("name"), @@ -566,13 +540,12 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * row ID in the successful case. If of unique constrain violation, * it returns null and any other error will be thrown as 500. */ - fun bankAccountCreate(bankAccount: BankAccount): Long? { - reconnect() + fun bankAccountCreate(bankAccount: BankAccount): Long? = conn { conn -> if (bankAccount.balance != null) throw internalServerError( "Do not pass a balance upon bank account creation, do a wire transfer instead." ) - val stmt = prepare(""" + val stmt = conn.prepareStatement(""" INSERT INTO bank_accounts (internal_payto_uri ,owning_customer_id @@ -595,22 +568,22 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.executeQuery() } catch (e: SQLException) { logger.error(e.message) - if (e.errorCode == 0) return null // unique constraint violation. + if (e.errorCode == 0) return@conn null // unique constraint violation. throw e // rethrow on other errors. } res.use { - if (!it.next()) - throw internalServerError("SQL RETURNING gave no bank_account_id.") - return it.getLong("bank_account_id") + when { + !it.next() -> throw internalServerError("SQL RETURNING gave no bank_account_id.") + else -> it.getLong("bank_account_id") + } } } fun bankAccountSetMaxDebt( owningCustomerId: Long, maxDebt: TalerAmount - ): Boolean { - reconnect() - val stmt = prepare(""" + ): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" UPDATE bank_accounts SET max_debt=(?,?)::taler_amount WHERE owning_customer_id=? @@ -618,16 +591,15 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setLong(1, maxDebt.value) stmt.setInt(2, maxDebt.frac) stmt.setLong(3, owningCustomerId) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } private fun getCurrency(): String { return bankCurrency } - fun bankAccountGetFromOwnerId(ownerId: Long): BankAccount? { - reconnect() - val stmt = prepare(""" + fun bankAccountGetFromOwnerId(ownerId: Long): BankAccount? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT internal_payto_uri ,owning_customer_id @@ -645,7 +617,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { """) stmt.setLong(1, ownerId) - return stmt.oneOrNull { + stmt.oneOrNull { BankAccount( internalPaytoUri = it.getString("internal_payto_uri"), balance = TalerAmount( @@ -668,9 +640,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { } } - fun bankAccountGetFromInternalPayto(internalPayto: String): BankAccount? { - reconnect() - val stmt = prepare(""" + fun bankAccountGetFromInternalPayto(internalPayto: String): BankAccount? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT bank_account_id ,owning_customer_id @@ -687,7 +658,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { """) stmt.setString(1, internalPayto) - return stmt.oneOrNull { + stmt.oneOrNull { BankAccount( internalPaytoUri = internalPayto, balance = TalerAmount( @@ -713,9 +684,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { fun bankTransactionCreate( tx: BankInternalTransaction - ): BankTransactionResult { - reconnect() - val stmt = prepare(""" + ): BankTransactionResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_nx_creditor, out_nx_debtor, out_balance_insufficient FROM bank_wire_transfer(?,?,TEXT(?),(?,?)::taler_amount,?,TEXT(?),TEXT(?),TEXT(?)) """ @@ -729,7 +699,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setString(7, tx.accountServicerReference) stmt.setString(8, tx.paymentInformationId) stmt.setString(9, tx.endToEndId) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("Bank transaction didn't properly return") it.getBoolean("out_nx_debtor") -> { @@ -759,21 +729,19 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * * Returns the row ID if found, null otherwise. */ - fun bankTransactionCheckExists(subject: String): Long? { - reconnect() - val stmt = prepare(""" + fun bankTransactionCheckExists(subject: String): Long? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT bank_transaction_id FROM bank_account_transactions WHERE subject = ?; """) stmt.setString(1, subject) - return stmt.oneOrNull { it.getLong("bank_transaction_id") } + stmt.oneOrNull { it.getLong("bank_transaction_id") } } // Get the bank transaction whose row ID is rowId - fun bankTransactionGetFromInternalId(rowId: Long): BankAccountTransaction? { - reconnect() - val stmt = prepare(""" + fun bankTransactionGetFromInternalId(rowId: Long): BankAccountTransaction? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT creditor_payto_uri ,creditor_name @@ -792,7 +760,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE bank_transaction_id=? """) stmt.setLong(1, rowId) - return stmt.oneOrNull { + stmt.oneOrNull { BankAccountTransaction( creditorPaytoUri = it.getString("creditor_payto_uri"), creditorName = it.getString("creditor_name"), @@ -825,133 +793,133 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { bankAccountId: Long, direction: TransactionDirection, map: (BankAccountTransaction) -> T? - ): List { - val conn = conn() ?: throw internalServerError("DB connection down"); - conn.use { - val channel = "${direction.name}_$bankAccountId"; - var start = params.start - var delta = params.delta - var poll_ms = params.poll_ms; - - val (cmpOp, orderBy) = if (delta < 0) Pair("<", "DESC") else Pair(">", "ASC") - val stmt = conn.prepareStatement(""" - SELECT - creditor_payto_uri - ,creditor_name - ,debtor_payto_uri - ,debtor_name - ,subject - ,(amount).val AS amount_val - ,(amount).frac AS amount_frac - ,transaction_date - ,account_servicer_reference - ,payment_information_id - ,end_to_end_id - ,bank_account_id - ,bank_transaction_id - FROM bank_account_transactions - WHERE bank_transaction_id ${cmpOp} ? - AND bank_account_id=? - AND direction=?::direction_enum - ORDER BY bank_transaction_id ${orderBy} - LIMIT ? - """) - - // If going backward with a starting point, it is useless to poll - if (delta < 0 && start != Long.MAX_VALUE) { - poll_ms = 0; - } + ): List = conn { conn -> + // TODO listening for notification is blocking a connection and postgres support a limited amount of connections + // We should use a single connection to listen for notification and dispatch them with kotlin code + val pg = conn.unwrap(PgConnection::class.java) + val channel = "${direction.name}_$bankAccountId"; + var start = params.start + var delta = params.delta + var poll_ms = params.poll_ms; - // Only start expensive listening if we intend to poll - if (poll_ms > 0) { - conn.execSQLUpdate("LISTEN $channel"); - } + val (cmpOp, orderBy) = if (delta < 0) Pair("<", "DESC") else Pair(">", "ASC") + val stmt = conn.prepareStatement(""" + SELECT + creditor_payto_uri + ,creditor_name + ,debtor_payto_uri + ,debtor_name + ,subject + ,(amount).val AS amount_val + ,(amount).frac AS amount_frac + ,transaction_date + ,account_servicer_reference + ,payment_information_id + ,end_to_end_id + ,bank_account_id + ,bank_transaction_id + FROM bank_account_transactions + WHERE bank_transaction_id ${cmpOp} ? + AND bank_account_id=? + AND direction=?::direction_enum + ORDER BY bank_transaction_id ${orderBy} + LIMIT ? + """) - val items = mutableListOf() - - fun bankTransactionGetHistory(): List { - stmt.setLong(1, start) - stmt.setLong(2, bankAccountId) - stmt.setString(3, direction.name) - stmt.setLong(4, abs(delta)) - return stmt.all { - BankAccountTransaction( - creditorPaytoUri = it.getString("creditor_payto_uri"), - creditorName = it.getString("creditor_name"), - debtorPaytoUri = it.getString("debtor_payto_uri"), - debtorName = it.getString("debtor_name"), - amount = TalerAmount( - it.getLong("amount_val"), - it.getInt("amount_frac"), - getCurrency() - ), - accountServicerReference = it.getString("account_servicer_reference"), - endToEndId = it.getString("end_to_end_id"), - direction = direction, - bankAccountId = it.getLong("bank_account_id"), - paymentInformationId = it.getString("payment_information_id"), - subject = it.getString("subject"), - transactionDate = it.getLong("transaction_date").microsToJavaInstant() ?: throw faultyTimestampByBank(), - dbRowId = it.getLong("bank_transaction_id") - ) - } + // If going backward with a starting point, it is useless to poll + if (delta < 0 && start != Long.MAX_VALUE) { + poll_ms = 0; + } + + // Only start expensive listening if we intend to poll + if (poll_ms > 0) { + pg.execSQLUpdate("LISTEN $channel"); + } + + val items = mutableListOf() + + fun bankTransactionGetHistory(): List { + stmt.setLong(1, start) + stmt.setLong(2, bankAccountId) + stmt.setString(3, direction.name) + stmt.setLong(4, abs(delta)) + return stmt.all { + BankAccountTransaction( + creditorPaytoUri = it.getString("creditor_payto_uri"), + creditorName = it.getString("creditor_name"), + debtorPaytoUri = it.getString("debtor_payto_uri"), + debtorName = it.getString("debtor_name"), + amount = TalerAmount( + it.getLong("amount_val"), + it.getInt("amount_frac"), + getCurrency() + ), + accountServicerReference = it.getString("account_servicer_reference"), + endToEndId = it.getString("end_to_end_id"), + direction = direction, + bankAccountId = it.getLong("bank_account_id"), + paymentInformationId = it.getString("payment_information_id"), + subject = it.getString("subject"), + transactionDate = it.getLong("transaction_date").microsToJavaInstant() ?: throw faultyTimestampByBank(), + dbRowId = it.getLong("bank_transaction_id") + ) } + } - fun loadBankHistory() { - while (delta != 0L) { - val history = bankTransactionGetHistory() - if (history.isEmpty()) - break; - history.forEach { - val item = map(it); - // Advance cursor - start = it.expectRowId() - - if (item != null) { - items.add(item) - // Reduce delta - if (delta < 0) delta++ else delta--; - } + fun loadBankHistory() { + while (delta != 0L) { + val history = bankTransactionGetHistory() + if (history.isEmpty()) + break; + history.forEach { + val item = map(it); + // Advance cursor + start = it.expectRowId() + + if (item != null) { + items.add(item) + // Reduce delta + if (delta < 0) delta++ else delta--; } } } + } - loadBankHistory() - - // Long polling - while (delta != 0L && poll_ms > 0) { - var remaining = abs(delta); - do { - val pollStart = System.currentTimeMillis() - logger.debug("POOL") - conn.getNotifications(poll_ms.toInt()).forEach { - val id = it.parameter.toLong() - val new = when { - params.start == Long.MAX_VALUE -> true - delta < 0 -> id < start - else -> id > start - } - logger.debug("NOTIF $id $new") - if (new) remaining -= 1 + loadBankHistory() + + // Long polling + while (delta != 0L && poll_ms > 0) { + var remaining = abs(delta); + do { + val pollStart = System.currentTimeMillis() + logger.debug("POOL") + pg.getNotifications(poll_ms.toInt()).forEach { + val id = it.parameter.toLong() + val new = when { + params.start == Long.MAX_VALUE -> true + delta < 0 -> id < start + else -> id > start } - val pollEnd = System.currentTimeMillis() - poll_ms -= pollEnd - pollStart - } while (poll_ms > 0 && remaining > 0L) - - // If going backward without a starting point, we reset loading progress - if (params.start == Long.MAX_VALUE) { - start = params.start - delta = params.delta - items.clear() + logger.debug("NOTIF $id $new") + if (new) remaining -= 1 } - loadBankHistory() + val pollEnd = System.currentTimeMillis() + poll_ms -= pollEnd - pollStart + } while (poll_ms > 0 && remaining > 0L) + + // If going backward without a starting point, we reset loading progress + if (params.start == Long.MAX_VALUE) { + start = params.start + delta = params.delta + items.clear() } + loadBankHistory() + } - conn.execSQLUpdate("UNLISTEN $channel"); - conn.getNotifications(); // Clear pending notifications + pg.execSQLUpdate("UNLISTEN $channel"); + pg.getNotifications(); // Clear pending notifications - return items.toList(); - } + items.toList() } /** @@ -966,10 +934,9 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { delta: Long, bankAccountId: Long, withDirection: TransactionDirection? = null - ): List { - reconnect() + ): List = conn { conn -> val (cmpOp, orderBy) = if (delta < 0) Pair("<", "DESC") else Pair(">", "ASC") - val stmt = prepare(""" + val stmt = conn.prepareStatement(""" SELECT creditor_payto_uri ,creditor_name @@ -1005,7 +972,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { else 3 stmt.setLong(limitParamIndex, abs(delta)) - return stmt.all { + stmt.all { val direction = withDirection ?: when (it.getString("direction")) { "credit" -> TransactionDirection.credit "debit" -> TransactionDirection.debit @@ -1038,9 +1005,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { opUUID: UUID, walletBankAccount: Long, amount: TalerAmount - ): Boolean { - reconnect() - val stmt = prepare(""" + ): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" INSERT INTO taler_withdrawal_operations (withdrawal_uuid, wallet_bank_account, amount) @@ -1050,11 +1016,10 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setLong(2, walletBankAccount) stmt.setLong(3, amount.value) stmt.setInt(4, amount.frac) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } - fun talerWithdrawalGet(opUUID: UUID): TalerWithdrawalOperation? { - reconnect() - val stmt = prepare(""" + fun talerWithdrawalGet(opUUID: UUID): TalerWithdrawalOperation? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT (amount).val as amount_val ,(amount).frac as amount_frac @@ -1069,7 +1034,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE withdrawal_uuid=? """) stmt.setObject(1, opUUID) - return stmt.oneOrNull { + stmt.oneOrNull { TalerWithdrawalOperation( amount = TalerAmount( it.getLong("amount_val"), @@ -1091,9 +1056,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * Aborts one Taler withdrawal, only if it wasn't previously * confirmed. It returns false if the UPDATE didn't succeed. */ - fun talerWithdrawalAbort(opUUID: UUID): Boolean { - reconnect() - val stmt = prepare(""" + fun talerWithdrawalAbort(opUUID: UUID): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" UPDATE taler_withdrawal_operations SET aborted = true WHERE withdrawal_uuid=? AND confirmation_done = false @@ -1101,7 +1065,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { """ ) stmt.setObject(1, opUUID) - return stmt.executeQueryCheck() + stmt.executeQueryCheck() } /** @@ -1115,9 +1079,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { opUuid: UUID, exchangePayto: String, reservePub: String - ): Boolean { - reconnect() - val stmt = prepare(""" + ): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" UPDATE taler_withdrawal_operations SET selected_exchange_payto = ?, reserve_pub = ?, selection_done = true WHERE withdrawal_uuid=? @@ -1126,7 +1089,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setString(1, exchangePayto) stmt.setString(2, reservePub) stmt.setObject(3, opUuid) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } /** @@ -1139,9 +1102,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { accountServicerReference: String = "NOT-USED", endToEndId: String = "NOT-USED", paymentInfId: String = "NOT-USED" - ): WithdrawalConfirmationResult { - reconnect() - val stmt = prepare(""" + ): WithdrawalConfirmationResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_nx_op, out_nx_exchange, @@ -1155,7 +1117,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setString(3, accountServicerReference) stmt.setString(4, endToEndId) stmt.setString(5, paymentInfId) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("No result from DB procedure confirm_taler_withdrawal") @@ -1171,9 +1133,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { /** * Creates a cashout operation in the database. */ - fun cashoutCreate(op: Cashout): Boolean { - reconnect() - val stmt = prepare(""" + fun cashoutCreate(op: Cashout): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" INSERT INTO cashout_operations ( cashout_uuid ,amount_debit @@ -1225,7 +1186,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setLong(16, op.bankAccount) stmt.setString(17, op.credit_payto_uri) stmt.setString(18, op.cashoutCurrency) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } /** @@ -1237,9 +1198,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { opUuid: UUID, tanConfirmationTimestamp: Long, bankTransaction: Long // regional payment backing the operation - ): Boolean { - reconnect() - val stmt = prepare(""" + ): Boolean = conn { conn -> + val stmt = conn.prepareStatement(""" UPDATE cashout_operations SET tan_confirmation_time = ?, local_transaction = ? WHERE cashout_uuid=?; @@ -1247,7 +1207,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setLong(1, tanConfirmationTimestamp) stmt.setLong(2, bankTransaction) stmt.setObject(3, opUuid) - return stmt.executeUpdateViolation() + stmt.executeUpdateViolation() } /** @@ -1261,13 +1221,13 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { /** * Deletes a cashout operation from the database. */ - fun cashoutDelete(opUuid: UUID): CashoutDeleteResult { - val stmt = prepare(""" + fun cashoutDelete(opUuid: UUID): CashoutDeleteResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_already_confirmed FROM cashout_delete(?) """) stmt.setObject(1, opUuid) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("Cashout deletion gave no result") it.getBoolean("out_already_confirmed") -> CashoutDeleteResult.CONFLICT_ALREADY_CONFIRMED @@ -1280,8 +1240,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { * Gets a cashout operation from the database, according * to its uuid. */ - fun cashoutGetFromUuid(opUuid: UUID): Cashout? { - val stmt = prepare(""" + fun cashoutGetFromUuid(opUuid: UUID): Cashout? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT (amount_debit).val as amount_debit_val ,(amount_debit).frac as amount_debit_frac @@ -1306,7 +1266,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE cashout_uuid=?; """) stmt.setObject(1, opUuid) - return stmt.oneOrNull { + stmt.oneOrNull { Cashout( amountDebit = TalerAmount( value = it.getLong("amount_debit_val"), @@ -1368,9 +1328,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { /** * Gets a Taler transfer request, given its UID. */ - fun talerTransferGetFromUid(requestUid: String): TalerTransferFromDb? { - reconnect() - val stmt = prepare(""" + fun talerTransferGetFromUid(requestUid: String): TalerTransferFromDb? = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT wtid ,exchange_base_url @@ -1385,7 +1344,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { WHERE request_uid = ?; """) stmt.setString(1, requestUid) - return stmt.oneOrNull { + stmt.oneOrNull { TalerTransferFromDb( wtid = it.getString("wtid"), amount = TalerAmount( @@ -1433,9 +1392,8 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { acctSvcrRef: String = "not used", pmtInfId: String = "not used", endToEndId: String = "not used", - ): TalerTransferCreationResult { - reconnect() - val stmt = prepare(""" + ): TalerTransferCreationResult = conn { conn -> + val stmt = conn.prepareStatement(""" SELECT out_exchange_balance_insufficient ,out_nx_creditor @@ -1467,7 +1425,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { stmt.setString(10, pmtInfId) stmt.setString(11, endToEndId) - return stmt.executeQuery().use { + stmt.executeQuery().use { when { !it.next() -> throw internalServerError("SQL function taler_transfer did not return anything.") @@ -1477,7 +1435,7 @@ class Database(private val dbConfig: String, private val bankCurrency: String) { TalerTransferCreationResult(BankTransactionResult.CONFLICT) else -> { val txRowId = it.getLong("out_tx_row_id") - return TalerTransferCreationResult( + TalerTransferCreationResult( txResult = BankTransactionResult.SUCCESS, txRowId = txRowId ) diff --git a/bank/src/test/kotlin/TalerApiTest.kt b/bank/src/test/kotlin/TalerApiTest.kt index 615b5f21..5c749988 100644 --- a/bank/src/test/kotlin/TalerApiTest.kt +++ b/bank/src/test/kotlin/TalerApiTest.kt @@ -67,15 +67,17 @@ class TalerApiTest { ) } - fun commonSetup(): Pair { + fun commonSetup(lambda: (Database, BankApplicationContext) -> Unit){ val db = initDb() val ctx = getTestContext() - // Creating the exchange and merchant accounts first. - assertNotNull(db.customerCreate(customerFoo)) - assertNotNull(db.bankAccountCreate(bankAccountFoo)) - assertNotNull(db.customerCreate(customerBar)) - assertNotNull(db.bankAccountCreate(bankAccountBar)) - return Pair(db, ctx) + db.use { + // Creating the exchange and merchant accounts first. + assertNotNull(db.customerCreate(customerFoo)) + assertNotNull(db.bankAccountCreate(bankAccountFoo)) + assertNotNull(db.customerCreate(customerBar)) + assertNotNull(db.bankAccountCreate(bankAccountBar)) + lambda(db, ctx) + } } // Test endpoint is correctly authenticated @@ -103,8 +105,7 @@ class TalerApiTest { // Testing the POST /transfer call from the TWG API. @Test - fun transfer() { - val (db, ctx) = commonSetup() + fun transfer() = commonSetup { db, ctx -> // Do POST /transfer. testApplication { application { @@ -212,8 +213,7 @@ class TalerApiTest { * Testing the /history/incoming call from the TWG API. */ @Test - fun historyIncoming() { - val (db, ctx) = commonSetup() + fun historyIncoming() = commonSetup { db, ctx -> // Give Foo reasonable debt allowance: assert( db.bankAccountSetMaxDebt( @@ -354,8 +354,7 @@ class TalerApiTest { * Testing the /history/outgoing call from the TWG API. */ @Test - fun historyOutgoing() { - val (db, ctx) = commonSetup() + fun historyOutgoing() = commonSetup { db, ctx -> // Give Bar reasonable debt allowance: assert( db.bankAccountSetMaxDebt( @@ -492,8 +491,7 @@ class TalerApiTest { // Testing the /admin/add-incoming call from the TWG API. @Test - fun addIncoming() { - val (db, ctx) = commonSetup() + fun addIncoming() = commonSetup { db, ctx -> // Give Bar reasonable debt allowance: assert(db.bankAccountSetMaxDebt( 2L, @@ -635,8 +633,7 @@ class TalerApiTest { } // Testing withdrawal confirmation @Test - fun withdrawalConfirmation() { - val (db, ctx) = commonSetup() + fun withdrawalConfirmation() = commonSetup { db, ctx -> // Artificially making a withdrawal operation for Foo. val uuid = UUID.randomUUID() -- cgit v1.2.3