libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

server.kt (13196B)


      1 /*
      2  * This file is part of LibEuFin.
      3  * Copyright (C) 2024-2025 Taler Systems S.A.
      4 
      5  * LibEuFin is free software; you can redistribute it and/or modify
      6  * it under the terms of the GNU Affero General Public License as
      7  * published by the Free Software Foundation; either version 3, or
      8  * (at your option) any later version.
      9 
     10  * LibEuFin is distributed in the hope that it will be useful, but
     11  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
     12  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
     13  * Public License for more details.
     14 
     15  * You should have received a copy of the GNU Affero General Public
     16  * License along with LibEuFin; see the file COPYING.  If not, see
     17  * <http://www.gnu.org/licenses/>
     18  */
     19 
     20 package tech.libeufin.common.api
     21 
     22 import io.ktor.http.*
     23 import io.ktor.serialization.kotlinx.json.*
     24 import io.ktor.server.application.*
     25 import io.ktor.server.engine.*
     26 import io.ktor.server.cio.*
     27 import io.ktor.server.plugins.*
     28 import io.ktor.server.plugins.calllogging.*
     29 import io.ktor.server.plugins.contentnegotiation.*
     30 import io.ktor.server.plugins.forwardedheaders.*
     31 import io.ktor.server.plugins.statuspages.*
     32 import io.ktor.server.plugins.callid.*
     33 import io.ktor.server.request.*
     34 import io.ktor.server.response.*
     35 import io.ktor.server.routing.*
     36 import io.ktor.utils.io.*
     37 import io.ktor.util.*
     38 import io.ktor.util.pipeline.*
     39 import io.ktor.http.content.*
     40 import kotlinx.serialization.ExperimentalSerializationApi
     41 import kotlinx.serialization.json.Json
     42 import org.postgresql.util.PSQLState
     43 import org.slf4j.Logger
     44 import org.slf4j.event.Level
     45 import tech.libeufin.common.*
     46 import tech.libeufin.common.db.SERIALIZATION_ERROR
     47 import java.net.InetAddress
     48 import java.sql.SQLException
     49 import java.util.zip.DataFormatException
     50 import java.util.zip.Inflater
     51 
     52 /** Used to store the raw body */
     53 private val RAW_BODY = AttributeKey<ByteArray>("RAW_BODY")
     54 
     55 /** Used to set custom body limit */
     56 val BODY_LIMIT = AttributeKey<Int>("BODY_LIMIT")
     57 
     58 /** Get call raw body */
     59 val ApplicationCall.rawBody: ByteArray get() = attributes.getOrNull(RAW_BODY) ?: ByteArray(0)
     60 
     61 /**
     62  * This plugin apply Taler specific logic
     63  * It checks for body length limit and inflates the requests that have "Content-Encoding: deflate"
     64  * It logs incoming requests and their details
     65  */
     66 fun talerPlugin(logger: Logger): ApplicationPlugin<Unit> {
     67     return createApplicationPlugin("TalerPlugin") {
     68         onCall { call ->
     69             // Handle CORS
     70             call.response.header(HttpHeaders.AccessControlAllowOrigin, "*")
     71             // Handle CORS preflight
     72             if (call.request.httpMethod == HttpMethod.Options) {
     73                 call.response.header(HttpHeaders.AccessControlAllowHeaders, "*")
     74                 call.response.header(HttpHeaders.AccessControlAllowMethods, "*")
     75                 call.respond(HttpStatusCode.NoContent)
     76                 return@onCall
     77             }
     78 
     79             // Log incoming transaction
     80             val requestCall = buildString {
     81                 val path = call.request.path()
     82                 append(call.request.httpMethod.value)
     83                 append(' ')
     84                 append(call.request.path())
     85                 val query = call.request.queryString()
     86                 if (query.isNotEmpty()) {
     87                     append('?')
     88                     append(query)
     89                 }
     90             }
     91             logger.info(requestCall)
     92         }
     93         onCallReceive { call ->
     94             val bodyLimit = call.attributes.getOrNull(BODY_LIMIT) ?: MAX_BODY_LENGTH
     95             // Check content length if present and wellformed
     96             val contentLenght = call.request.headers[HttpHeaders.ContentLength]?.toIntOrNull()
     97             if (contentLenght != null && contentLenght > bodyLimit)
     98                 throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")
     99 
    100             // Else check while reading and decompressing the body
    101             transformBody { body ->
    102                 val bytes = ByteArray(bodyLimit + 1)
    103                 var read = 0
    104                 when (val encoding = call.request.headers[HttpHeaders.ContentEncoding])  {
    105                     "deflate" -> {
    106                         // Decompress and check decompressed length
    107                         val inflater = Inflater()
    108                         while (!body.isClosedForRead) {
    109                             body.read { buf ->
    110                                 inflater.setInput(buf)
    111                                 try {
    112                                     read += inflater.inflate(bytes, read, bytes.size - read)
    113                                 } catch (e: DataFormatException) {
    114                                     logger.error("Deflated request failed to inflate: ${e.message}")
    115                                     throw badRequest(
    116                                         "Could not inflate request",
    117                                         TalerErrorCode.GENERIC_COMPRESSION_INVALID
    118                                     )
    119                                 }
    120                             }
    121                             if (read > bodyLimit)
    122                                 throw bodyOverflow("Decompressed body is suspiciously big > ${bodyLimit}B")
    123                         }
    124                     }
    125                     null -> {
    126                         // Check body length
    127                         while (true) {
    128                             val new = body.readAvailable(bytes, read, bytes.size - read)
    129                             if (new == -1) break // Channel is closed
    130                             read += new
    131                             if (read > bodyLimit)
    132                                 throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")
    133                         }
    134                     } 
    135                     else -> throw unsupportedMediaType(
    136                         "Content encoding '$encoding' not supported, expected plain or deflate",
    137                         TalerErrorCode.GENERIC_COMPRESSION_INVALID
    138                     )
    139                 }
    140                 logger.trace {
    141                     "request ${bytes.sliceArray(0 until read).asUtf8()}"
    142                 }
    143                 call.attributes.put(RAW_BODY, bytes)
    144                 ByteReadChannel(bytes, 0, read)
    145             }
    146         }
    147     }
    148 }
    149 
    150 /** Set up web server handlers for a Taler API */
    151 fun Application.talerApi(logger: Logger, routes: Routing.() -> Unit) {
    152     install(CallId) {
    153         generate(10, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
    154         verify { true }
    155     }
    156     install(CallLogging) {
    157         callIdMdc("call-id")
    158         level = Level.INFO
    159         this.logger = logger
    160         format { call ->
    161             val status = call.response.status()
    162             val msg = call.logMsg()
    163             if (msg != null) {
    164                 "${status?.value} ${call.processingTimeMillis()}ms: $msg"
    165             } else {
    166                 "${status?.value} ${call.processingTimeMillis()}ms"
    167             }
    168         }
    169     }
    170     install(XForwardedHeaders)
    171     install(talerPlugin(logger))
    172     install(IgnoreTrailingSlash)
    173     install(ContentNegotiation) {
    174         json(Json {
    175             @OptIn(ExperimentalSerializationApi::class)
    176             explicitNulls = false
    177             encodeDefaults = true
    178             ignoreUnknownKeys = true
    179         })
    180     }
    181     install(StatusPages) {
    182         status(HttpStatusCode.NotFound) { call, status ->
    183             call.err(
    184                 status,
    185                 "There is no endpoint defined for the URL provided by the client. Check if you used the correct URL and/or file a report with the developers of the client software.",
    186                 TalerErrorCode.GENERIC_ENDPOINT_UNKNOWN,
    187                 null
    188             )
    189         }
    190         status(HttpStatusCode.MethodNotAllowed) { call, status ->
    191             call.err(
    192                 status,
    193                 "The HTTP method used is invalid for this endpoint. This is likely a bug in the client implementation. Check if you are using the latest available version and/or file a report with the developers.",
    194                 TalerErrorCode.GENERIC_METHOD_INVALID,
    195                 null
    196             )
    197         }
    198         exception<Exception> { call, cause ->
    199             logger.debug("", cause)
    200             when (cause) {
    201                 is ApiException -> call.err(cause, null)
    202                 is SQLException -> {
    203                     if (SERIALIZATION_ERROR.contains(cause.sqlState)) {
    204                         call.err(
    205                             HttpStatusCode.InternalServerError,
    206                             "Transaction serialization failure",
    207                             TalerErrorCode.BANK_SOFT_EXCEPTION,
    208                             cause
    209                         )
    210                     } else {
    211                         call.err(
    212                             HttpStatusCode.InternalServerError,
    213                             "Unexpected sql error with state ${cause.sqlState}",
    214                             TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
    215                             cause
    216                         )
    217                     }
    218                 }
    219                 is BadRequestException -> {
    220                     /**
    221                      * NOTE: extracting the root cause helps with JSON error messages,
    222                      * because they mention the particular way they are invalid, but OTOH
    223                      * it loses (by getting null) other error messages, like for example
    224                      * the one from MissingRequestParameterException.  Therefore, in order
    225                      * to get the most detailed message, we must consider BOTH sides:
    226                      * the 'cause' AND its root cause!
    227                      */
    228                     var rootCause: Throwable? = cause.cause
    229                     while (rootCause?.cause != null)
    230                         rootCause = rootCause.cause
    231                     // Telling apart invalid JSON vs missing parameter vs invalid parameter.
    232                     val errorCode = when {
    233                         cause is MissingRequestParameterException ->
    234                             TalerErrorCode.GENERIC_PARAMETER_MISSING
    235                         cause is ParameterConversionException ->
    236                             TalerErrorCode.GENERIC_PARAMETER_MALFORMED
    237                         rootCause is CommonError -> when (rootCause) {
    238                             is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
    239                             is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
    240                             is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
    241                         }
    242                         else -> TalerErrorCode.GENERIC_JSON_INVALID
    243                     }
    244                     call.err(
    245                         HttpStatusCode.BadRequest,
    246                         rootCause?.message,
    247                         errorCode,
    248                         null
    249                     )
    250                 }
    251                 is CommonError -> {
    252                     val errorCode = when (cause) {
    253                         is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
    254                         is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
    255                         is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
    256                     }
    257                     call.err(
    258                         HttpStatusCode.BadRequest,
    259                         cause.message,
    260                         errorCode,
    261                         null
    262                     )
    263                 }
    264                 else -> {
    265                     call.err(
    266                         HttpStatusCode.InternalServerError,
    267                         cause.message,
    268                         TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
    269                         cause
    270                     )
    271                 }
    272             }
    273         }
    274     }
    275     val phase = PipelinePhase("phase")
    276     sendPipeline.insertPhaseBefore(ApplicationSendPipeline.Engine, phase)
    277     sendPipeline.intercept(phase) { response ->
    278         if (logger.isTraceEnabled) {
    279             if (response is OutgoingContent.ByteArrayContent) {
    280                 logger.trace("response ${String(response.bytes())}")
    281             }
    282         }
    283         
    284     }
    285     routing { routes() }
    286 }
    287 
    288 // Dirty local variable to stop the server in test TODO remove this ugly hack
    289 var engine: ApplicationEngine? = null
    290 
    291 fun serve(cfg: tech.libeufin.common.ServerConfig, logger: Logger, api: Application.() -> Unit) {
    292     val server = embeddedServer(CIO,
    293         configure = {
    294             when (cfg) {
    295                 is ServerConfig.Tcp -> {
    296                     for (addr in InetAddress.getAllByName(cfg.addr)) {
    297                         logger.info("Listening on ${addr.hostAddress}:${cfg.port}")
    298                         connector {
    299                             port = cfg.port
    300                             host = addr.hostAddress
    301                         }
    302                     }
    303                 }
    304                 is ServerConfig.Unix -> {
    305                     logger.info("Listening on ${cfg.path}")
    306                     unixConnector(cfg.path)
    307                 }
    308             }
    309         },
    310         module = api
    311     )
    312     engine = server.engine
    313     server.start(wait = true)
    314 }