statement.kt (6205B)
1 /* 2 * This file is part of LibEuFin. 3 * Copyright (C) 2025 Taler Systems S.A. 4 * 5 * LibEuFin is free software; you can redistribute it and/or modify 6 * it under the terms of the GNU Affero General Public License as 7 * published by the Free Software Foundation; either version 3, or 8 * (at your option) any later version. 9 * 10 * LibEuFin is distributed in the hope that it will be useful, but 11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General 13 * Public License for more details. 14 * 15 * You should have received a copy of the GNU Affero General Public 16 * License along with LibEuFin; see the file COPYING. If not, see 17 * <http://www.gnu.org/licenses/> 18 */ 19 20 package tech.libeufin.common.db 21 22 import org.slf4j.Logger 23 import org.slf4j.LoggerFactory 24 import org.postgresql.util.PSQLState 25 import tech.libeufin.common.* 26 import java.sql.* 27 import java.time.* 28 import java.util.* 29 30 internal val logger: Logger = LoggerFactory.getLogger("libeufin-db") 31 32 class TalerStatement(internal val stmt: PreparedStatement): java.io.Closeable { 33 override fun close() { 34 // Close inner statement 35 stmt.close() 36 } 37 38 private fun consume() { 39 // Log warnings 40 var current = stmt.getWarnings() 41 while (current != null) { 42 logger.warn(current.message) 43 current = current.getNextWarning() 44 } 45 46 // Reset params 47 stmt.clearParameters() 48 idx=1 49 } 50 51 /* ----- Bindings helpers ----- */ 52 53 private var idx = 1; 54 55 fun bind(string: String?) { 56 stmt.setString(idx, string) 57 idx+=1; 58 } 59 60 fun bind(bool: Boolean) { 61 stmt.setBoolean(idx, bool) 62 idx+=1; 63 } 64 65 fun bind(nb: Long?) { 66 if (nb != null) { 67 stmt.setLong(idx, nb) 68 } else { 69 stmt.setNull(idx, Types.INTEGER) 70 } 71 idx+=1; 72 } 73 74 fun bind(nb: Int) { 75 stmt.setInt(idx, nb) 76 idx+=1; 77 } 78 79 fun bind(amount: TalerAmount?) { 80 bind(amount?.number()) 81 } 82 83 fun bind(nb: DecimalNumber?) { 84 if (nb != null) { 85 stmt.setLong(idx, nb.value) 86 stmt.setInt(idx+1, nb.frac) 87 idx+=2 88 } 89 } 90 91 fun bind(timestamp: Instant) { 92 stmt.setLong(idx, timestamp.micros()) 93 idx+=1 94 } 95 96 fun bind(bytes: Base32Crockford64B?) { 97 stmt.setBytes(idx, bytes?.raw) 98 idx+=1 99 } 100 101 fun bind(bytes: Base32Crockford32B?) { 102 stmt.setBytes(idx, bytes?.raw) 103 idx+=1 104 } 105 106 fun bind(bytes: Base32Crockford16B?) { 107 stmt.setBytes(idx, bytes?.raw) 108 idx+=1 109 } 110 111 112 fun bind(bytes: ByteArray?) { 113 stmt.setBytes(idx, bytes) 114 idx+=1 115 } 116 117 fun <T : kotlin.Enum<T>> bind(enum: T?) { 118 bind(enum?.name) 119 } 120 121 fun bind(date: LocalDateTime) { 122 stmt.setObject(idx, date) 123 idx+=1 124 } 125 126 fun bind(uuid: UUID?) { 127 stmt.setObject(idx, uuid) 128 idx+=1 129 } 130 131 fun <T : Enum<T>> bind(array: Array<T>) { 132 val sqlArray = stmt.connection.createArrayOf("text", array) 133 stmt.setArray(idx, sqlArray) 134 idx+=1 135 } 136 137 fun bind(array: Array<String>) { 138 val sqlArray = stmt.connection.createArrayOf("text", array) 139 stmt.setArray(idx, sqlArray) 140 idx+=1 141 } 142 143 fun bind(array: Array<UUID>) { 144 val sqlArray = stmt.connection.createArrayOf("uuid", array) 145 stmt.setArray(idx, sqlArray) 146 idx+=1 147 } 148 149 /* ----- Transaction helpers ----- */ 150 151 fun executeQuery(): ResultSet { 152 return try { 153 stmt.executeQuery() 154 } finally { 155 consume() 156 } 157 } 158 159 fun executeUpdate(): Int { 160 return try { 161 stmt.executeUpdate() 162 } finally { 163 consume() 164 } 165 } 166 167 /** Read one row or null if none */ 168 fun <T> oneOrNull(lambda: (ResultSet) -> T): T? { 169 return executeQuery().use { 170 if (it.next()) lambda(it) else null 171 } 172 } 173 174 /** Read one row or throw if none */ 175 fun <T> one(lambda: (ResultSet) -> T): T = 176 requireNotNull(oneOrNull(lambda)) { "Missing result to database query" } 177 178 /** Read one row or throw [err] in case or unique violation error */ 179 fun <T> oneUniqueViolation(err: T, lambda: (ResultSet) -> T): T { 180 return try { 181 one(lambda) 182 } catch (e: SQLException) { 183 if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return err 184 throw e // rethrowing, not to hide other types of errors. 185 } 186 } 187 188 /** Read all rows */ 189 fun <T> all(lambda: (ResultSet) -> T): List<T> { 190 return executeQuery().use { 191 val ret = mutableListOf<T>() 192 while (it.next()) { 193 ret.add(lambda(it)) 194 } 195 ret 196 } 197 } 198 199 /** Execute a query checking it return a least one row */ 200 fun executeQueryCheck(): Boolean { 201 return executeQuery().use { 202 it.next() 203 } 204 } 205 206 /** Execute an update checking it update at least one row */ 207 fun executeUpdateCheck(): Boolean { 208 executeUpdate() 209 return stmt.updateCount > 0 210 } 211 212 /** Execute an update checking if fail because of unique violation error */ 213 fun executeUpdateViolation(): Boolean { 214 return try { 215 executeUpdateCheck() 216 } catch (e: SQLException) { 217 logger.debug(e.message) 218 if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false 219 throw e // rethrowing, not to hide other types of errors. 220 } 221 } 222 223 /** Execute an update checking if fail because of unique violation error and resetting state */ 224 fun executeProcedureViolation(): Boolean { 225 val savepoint = stmt.connection.setSavepoint() 226 return try { 227 executeUpdate() 228 stmt.connection.releaseSavepoint(savepoint) 229 true 230 } catch (e: SQLException) { 231 stmt.connection.rollback(savepoint) 232 if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false 233 throw e // rethrowing, not to hide other types of errors. 234 } 235 } 236 }