diff options
Diffstat (limited to 'common/src/main/kotlin/db/config.kt')
-rw-r--r-- | common/src/main/kotlin/db/config.kt | 64 |
1 files changed, 13 insertions, 51 deletions
diff --git a/common/src/main/kotlin/db/config.kt b/common/src/main/kotlin/db/config.kt index e7f46263..6f225f05 100644 --- a/common/src/main/kotlin/db/config.kt +++ b/common/src/main/kotlin/db/config.kt @@ -25,6 +25,7 @@ import org.postgresql.jdbc.PgConnection import org.postgresql.util.PSQLState import org.slf4j.Logger import org.slf4j.LoggerFactory +import io.ktor.http.parseQueryString import java.net.URI import java.nio.file.Path import java.sql.PreparedStatement @@ -32,7 +33,7 @@ import java.sql.ResultSet import java.sql.SQLException import kotlin.io.path.Path -fun getCurrentUser(): String = System.getProperty("user.name") +fun currentUser(): String = System.getProperty("user.name") /** * This function converts postgresql:// URIs to JDBC URIs. @@ -43,7 +44,7 @@ fun getCurrentUser(): String = System.getProperty("user.name") * They are especially complex when using unix domain sockets, as they're not really * supported natively by JDBC. */ -fun getJdbcConnectionFromPg(pgConn: String): String { +fun jdbcFromPg(pgConn: String): String { // Pass through jdbc URIs. if (pgConn.startsWith("jdbc:")) { return pgConn @@ -52,55 +53,16 @@ fun getJdbcConnectionFromPg(pgConn: String): String { throw Exception("Not a Postgres connection string: $pgConn") } var maybeUnixSocket = false - val parsed = URI(pgConn) - var hostAsParam: String? = if (parsed.query != null) { - getQueryParam(parsed.query, "host") - } else { - null - } - var pgHost = System.getenv("PGHOST") - if (null == pgHost) - pgHost = parsed.host - var pgPort = System.getenv("PGPORT") - if (null == pgPort) { - if (-1 == parsed.port) - pgPort = "5432" - else - pgPort = parsed.port.toString() - } + val uri = URI(pgConn) + val params = parseQueryString(uri.query ?: "", decode = false) - /** - * In some cases, it is possible to leave the hostname empty - * and specify it via a query param, therefore a "postgresql:///"-starting - * connection string does NOT always mean Unix domain socket. - * https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING - */ - if (pgHost == null && - (hostAsParam == null || hostAsParam.startsWith('/')) - ) { - maybeUnixSocket = true - } - if (pgHost != null && - (pgHost.startsWith('/')) - ) { - maybeUnixSocket = true - } - if (maybeUnixSocket) { - // Check whether the database user should differ from the process user. - var pgUser = getCurrentUser() - if (parsed.query != null) { - val maybeUserParam = getQueryParam(parsed.query, "user") - if (maybeUserParam != null) pgUser = maybeUserParam - } - // Check whether the Unix domain socket location was given non-standard. - if ( (null == hostAsParam) && (null != pgHost) ) - hostAsParam = pgHost + "/.s.PGSQL." + pgPort - val socketLocation = hostAsParam ?: "/var/run/postgresql/.s.PGSQL." + pgPort - if (!socketLocation.startsWith('/')) { - throw Exception("PG connection wants Unix domain socket, but non-null host doesn't start with slash") - } - return "jdbc:postgresql://localhost${parsed.path}?user=$pgUser&socketFactory=org.newsclub.net.unix." + - "AFUNIXSocketFactory\$FactoryArg&socketFactoryArg=$socketLocation" + val host = uri.host ?: params["host"] ?: System.getenv("PGHOST") + if (host == null || host.startsWith('/')) { + val port = (if (uri.port == -1) null else uri.port.toString()) ?: params["port"] ?: System.getenv("PGPORT") ?: "5432" + val user = params["user"] ?: currentUser() + val unixPath = (host ?:"/var/run/postgresql") + "/.s.PGSQL.$port" + return "jdbc:postgresql://localhost${uri.path}?user=$user&socketFactory=org.newsclub.net.unix." + + "AFUNIXSocketFactory\$FactoryArg&socketFactoryArg=$unixPath" } if (pgConn.startsWith("postgres://")) { // The JDBC driver doesn't like postgres://, only postgresql://. @@ -118,7 +80,7 @@ data class DatabaseConfig( ) fun pgDataSource(dbConfig: String): PGSimpleDataSource { - val jdbcConnStr = getJdbcConnectionFromPg(dbConfig) + val jdbcConnStr = jdbcFromPg(dbConfig) logger.debug("connecting to database via JDBC string '$jdbcConnStr'") val pgSource = PGSimpleDataSource() pgSource.setUrl(jdbcConnStr) |