commit 661e1b32344c1d837cb7cced498f3cbde1cebe16 parent b88be146117455f50bfc2c1b084278aefa1dcc78 Author: Antoine A <> Date: Wed, 3 Apr 2024 15:15:40 +0200 Refactor db logic Diffstat:
32 files changed, 587 insertions(+), 362 deletions(-)
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Config.kt b/bank/src/main/kotlin/tech/libeufin/bank/Config.kt @@ -23,6 +23,7 @@ import kotlinx.serialization.json.Json import org.slf4j.Logger import org.slf4j.LoggerFactory import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.nio.file.Path import java.time.Duration diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Main.kt b/bank/src/main/kotlin/tech/libeufin/bank/Main.kt @@ -54,6 +54,7 @@ import org.slf4j.event.Level import tech.libeufin.bank.db.AccountDAO.* import tech.libeufin.bank.db.Database import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.net.InetAddress import java.sql.SQLException import java.util.zip.DataFormatException diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/AccountDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import tech.libeufin.common.crypto.* import java.time.Instant diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/CashoutDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant /** Data access logic for cashout operations */ diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ConversionDAO.kt @@ -24,9 +24,7 @@ import tech.libeufin.bank.DecimalNumber import tech.libeufin.bank.RoundingMode import tech.libeufin.bank.internalServerError import tech.libeufin.common.TalerAmount -import tech.libeufin.common.getAmount -import tech.libeufin.common.oneOrNull -import tech.libeufin.common.transaction +import tech.libeufin.common.db.* /** Data access logic for conversion */ class ConversionDAO(private val db: Database) { diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/Database.kt @@ -28,6 +28,7 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.sql.PreparedStatement import java.sql.ResultSet import java.sql.Types diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/ExchangeDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant /** Data access logic for exchange specific logic */ diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/GcDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.bank.db import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import tech.libeufin.common.crypto.* import java.time.Instant import java.time.Duration diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/NotificationWatcher.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/NotificationWatcher.kt @@ -26,6 +26,7 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.util.* import java.util.concurrent.ConcurrentHashMap diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TanDAO.kt @@ -22,8 +22,8 @@ package tech.libeufin.bank.db import tech.libeufin.bank.Operation import tech.libeufin.bank.TanChannel import tech.libeufin.bank.internalServerError -import tech.libeufin.common.oneOrNull import tech.libeufin.common.micros +import tech.libeufin.common.db.* import java.time.Duration import java.time.Instant import java.util.concurrent.TimeUnit diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TokenDAO.kt @@ -21,10 +21,9 @@ package tech.libeufin.bank.db import tech.libeufin.bank.BearerToken import tech.libeufin.bank.TokenScope -import tech.libeufin.common.executeUpdateViolation import tech.libeufin.common.asInstant -import tech.libeufin.common.oneOrNull import tech.libeufin.common.micros +import tech.libeufin.common.db.* import java.time.Instant /** Data access logic for auth tokens */ diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/TransactionDAO.kt @@ -23,6 +23,7 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant private val logger: Logger = LoggerFactory.getLogger("libeufin-bank-tx-dao") diff --git a/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt b/bank/src/main/kotlin/tech/libeufin/bank/db/WithdrawalDAO.kt @@ -25,6 +25,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeoutOrNull import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.util.* diff --git a/bank/src/test/kotlin/AmountTest.kt b/bank/src/test/kotlin/AmountTest.kt @@ -22,6 +22,7 @@ import tech.libeufin.bank.DecimalNumber import tech.libeufin.bank.db.TransactionDAO.BankTransactionResult import tech.libeufin.bank.db.WithdrawalDAO.WithdrawalCreationResult import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.util.* import kotlin.test.assertEquals diff --git a/bank/src/test/kotlin/DatabaseTest.kt b/bank/src/test/kotlin/DatabaseTest.kt @@ -23,6 +23,7 @@ import org.junit.Test import tech.libeufin.bank.createAdminAccount import tech.libeufin.bank.db.AccountDAO.AccountCreationResult import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Duration import java.time.Instant import java.time.temporal.ChronoUnit diff --git a/bank/src/test/kotlin/GcTest.kt b/bank/src/test/kotlin/GcTest.kt @@ -25,6 +25,7 @@ import tech.libeufin.bank.db.TransactionDAO.* import tech.libeufin.bank.db.CashoutDAO.CashoutCreationResult import tech.libeufin.bank.db.ExchangeDAO.TransferResult import tech.libeufin.common.* +import tech.libeufin.common.db.* import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* diff --git a/bank/src/test/kotlin/StatsTest.kt b/bank/src/test/kotlin/StatsTest.kt @@ -23,6 +23,7 @@ import tech.libeufin.bank.MonitorResponse import tech.libeufin.bank.MonitorWithConversion import tech.libeufin.bank.Timeframe import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.time.LocalDateTime import kotlin.test.assertEquals diff --git a/bank/src/test/kotlin/helpers.kt b/bank/src/test/kotlin/helpers.kt @@ -27,6 +27,7 @@ import tech.libeufin.bank.* import tech.libeufin.bank.db.AccountDAO.AccountCreationResult import tech.libeufin.bank.db.Database import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.nio.file.NoSuchFileException import kotlin.io.path.Path import kotlin.io.path.deleteExisting diff --git a/common/src/main/kotlin/DB.kt b/common/src/main/kotlin/DB.kt @@ -1,355 +0,0 @@ -/* - * This file is part of LibEuFin. - * Copyright (C) 2023 Taler Systems S.A. - * - * LibEuFin is free software; you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation; either version 3, or - * (at your option) any later version. - * - * LibEuFin is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY - * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General - * Public License for more details. - * - * You should have received a copy of the GNU Affero General Public - * License along with LibEuFin; see the file COPYING. If not, see - * <http://www.gnu.org/licenses/> - */ - -package tech.libeufin.common - -import com.zaxxer.hikari.HikariConfig -import com.zaxxer.hikari.HikariDataSource -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import org.postgresql.ds.PGSimpleDataSource -import org.postgresql.jdbc.PgConnection -import org.postgresql.util.PSQLState -import org.slf4j.Logger -import org.slf4j.LoggerFactory -import java.net.URI -import java.nio.file.Path -import java.sql.PreparedStatement -import java.sql.ResultSet -import java.sql.SQLException -import kotlin.io.path.Path -import kotlin.io.path.exists -import kotlin.io.path.readText - -fun getCurrentUser(): String = System.getProperty("user.name") - -private val logger: Logger = LoggerFactory.getLogger("libeufin-db") - -// Check GANA (https://docs.gnunet.org/gana/index.html) for numbers allowance. - -/** - * This function converts postgresql:// URIs to JDBC URIs. - * - * URIs that are already jdbc: URIs are passed through. - * - * This avoids the user having to create complex JDBC URIs for postgres connections. - * They are especially complex when using unix domain sockets, as they're not really - * supported natively by JDBC. - */ -fun getJdbcConnectionFromPg(pgConn: String): String { - // Pass through jdbc URIs. - if (pgConn.startsWith("jdbc:")) { - return pgConn - } - if (!pgConn.startsWith("postgresql://") && !pgConn.startsWith("postgres://")) { - throw Exception("Not a Postgres connection string: $pgConn") - } - var maybeUnixSocket = false - val parsed = URI(pgConn) - var hostAsParam: String? = if (parsed.query != null) { - getQueryParam(parsed.query, "host") - } else { - null - } - var pgHost = System.getenv("PGHOST") - if (null == pgHost) - pgHost = parsed.host - var pgPort = System.getenv("PGPORT") - if (null == pgPort) { - if (-1 == parsed.port) - pgPort = "5432" - else - pgPort = parsed.port.toString() - } - - /** - * In some cases, it is possible to leave the hostname empty - * and specify it via a query param, therefore a "postgresql:///"-starting - * connection string does NOT always mean Unix domain socket. - * https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING - */ - if (pgHost == null && - (hostAsParam == null || hostAsParam.startsWith('/')) - ) { - maybeUnixSocket = true - } - if (pgHost != null && - (pgHost.startsWith('/')) - ) { - maybeUnixSocket = true - } - if (maybeUnixSocket) { - // Check whether the database user should differ from the process user. - var pgUser = getCurrentUser() - if (parsed.query != null) { - val maybeUserParam = getQueryParam(parsed.query, "user") - if (maybeUserParam != null) pgUser = maybeUserParam - } - // Check whether the Unix domain socket location was given non-standard. - if ( (null == hostAsParam) && (null != pgHost) ) - hostAsParam = pgHost + "/.s.PGSQL." + pgPort - val socketLocation = hostAsParam ?: "/var/run/postgresql/.s.PGSQL." + pgPort - if (!socketLocation.startsWith('/')) { - throw Exception("PG connection wants Unix domain socket, but non-null host doesn't start with slash") - } - return "jdbc:postgresql://localhost${parsed.path}?user=$pgUser&socketFactory=org.newsclub.net.unix." + - "AFUNIXSocketFactory\$FactoryArg&socketFactoryArg=$socketLocation" - } - if (pgConn.startsWith("postgres://")) { - // The JDBC driver doesn't like postgres://, only postgresql://. - // For consistency with other components, we normalize the postgres:// URI - // into one that the JDBC driver likes. - return "jdbc:postgresql://" + pgConn.removePrefix("postgres://") - } - logger.info("connecting to database via JDBC string '$pgConn'") - return "jdbc:$pgConn" -} - -data class DatabaseConfig( - val dbConnStr: String, - val sqlDir: Path -) - -fun pgDataSource(dbConfig: String): PGSimpleDataSource { - val jdbcConnStr = getJdbcConnectionFromPg(dbConfig) - logger.debug("connecting to database via JDBC string '$jdbcConnStr'") - val pgSource = PGSimpleDataSource() - pgSource.setUrl(jdbcConnStr) - pgSource.prepareThreshold = 1 - return pgSource -} - -fun PGSimpleDataSource.pgConnection(schema: String? = null): PgConnection { - val conn = connection.unwrap(PgConnection::class.java) - if (schema != null) conn.execSQLUpdate("SET search_path TO $schema") - return conn -} - -fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R { - try { - autoCommit = false - val result = lambda(this) - commit() - autoCommit = true - return result - } catch (e: Exception) { - rollback() - autoCommit = true - throw e - } -} - -fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? { - executeQuery().use { - return if (it.next()) lambda(it) else null - } -} - -fun <T> PreparedStatement.one(lambda: (ResultSet) -> T): T = - requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } - -fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> { - executeQuery().use { - val ret = mutableListOf<T>() - while (it.next()) { - ret.add(lambda(it)) - } - return ret - } -} - -fun PreparedStatement.executeQueryCheck(): Boolean { - executeQuery().use { - return it.next() - } -} - -fun PreparedStatement.executeUpdateCheck(): Boolean { - executeUpdate() - return updateCount > 0 -} - -/** - * Helper that returns false if the row to be inserted - * hits a unique key constraint violation, true when it - * succeeds. Any other error (re)throws exception. - */ -fun PreparedStatement.executeUpdateViolation(): Boolean { - return try { - executeUpdateCheck() - } catch (e: SQLException) { - logger.debug(e.message) - if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false - throw e // rethrowing, not to hide other types of errors. - } -} - -fun PreparedStatement.executeProcedureViolation(): Boolean { - val savepoint = connection.setSavepoint() - return try { - executeUpdate() - connection.releaseSavepoint(savepoint) - true - } catch (e: SQLException) { - connection.rollback(savepoint) - if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false - throw e // rethrowing, not to hide other types of errors. - } -} - -// TODO comment -fun PgConnection.dynamicUpdate( - table: String, - fields: Sequence<String>, - filter: String, - bind: Sequence<Any?>, -) { - val sql = fields.joinToString() - if (sql.isEmpty()) return - prepareStatement("UPDATE $table SET $sql $filter").run { - for ((idx, value) in bind.withIndex()) { - setObject(idx + 1, value) - } - executeUpdate() - } -} - -/** - * Only runs versioning.sql if the _v schema is not found. - * - * @param conn database connection - * @param cfg database configuration - */ -fun maybeApplyV(conn: PgConnection, cfg: DatabaseConfig) { - conn.transaction { - val checkVSchema = conn.prepareStatement( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '_v'" - ) - if (!checkVSchema.executeQueryCheck()) { - logger.debug("_v schema not found, applying versioning.sql") - val sqlVersioning = Path("${cfg.sqlDir}/versioning.sql").readText() - conn.execSQLUpdate(sqlVersioning) - } - } -} - -// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash). -fun initializeDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sqlFilePrefix: String) { - logger.info("doing DB initialization, sqldir ${cfg.sqlDir}") - maybeApplyV(conn, cfg) - conn.transaction { - val checkStmt = conn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?") - - for (n in 1..9999) { - val numStr = n.toString().padStart(4, '0') - val patchName = "$sqlFilePrefix-$numStr" - - checkStmt.setString(1, patchName) - val patchCount = checkStmt.oneOrNull { it.getInt(1) } ?: throw Exception("unable to query patches") - if (patchCount >= 1) { - logger.debug("patch $patchName already applied") - continue - } - - val path = Path("${cfg.sqlDir}/$sqlFilePrefix-$numStr.sql") - if (!path.exists()) { - logger.debug("path $path doesn't exist anymore, stopping") - break - } - logger.info("applying patch $path") - val sqlPatchText = path.readText() - conn.execSQLUpdate(sqlPatchText) - } - val sqlProcedures = Path("${cfg.sqlDir}/$sqlFilePrefix-procedures.sql") - if (!sqlProcedures.exists()) { - logger.warn("no procedures.sql for the SQL collection: $sqlFilePrefix") - return@transaction - } - logger.info("run procedure.sql") - conn.execSQLUpdate(sqlProcedures.readText()) - } -} - -// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash). -fun resetDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sqlFilePrefix: String) { - logger.info("reset DB, sqldir ${cfg.sqlDir}") - val sqlDrop = Path("${cfg.sqlDir}/$sqlFilePrefix-drop.sql").readText() - conn.execSQLUpdate(sqlDrop) -} - -open class DbPool(cfg: String, schema: String) : java.io.Closeable { - val pgSource = pgDataSource(cfg) - private val pool: HikariDataSource - - init { - val config = HikariConfig() - config.dataSource = pgSource - config.schema = schema - config.transactionIsolation = "TRANSACTION_SERIALIZABLE" - pool = HikariDataSource(config) - pool.connection.use { con -> - val meta = con.metaData - val majorVersion = meta.databaseMajorVersion - val minorVersion = meta.databaseMinorVersion - if (majorVersion < MIN_VERSION) { - throw Exception("postgres version must be at least $MIN_VERSION.0 got $majorVersion.$minorVersion") - } - } - } - - suspend fun <R> conn(lambda: suspend (PgConnection) -> R): R { - // Use a coroutine dispatcher that we can block as JDBC API is blocking - return withContext(Dispatchers.IO) { - pool.connection.use { lambda(it.unwrap(PgConnection::class.java)) } - } - } - - suspend fun <R> serializable(lambda: suspend (PgConnection) -> R): R = conn { conn -> - repeat(SERIALIZATION_RETRY) { - try { - return@conn lambda(conn) - } catch (e: SQLException) { - if (e.sqlState != PSQLState.SERIALIZATION_FAILURE.state) - throw e - } - } - try { - return@conn lambda(conn) - } catch (e: SQLException) { - logger.warn("Serialization failure after $SERIALIZATION_RETRY retry") - throw e - } - } - - override fun close() { - pool.close() - } -} - -fun ResultSet.getAmount(name: String, currency: String): TalerAmount { - return TalerAmount( - getLong("${name}_val"), - getInt("${name}_frac"), - currency - ) -} - -fun ResultSet.getBankPayto(payto: String, name: String, ctx: BankPaytoCtx): String { - return Payto.parse(getString(payto)).bank(getString(name), ctx) -} -\ No newline at end of file diff --git a/common/src/main/kotlin/db/DbPool.kt b/common/src/main/kotlin/db/DbPool.kt @@ -0,0 +1,78 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2024 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import tech.libeufin.common.* +import org.postgresql.jdbc.PgConnection +import org.postgresql.util.PSQLState +import java.sql.SQLException +import com.zaxxer.hikari.HikariConfig +import com.zaxxer.hikari.HikariDataSource +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +open class DbPool(cfg: String, schema: String) : java.io.Closeable { + val pgSource = pgDataSource(cfg) + private val pool: HikariDataSource + + init { + val config = HikariConfig() + config.dataSource = pgSource + config.schema = schema + config.transactionIsolation = "TRANSACTION_SERIALIZABLE" + pool = HikariDataSource(config) + pool.connection.use { con -> + val meta = con.metaData + val majorVersion = meta.databaseMajorVersion + val minorVersion = meta.databaseMinorVersion + if (majorVersion < MIN_VERSION) { + throw Exception("postgres version must be at least $MIN_VERSION.0 got $majorVersion.$minorVersion") + } + } + } + + suspend fun <R> conn(lambda: suspend (PgConnection) -> R): R { + // Use a coroutine dispatcher that we can block as JDBC API is blocking + return withContext(Dispatchers.IO) { + pool.connection.use { lambda(it.unwrap(PgConnection::class.java)) } + } + } + + suspend fun <R> serializable(lambda: suspend (PgConnection) -> R): R = conn { conn -> + repeat(SERIALIZATION_RETRY) { + try { + return@conn lambda(conn) + } catch (e: SQLException) { + if (e.sqlState != PSQLState.SERIALIZATION_FAILURE.state) + throw e + } + } + try { + return@conn lambda(conn) + } catch (e: SQLException) { + logger.warn("Serialization failure after $SERIALIZATION_RETRY retry") + throw e + } + } + + override fun close() { + pool.close() + } +} +\ No newline at end of file diff --git a/common/src/main/kotlin/db/config.kt b/common/src/main/kotlin/db/config.kt @@ -0,0 +1,133 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2024 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import tech.libeufin.common.* +import org.postgresql.ds.PGSimpleDataSource +import org.postgresql.jdbc.PgConnection +import org.postgresql.util.PSQLState +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.net.URI +import java.nio.file.Path +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.SQLException +import kotlin.io.path.Path + +fun getCurrentUser(): String = System.getProperty("user.name") + +/** + * This function converts postgresql:// URIs to JDBC URIs. + * + * URIs that are already jdbc: URIs are passed through. + * + * This avoids the user having to create complex JDBC URIs for postgres connections. + * They are especially complex when using unix domain sockets, as they're not really + * supported natively by JDBC. + */ +fun getJdbcConnectionFromPg(pgConn: String): String { + // Pass through jdbc URIs. + if (pgConn.startsWith("jdbc:")) { + return pgConn + } + if (!pgConn.startsWith("postgresql://") && !pgConn.startsWith("postgres://")) { + throw Exception("Not a Postgres connection string: $pgConn") + } + var maybeUnixSocket = false + val parsed = URI(pgConn) + var hostAsParam: String? = if (parsed.query != null) { + getQueryParam(parsed.query, "host") + } else { + null + } + var pgHost = System.getenv("PGHOST") + if (null == pgHost) + pgHost = parsed.host + var pgPort = System.getenv("PGPORT") + if (null == pgPort) { + if (-1 == parsed.port) + pgPort = "5432" + else + pgPort = parsed.port.toString() + } + + /** + * In some cases, it is possible to leave the hostname empty + * and specify it via a query param, therefore a "postgresql:///"-starting + * connection string does NOT always mean Unix domain socket. + * https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + */ + if (pgHost == null && + (hostAsParam == null || hostAsParam.startsWith('/')) + ) { + maybeUnixSocket = true + } + if (pgHost != null && + (pgHost.startsWith('/')) + ) { + maybeUnixSocket = true + } + if (maybeUnixSocket) { + // Check whether the database user should differ from the process user. + var pgUser = getCurrentUser() + if (parsed.query != null) { + val maybeUserParam = getQueryParam(parsed.query, "user") + if (maybeUserParam != null) pgUser = maybeUserParam + } + // Check whether the Unix domain socket location was given non-standard. + if ( (null == hostAsParam) && (null != pgHost) ) + hostAsParam = pgHost + "/.s.PGSQL." + pgPort + val socketLocation = hostAsParam ?: "/var/run/postgresql/.s.PGSQL." + pgPort + if (!socketLocation.startsWith('/')) { + throw Exception("PG connection wants Unix domain socket, but non-null host doesn't start with slash") + } + return "jdbc:postgresql://localhost${parsed.path}?user=$pgUser&socketFactory=org.newsclub.net.unix." + + "AFUNIXSocketFactory\$FactoryArg&socketFactoryArg=$socketLocation" + } + if (pgConn.startsWith("postgres://")) { + // The JDBC driver doesn't like postgres://, only postgresql://. + // For consistency with other components, we normalize the postgres:// URI + // into one that the JDBC driver likes. + return "jdbc:postgresql://" + pgConn.removePrefix("postgres://") + } + logger.info("connecting to database via JDBC string '$pgConn'") + return "jdbc:$pgConn" +} + +data class DatabaseConfig( + val dbConnStr: String, + val sqlDir: Path +) + +fun pgDataSource(dbConfig: String): PGSimpleDataSource { + val jdbcConnStr = getJdbcConnectionFromPg(dbConfig) + logger.debug("connecting to database via JDBC string '$jdbcConnStr'") + val pgSource = PGSimpleDataSource() + pgSource.setUrl(jdbcConnStr) + pgSource.prepareThreshold = 1 + return pgSource +} + +fun PGSimpleDataSource.pgConnection(schema: String? = null): PgConnection { + val conn = connection.unwrap(PgConnection::class.java) + if (schema != null) conn.execSQLUpdate("SET search_path TO $schema") + return conn +} diff --git a/common/src/main/kotlin/db/schema.kt b/common/src/main/kotlin/db/schema.kt @@ -0,0 +1,89 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2024 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import tech.libeufin.common.* +import org.postgresql.jdbc.PgConnection +import kotlin.io.path.Path +import kotlin.io.path.exists +import kotlin.io.path.readText + +/** + * Only runs versioning.sql if the _v schema is not found. + * + * @param conn database connection + * @param cfg database configuration + */ +fun maybeApplyV(conn: PgConnection, cfg: DatabaseConfig) { + conn.transaction { + val checkVSchema = conn.prepareStatement( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '_v'" + ) + if (!checkVSchema.executeQueryCheck()) { + logger.debug("_v schema not found, applying versioning.sql") + val sqlVersioning = Path("${cfg.sqlDir}/versioning.sql").readText() + conn.execSQLUpdate(sqlVersioning) + } + } +} + +// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash). +fun initializeDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sqlFilePrefix: String) { + logger.info("doing DB initialization, sqldir ${cfg.sqlDir}") + maybeApplyV(conn, cfg) + conn.transaction { + val checkStmt = conn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?") + + for (n in 1..9999) { + val numStr = n.toString().padStart(4, '0') + val patchName = "$sqlFilePrefix-$numStr" + + checkStmt.setString(1, patchName) + val patchCount = checkStmt.oneOrNull { it.getInt(1) } ?: throw Exception("unable to query patches") + if (patchCount >= 1) { + logger.debug("patch $patchName already applied") + continue + } + + val path = Path("${cfg.sqlDir}/$sqlFilePrefix-$numStr.sql") + if (!path.exists()) { + logger.debug("path $path doesn't exist anymore, stopping") + break + } + logger.info("applying patch $path") + val sqlPatchText = path.readText() + conn.execSQLUpdate(sqlPatchText) + } + val sqlProcedures = Path("${cfg.sqlDir}/$sqlFilePrefix-procedures.sql") + if (!sqlProcedures.exists()) { + logger.warn("no procedures.sql for the SQL collection: $sqlFilePrefix") + return@transaction + } + logger.info("run procedure.sql") + conn.execSQLUpdate(sqlProcedures.readText()) + } +} + +// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash). +fun resetDatabaseTables(conn: PgConnection, cfg: DatabaseConfig, sqlFilePrefix: String) { + logger.info("reset DB, sqldir ${cfg.sqlDir}") + val sqlDrop = Path("${cfg.sqlDir}/$sqlFilePrefix-drop.sql").readText() + conn.execSQLUpdate(sqlDrop) +} +\ No newline at end of file diff --git a/common/src/main/kotlin/db/types.kt b/common/src/main/kotlin/db/types.kt @@ -0,0 +1,35 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2024 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import tech.libeufin.common.* +import java.sql.ResultSet + +fun ResultSet.getAmount(name: String, currency: String): TalerAmount { + return TalerAmount( + getLong("${name}_val"), + getInt("${name}_frac"), + currency + ) +} + +fun ResultSet.getBankPayto(payto: String, name: String, ctx: BankPaytoCtx): String { + return Payto.parse(getString(payto)).bank(getString(name), ctx) +} +\ No newline at end of file diff --git a/common/src/main/kotlin/db/utils.kt b/common/src/main/kotlin/db/utils.kt @@ -0,0 +1,222 @@ +/* + * This file is part of LibEuFin. + * Copyright (C) 2024 Taler Systems S.A. + * + * LibEuFin is free software; you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation; either version 3, or + * (at your option) any later version. + * + * LibEuFin is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General + * Public License for more details. + * + * You should have received a copy of the GNU Affero General Public + * License along with LibEuFin; see the file COPYING. If not, see + * <http://www.gnu.org/licenses/> + */ + +package tech.libeufin.common.db + +import tech.libeufin.common.* +import org.postgresql.ds.PGSimpleDataSource +import org.postgresql.jdbc.PgConnection +import org.postgresql.util.PSQLState +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.net.URI +import java.nio.file.Path +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.SQLException +import kotlin.io.path.Path + +internal val logger: Logger = LoggerFactory.getLogger("libeufin-db") + +/** + * This function converts postgresql:// URIs to JDBC URIs. + * + * URIs that are already jdbc: URIs are passed through. + * + * This avoids the user having to create complex JDBC URIs for postgres connections. + * They are especially complex when using unix domain sockets, as they're not really + * supported natively by JDBC. + */ +fun getJdbcConnectionFromPg(pgConn: String): String { + // Pass through jdbc URIs. + if (pgConn.startsWith("jdbc:")) { + return pgConn + } + if (!pgConn.startsWith("postgresql://") && !pgConn.startsWith("postgres://")) { + throw Exception("Not a Postgres connection string: $pgConn") + } + var maybeUnixSocket = false + val parsed = URI(pgConn) + var hostAsParam: String? = if (parsed.query != null) { + getQueryParam(parsed.query, "host") + } else { + null + } + var pgHost = System.getenv("PGHOST") + if (null == pgHost) + pgHost = parsed.host + var pgPort = System.getenv("PGPORT") + if (null == pgPort) { + if (-1 == parsed.port) + pgPort = "5432" + else + pgPort = parsed.port.toString() + } + + /** + * In some cases, it is possible to leave the hostname empty + * and specify it via a query param, therefore a "postgresql:///"-starting + * connection string does NOT always mean Unix domain socket. + * https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + */ + if (pgHost == null && + (hostAsParam == null || hostAsParam.startsWith('/')) + ) { + maybeUnixSocket = true + } + if (pgHost != null && + (pgHost.startsWith('/')) + ) { + maybeUnixSocket = true + } + if (maybeUnixSocket) { + // Check whether the database user should differ from the process user. + var pgUser = getCurrentUser() + if (parsed.query != null) { + val maybeUserParam = getQueryParam(parsed.query, "user") + if (maybeUserParam != null) pgUser = maybeUserParam + } + // Check whether the Unix domain socket location was given non-standard. + if ( (null == hostAsParam) && (null != pgHost) ) + hostAsParam = pgHost + "/.s.PGSQL." + pgPort + val socketLocation = hostAsParam ?: "/var/run/postgresql/.s.PGSQL." + pgPort + if (!socketLocation.startsWith('/')) { + throw Exception("PG connection wants Unix domain socket, but non-null host doesn't start with slash") + } + return "jdbc:postgresql://localhost${parsed.path}?user=$pgUser&socketFactory=org.newsclub.net.unix." + + "AFUNIXSocketFactory\$FactoryArg&socketFactoryArg=$socketLocation" + } + if (pgConn.startsWith("postgres://")) { + // The JDBC driver doesn't like postgres://, only postgresql://. + // For consistency with other components, we normalize the postgres:// URI + // into one that the JDBC driver likes. + return "jdbc:postgresql://" + pgConn.removePrefix("postgres://") + } + logger.info("connecting to database via JDBC string '$pgConn'") + return "jdbc:$pgConn" +} + +data class DatabaseConfig( + val dbConnStr: String, + val sqlDir: Path +) + +fun pgDataSource(dbConfig: String): PGSimpleDataSource { + val jdbcConnStr = getJdbcConnectionFromPg(dbConfig) + logger.debug("connecting to database via JDBC string '$jdbcConnStr'") + val pgSource = PGSimpleDataSource() + pgSource.setUrl(jdbcConnStr) + pgSource.prepareThreshold = 1 + return pgSource +} + +fun PGSimpleDataSource.pgConnection(schema: String? = null): PgConnection { + val conn = connection.unwrap(PgConnection::class.java) + if (schema != null) conn.execSQLUpdate("SET search_path TO $schema") + return conn +} + +fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R { + try { + autoCommit = false + val result = lambda(this) + commit() + autoCommit = true + return result + } catch (e: Exception) { + rollback() + autoCommit = true + throw e + } +} + +fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? { + executeQuery().use { + return if (it.next()) lambda(it) else null + } +} + +fun <T> PreparedStatement.one(lambda: (ResultSet) -> T): T = + requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } + +fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> { + executeQuery().use { + val ret = mutableListOf<T>() + while (it.next()) { + ret.add(lambda(it)) + } + return ret + } +} + +fun PreparedStatement.executeQueryCheck(): Boolean { + executeQuery().use { + return it.next() + } +} + +fun PreparedStatement.executeUpdateCheck(): Boolean { + executeUpdate() + return updateCount > 0 +} + +/** + * Helper that returns false if the row to be inserted + * hits a unique key constraint violation, true when it + * succeeds. Any other error (re)throws exception. + */ +fun PreparedStatement.executeUpdateViolation(): Boolean { + return try { + executeUpdateCheck() + } catch (e: SQLException) { + logger.debug(e.message) + if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false + throw e // rethrowing, not to hide other types of errors. + } +} + +fun PreparedStatement.executeProcedureViolation(): Boolean { + val savepoint = connection.setSavepoint() + return try { + executeUpdate() + connection.releaseSavepoint(savepoint) + true + } catch (e: SQLException) { + connection.rollback(savepoint) + if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false + throw e // rethrowing, not to hide other types of errors. + } +} + +// TODO comment +fun PgConnection.dynamicUpdate( + table: String, + fields: Sequence<String>, + filter: String, + bind: Sequence<Any?>, +) { + val sql = fields.joinToString() + if (sql.isEmpty()) return + prepareStatement("UPDATE $table SET $sql $filter").run { + for ((idx, value) in bind.withIndex()) { + setObject(idx + 1, value) + } + executeUpdate() + } +} +\ No newline at end of file diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/DbInit.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/DbInit.kt @@ -23,6 +23,7 @@ import com.github.ajalt.clikt.parameters.groups.provideDelegate import com.github.ajalt.clikt.parameters.options.flag import com.github.ajalt.clikt.parameters.options.option import tech.libeufin.common.* +import tech.libeufin.common.db.* /** * This subcommand tries to load the SQL files that define diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/Main.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/Main.kt @@ -36,6 +36,7 @@ import kotlinx.coroutines.* import org.slf4j.Logger import org.slf4j.LoggerFactory import tech.libeufin.common.* +import tech.libeufin.common.db.* import tech.libeufin.nexus.ebics.* import tech.libeufin.nexus.db.* import java.nio.file.Path diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/Database.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/Database.kt @@ -20,6 +20,7 @@ package tech.libeufin.nexus.db import org.postgresql.util.PSQLState import tech.libeufin.common.* +import tech.libeufin.common.db.* import tech.libeufin.nexus.* import java.sql.PreparedStatement import java.sql.SQLException diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/InitiatedDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.nexus.db import tech.libeufin.nexus.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.sql.ResultSet diff --git a/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt b/nexus/src/main/kotlin/tech/libeufin/nexus/db/PaymentDAO.kt @@ -21,6 +21,7 @@ package tech.libeufin.nexus.db import tech.libeufin.nexus.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant /** Data access logic for incoming & outgoing payments */ diff --git a/nexus/src/test/kotlin/helpers.kt b/nexus/src/test/kotlin/helpers.kt @@ -22,6 +22,7 @@ import io.ktor.client.engine.mock.* import io.ktor.client.request.* import kotlinx.coroutines.runBlocking import tech.libeufin.common.* +import tech.libeufin.common.db.* import tech.libeufin.nexus.* import tech.libeufin.nexus.db.* import java.time.Instant diff --git a/testbench/src/test/kotlin/IntegrationTest.kt b/testbench/src/test/kotlin/IntegrationTest.kt @@ -23,6 +23,7 @@ import tech.libeufin.nexus.* import tech.libeufin.nexus.db.Database as NexusDb import tech.libeufin.bank.db.AccountDAO.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.util.Arrays import java.sql.SQLException diff --git a/testbench/src/test/kotlin/MigrationTest.kt b/testbench/src/test/kotlin/MigrationTest.kt @@ -23,6 +23,7 @@ import tech.libeufin.bank.db.WithdrawalDAO.WithdrawalCreationResult import tech.libeufin.bank.db.* import tech.libeufin.bank.* import tech.libeufin.common.* +import tech.libeufin.common.db.* import java.time.Instant import java.util.* import org.postgresql.jdbc.PgConnection