commit a72483ef8fb3847ac729ed5ef662b2e3fc77d281
parent 8cc2dd0d49cad641eea9c7bc90cfcad79ab36475
Author: Antoine A <>
Date: Mon, 9 Oct 2023 00:03:44 +0000
Reuse database connection logic and fix connection leak
Diffstat:
1 file changed, 53 insertions(+), 62 deletions(-)
diff --git a/bank/src/main/kotlin/tech/libeufin/bank/Database.kt b/bank/src/main/kotlin/tech/libeufin/bank/Database.kt
@@ -21,12 +21,14 @@
package tech.libeufin.bank
import org.postgresql.jdbc.PgConnection
+import org.postgresql.ds.PGSimpleDataSource
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import tech.libeufin.util.getJdbcConnectionFromPg
import tech.libeufin.util.microsToJavaInstant
import tech.libeufin.util.stripIbanPayto
import tech.libeufin.util.toDbMicros
+import tech.libeufin.util.XMLUtil
import java.io.File
import java.sql.*
import java.time.Instant
@@ -62,73 +64,68 @@ private val logger: Logger = LoggerFactory.getLogger("tech.libeufin.bank.Databas
private fun faultyTimestampByBank() = internalServerError("Bank took overflowing timestamp")
private fun faultyDurationByClient() = badRequest("Overflowing duration, please specify 'forever' instead.")
-fun initializeDatabaseTables(dbConfig: String, sqlDir: String) {
- logger.info("doing DB initialization, sqldir $sqlDir, dbConfig $dbConfig")
+private fun pgDataSource(dbConfig: String): PGSimpleDataSource {
val jdbcConnStr = getJdbcConnectionFromPg(dbConfig)
logger.info("connecting to database via JDBC string '$jdbcConnStr'")
- val dbConn = DriverManager.getConnection(jdbcConnStr).unwrap(PgConnection::class.java)
- if (dbConn == null) {
- throw Error("could not open database")
- }
- val sqlVersioning = File("$sqlDir/versioning.sql").readText()
- dbConn.execSQLUpdate(sqlVersioning)
+ val pgSource = PGSimpleDataSource()
+ pgSource.setUrl(jdbcConnStr)
+ pgSource.prepareThreshold = 1
+ return pgSource
+}
- val checkStmt = dbConn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?")
+private fun PGSimpleDataSource.pgConnection(): PgConnection {
+ val conn = getConnection().unwrap(PgConnection::class.java)
+ conn.execSQLUpdate("SET search_path TO libeufin_bank;")
+ return conn
+}
- for (n in 1..9999) {
- val numStr = n.toString().padStart(4, '0')
- val patchName = "libeufin-bank-$numStr"
+fun initializeDatabaseTables(dbConfig: String, sqlDir: String) {
+ logger.info("doing DB initialization, sqldir $sqlDir, dbConfig $dbConfig")
+ pgDataSource(dbConfig).pgConnection().use { conn ->
+ val sqlVersioning = File("$sqlDir/versioning.sql").readText()
+ conn.execSQLUpdate(sqlVersioning)
- checkStmt.setString(1, patchName)
- val res = checkStmt.executeQuery()
- if (!res.next()) {
- throw Error("unable to query patches")
- }
+ val checkStmt = conn.prepareStatement("SELECT count(*) as n FROM _v.patches where patch_name = ?")
- val patchCount = res.getInt("n")
- if (patchCount >= 1) {
- logger.info("patch $patchName already applied")
- continue
- }
+ for (n in 1..9999) {
+ val numStr = n.toString().padStart(4, '0')
+ val patchName = "libeufin-bank-$numStr"
- val path = File("$sqlDir/libeufin-bank-$numStr.sql")
- if (!path.exists()) {
- logger.info("path $path doesn't exist anymore, stopping")
- break
- }
- logger.info("applying patch $path")
- val sqlPatchText = path.readText()
- dbConn.execSQLUpdate(sqlPatchText)
- }
- val sqlProcedures = File("$sqlDir/procedures.sql").readText()
- dbConn.execSQLUpdate(sqlProcedures)
-}
+ checkStmt.setString(1, patchName)
+ val patchCount = checkStmt.oneOrNull { it.getInt(1) } ?: throw Error("unable to query patches");
+ if (patchCount >= 1) {
+ logger.info("patch $patchName already applied")
+ continue
+ }
-private fun countRows(rs: ResultSet): Int {
- var size = 0
- while (rs.next()) {
- size++
+ val path = File("$sqlDir/libeufin-bank-$numStr.sql")
+ if (!path.exists()) {
+ logger.info("path $path doesn't exist anymore, stopping")
+ break
+ }
+ logger.info("applying patch $path")
+ val sqlPatchText = path.readText()
+ conn.execSQLUpdate(sqlPatchText)
+ }
+ val sqlProcedures = File("$sqlDir/procedures.sql").readText()
+ conn.execSQLUpdate(sqlProcedures)
}
- return size
}
fun resetDatabaseTables(dbConfig: String, sqlDir: String) {
logger.info("doing DB initialization, sqldir $sqlDir, dbConfig $dbConfig")
- val jdbcConnStr = getJdbcConnectionFromPg(dbConfig)
- logger.info("connecting to database via JDBC string '$jdbcConnStr'")
- val dbConn = DriverManager.getConnection(jdbcConnStr).unwrap(PgConnection::class.java)
- if (dbConn == null) {
- throw Error("could not open database")
- }
+ pgDataSource(dbConfig).pgConnection().use { conn ->
+ val count = conn.prepareStatement("SELECT count(*) FROM information_schema.schemata WHERE schema_name='_v'").oneOrNull {
+ it.getInt(1)
+ } ?: 0
+ if (count == 0) {
+ logger.info("versioning schema not present, not running drop sql")
+ return
+ }
- val queryRes = dbConn.execSQLQuery("SELECT schema_name FROM information_schema.schemata WHERE schema_name='_v'")
- if (countRows(queryRes) == 0) {
- logger.info("versioning schema not present, not running drop sql")
- return
+ val sqlDrop = File("$sqlDir/libeufin-bank-drop.sql").readText()
+ conn.execSQLUpdate(sqlDrop)
}
-
- val sqlDrop = File("$sqlDir/libeufin-bank-drop.sql").readText()
- dbConn.execSQLUpdate(sqlDrop)
}
@@ -176,12 +173,13 @@ private fun PreparedStatement.executeUpdateViolation(): Boolean {
}
class Database(dbConfig: String, private val bankCurrency: String): java.io.Closeable {
+ private val pgSource: PGSimpleDataSource
private val dbPool: HikariDataSource
- private val jdbcConnStr = getJdbcConnectionFromPg(dbConfig)
init {
+ pgSource = pgDataSource(dbConfig)
val config = HikariConfig();
- config.jdbcUrl = jdbcConnStr
+ config.dataSource = pgSource
config.connectionInitSql = "SET search_path TO libeufin_bank;"
config.validate()
dbPool = HikariDataSource(config);
@@ -196,13 +194,6 @@ class Database(dbConfig: String, private val bankCurrency: String): java.io.Clos
return conn.use(lambda)
}
- /** Create new connection outside the pool */
- private fun freshConn(): PgConnection {
- val conn = DriverManager.getConnection(jdbcConnStr).unwrap(PgConnection::class.java)
- conn?.execSQLUpdate("SET search_path TO libeufin_bank;")
- return conn
- }
-
// CUSTOMERS
/**
* This method INSERTs a new customer into the database and
@@ -819,7 +810,7 @@ class Database(dbConfig: String, private val bankCurrency: String): java.io.Clos
// Only start expensive listening and connection creation if we intend to poll
if (poll_ms > 0) {
- pg = freshConn()
+ pg = pgSource.pgConnection()
conn = pg
pg.execSQLUpdate("LISTEN $channel");
} else {