commit 94cb68f30e1b0a6d5fdb61d1917c8f08337b1f7a
parent 12cbce2f9bdadb00e3a4e60ab4fc861ad5426ae8
Author: MS <ms@taler.net>
Date: Wed, 18 Oct 2023 16:15:28 +0200
Moving DB init/drop logic to util/.
Diffstat:
6 files changed, 110 insertions(+), 102 deletions(-)
diff --git a/bank/build.gradle b/bank/build.gradle
@@ -76,6 +76,8 @@ dependencies {
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.5.21'
testImplementation group: "junit", name: "junit", version: '4.13.2'
+ testImplementation project(":util")
+
// UNIX domain sockets support (used to connect to PostgreSQL)
implementation 'com.kohlschutter.junixsocket:junixsocket-core:2.6.2'
}
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Config.kt b/bank/src/main/kotlin/tech/libeufin/bank/Config.kt
@@ -24,15 +24,11 @@ import TalerConfigError
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import kotlinx.serialization.json.Json
+import tech.libeufin.util.DatabaseConfig
private val logger: Logger = LoggerFactory.getLogger("tech.libeufin.bank.Config")
private val BANK_CONFIG_SOURCE = ConfigSource("libeufin-bank", "libeufin-bank")
-data class DatabaseConfig(
- val dbConnStr: String,
- val sqlDir: String
-)
-
data class ServerConfig(
val method: String,
val port: Int
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/Database.kt
@@ -23,11 +23,6 @@ import org.postgresql.jdbc.PgConnection
import org.postgresql.ds.PGSimpleDataSource
import org.slf4j.Logger
import org.slf4j.LoggerFactory
-import tech.libeufin.util.getJdbcConnectionFromPg
-import tech.libeufin.util.microsToJavaInstant
-import tech.libeufin.util.stripIbanPayto
-import tech.libeufin.util.toDbMicros
-import tech.libeufin.util.XMLUtil
import java.io.File
import java.sql.*
import java.time.Instant
@@ -37,6 +32,7 @@ import kotlin.math.abs
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.*
import com.zaxxer.hikari.*
+import tech.libeufin.util.*
private const val DB_CTR_LIMIT = 1000000
@@ -66,94 +62,6 @@ private val logger: Logger = LoggerFactory.getLogger("tech.libeufin.bank.Databas
private fun faultyTimestampByBank() = internalServerError("Bank took overflowing timestamp")
private fun faultyDurationByClient() = badRequest("Overflowing duration, please specify 'forever' instead.")
-private fun pgDataSource(dbConfig: String): PGSimpleDataSource {
- val jdbcConnStr = getJdbcConnectionFromPg(dbConfig)
- logger.info("connecting to database via JDBC string '$jdbcConnStr'")
- val pgSource = PGSimpleDataSource()
- pgSource.setUrl(jdbcConnStr)
- pgSource.prepareThreshold = 1
- return pgSource
-}
-
-private fun PGSimpleDataSource.pgConnection(): PgConnection {
- val conn = connection.unwrap(PgConnection::class.java)
- conn.execSQLUpdate("SET search_path TO libeufin_bank;")
- return conn
-}
-
-private fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R {
- try {
- setAutoCommit(false);
- val result = lambda(this)
- commit();
- setAutoCommit(true);
- return result
- } catch(e: Exception){
- rollback();
- setAutoCommit(true);
- throw e;
- }
-}
-
-fun initializeDatabaseTables(cfg: DatabaseConfig) {
- logger.info("doing DB initialization, sqldir ${cfg.sqlDir}, dbConnStr ${cfg.dbConnStr}")
- pgDataSource(cfg.dbConnStr).pgConnection().use { conn ->
- conn.transaction {
- val sqlVersioning = File("${cfg.sqlDir}/versioning.sql").readText()
- conn.execSQLUpdate(sqlVersioning)
-
- val checkStmt = conn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?")
-
- for (n in 1..9999) {
- val numStr = n.toString().padStart(4, '0')
- val patchName = "libeufin-bank-$numStr"
-
- checkStmt.setString(1, patchName)
- val patchCount = checkStmt.oneOrNull { it.getInt(1) } ?: throw Error("unable to query patches");
- if (patchCount >= 1) {
- logger.info("patch $patchName already applied")
- continue
- }
-
- val path = File("${cfg.sqlDir}/libeufin-bank-$numStr.sql")
- if (!path.exists()) {
- logger.info("path $path doesn't exist anymore, stopping")
- break
- }
- logger.info("applying patch $path")
- val sqlPatchText = path.readText()
- conn.execSQLUpdate(sqlPatchText)
- }
- val sqlProcedures = File("${cfg.sqlDir}/procedures.sql").readText()
- conn.execSQLUpdate(sqlProcedures)
- }
- }
-}
-
-fun resetDatabaseTables(cfg: DatabaseConfig) {
- logger.info("reset DB, sqldir ${cfg.sqlDir}, dbConnStr ${cfg.dbConnStr}")
- pgDataSource(cfg.dbConnStr).pgConnection().use { conn ->
- val count = conn.prepareStatement("SELECT count(*) FROM information_schema.schemata WHERE schema_name='_v'").oneOrNull {
- it.getInt(1)
- } ?: 0
- if (count == 0) {
- logger.info("versioning schema not present, not running drop sql")
- return
- }
-
- val sqlDrop = File("${cfg.sqlDir}/libeufin-bank-drop.sql").readText()
- conn.execSQLUpdate(sqlDrop) // TODO can fail ?
- }
-}
-
-
-private fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? {
- executeQuery().use {
- if (!it.next()) return null
- return lambda(it)
- }
-}
-
private fun <T> PreparedStatement.all(lambda: (ResultSet) -> T): List<T> {
executeQuery().use {
val ret = mutableListOf<T>()
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Main.kt b/bank/src/main/kotlin/tech/libeufin/bank/Main.kt
@@ -58,6 +58,8 @@ import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import tech.libeufin.util.CryptoUtil
import tech.libeufin.util.getVersion
+import tech.libeufin.util.initializeDatabaseTables
+import tech.libeufin.util.resetDatabaseTables
import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit
@@ -345,9 +347,9 @@ class BankDbInit : CliktCommand("Initialize the libeufin-bank database", name =
override fun run() {
val cfg = talerConfig(configFile).loadDbConfig()
if (requestReset) {
- resetDatabaseTables(cfg)
+ resetDatabaseTables(cfg, sqlFilePrefix = "libeufin-bank")
}
- initializeDatabaseTables(cfg)
+ initializeDatabaseTables(cfg, sqlFilePrefix = "libeufin-bank")
}
}
diff --git a/bank/src/test/kotlin/helpers.kt b/bank/src/test/kotlin/helpers.kt
@@ -10,6 +10,7 @@ import tech.libeufin.bank.*
import java.io.ByteArrayOutputStream
import java.util.zip.DeflaterOutputStream
import tech.libeufin.util.CryptoUtil
+import tech.libeufin.util.*
/* ----- Setup ----- */
@@ -81,8 +82,8 @@ fun setup(
){
val config = talerConfig("conf/$conf")
val dbCfg = config.loadDbConfig()
- resetDatabaseTables(dbCfg)
- initializeDatabaseTables(dbCfg)
+ resetDatabaseTables(dbCfg, "libeufin-bank")
+ initializeDatabaseTables(dbCfg, "libeufin-bank")
val ctx = config.loadBankApplicationContext()
Database(dbCfg.dbConnStr, ctx.currency).use {
runBlocking {
diff --git a/util/src/main/kotlin/DB.kt b/util/src/main/kotlin/DB.kt
@@ -28,10 +28,14 @@ import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.name
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.transactions.transaction
+import org.postgresql.ds.PGSimpleDataSource
import org.postgresql.jdbc.PgConnection
import org.slf4j.Logger
import org.slf4j.LoggerFactory
+import java.io.File
import java.net.URI
+import java.sql.PreparedStatement
+import java.sql.ResultSet
fun getCurrentUser(): String = System.getProperty("user.name")
@@ -323,3 +327,98 @@ fun getJdbcConnectionFromPg(pgConn: String): String {
}
return "jdbc:$pgConn"
}
+
+
+data class DatabaseConfig(
+ val dbConnStr: String,
+ val sqlDir: String
+)
+
+fun pgDataSource(dbConfig: String): PGSimpleDataSource {
+ val jdbcConnStr = getJdbcConnectionFromPg(dbConfig)
+ logger.info("connecting to database via JDBC string '$jdbcConnStr'")
+ val pgSource = PGSimpleDataSource()
+ pgSource.setUrl(jdbcConnStr)
+ pgSource.prepareThreshold = 1
+ return pgSource
+}
+
+fun PGSimpleDataSource.pgConnection(): PgConnection {
+ val conn = connection.unwrap(PgConnection::class.java)
+ conn.execSQLUpdate("SET search_path TO libeufin_bank;")
+ return conn
+}
+
+fun <R> PgConnection.transaction(lambda: (PgConnection) -> R): R {
+ try {
+ setAutoCommit(false);
+ val result = lambda(this)
+ commit();
+ setAutoCommit(true);
+ return result
+ } catch(e: Exception){
+ rollback();
+ setAutoCommit(true);
+ throw e;
+ }
+}
+
+fun <T> PreparedStatement.oneOrNull(lambda: (ResultSet) -> T): T? {
+ executeQuery().use {
+ if (!it.next()) return null
+ return lambda(it)
+ }
+}
+
+// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash).
+fun initializeDatabaseTables(cfg: DatabaseConfig, sqlFilePrefix: String) {
+ logger.info("doing DB initialization, sqldir ${cfg.sqlDir}, dbConnStr ${cfg.dbConnStr}")
+ pgDataSource(cfg.dbConnStr).pgConnection().use { conn ->
+ conn.transaction {
+ val sqlVersioning = File("${cfg.sqlDir}/versioning.sql").readText()
+ conn.execSQLUpdate(sqlVersioning)
+
+ val checkStmt = conn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?")
+
+ for (n in 1..9999) {
+ val numStr = n.toString().padStart(4, '0')
+ val patchName = "$sqlFilePrefix-$numStr"
+
+ checkStmt.setString(1, patchName)
+ val patchCount = checkStmt.oneOrNull { it.getInt(1) } ?: throw Error("unable to query patches");
+ if (patchCount >= 1) {
+ logger.info("patch $patchName already applied")
+ continue
+ }
+
+ val path = File("${cfg.sqlDir}/$sqlFilePrefix-$numStr.sql")
+ if (!path.exists()) {
+ logger.info("path $path doesn't exist anymore, stopping")
+ break
+ }
+ logger.info("applying patch $path")
+ val sqlPatchText = path.readText()
+ conn.execSQLUpdate(sqlPatchText)
+ }
+ val sqlProcedures = File("${cfg.sqlDir}/procedures.sql").readText()
+ conn.execSQLUpdate(sqlProcedures)
+ }
+ }
+}
+
+// sqlFilePrefix is, for example, "libeufin-bank" or "libeufin-nexus" (no trailing dash).
+fun resetDatabaseTables(cfg: DatabaseConfig, sqlFilePrefix: String) {
+ logger.info("reset DB, sqldir ${cfg.sqlDir}, dbConnStr ${cfg.dbConnStr}")
+ pgDataSource(cfg.dbConnStr).pgConnection().use { conn ->
+ val count = conn.prepareStatement("SELECT count(*) FROM information_schema.schemata WHERE schema_name='_v'").oneOrNull {
+ it.getInt(1)
+ } ?: 0
+ if (count == 0) {
+ logger.info("versioning schema not present, not running drop sql")
+ return
+ }
+
+ val sqlDrop = File("${cfg.sqlDir}/$sqlFilePrefix-drop.sql").readText()
+ conn.execSQLUpdate(sqlDrop) // TODO can fail ?
+ }
+}