taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

db.rs (8136B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{str::FromStr, time::Duration};
     18 
     19 use jiff::{
     20     Timestamp,
     21     civil::{Date, Time},
     22     tz::TimeZone,
     23 };
     24 use sqlx::{
     25     Decode, Error, PgExecutor, PgPool, QueryBuilder, Type, error::BoxDynError, postgres::PgRow,
     26     query::Query,
     27 };
     28 use sqlx::{Postgres, Row};
     29 use taler_common::{
     30     api_common::SafeU64,
     31     api_params::{History, Page},
     32     types::{
     33         amount::{Amount, Currency, Decimal},
     34         iban::IBAN,
     35         payto::PaytoURI,
     36         utils::date_to_utc_timestamp,
     37     },
     38 };
     39 use tokio::sync::watch::Receiver;
     40 use url::Url;
     41 
     42 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>;
     43 
     44 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type)]
     45 #[allow(non_camel_case_types)]
     46 #[sqlx(type_name = "incoming_type")]
     47 pub enum IncomingType {
     48     reserve,
     49     kyc,
     50     wad,
     51 }
     52 
     53 /* ----- Routines ------ */
     54 
     55 pub async fn page<'a, 'b, R: Send + Unpin>(
     56     pool: impl PgExecutor<'b>,
     57     id_col: &str,
     58     params: &Page,
     59     prepare: impl Fn() -> QueryBuilder<'a, Postgres>,
     60     map: impl Fn(PgRow) -> Result<R, Error> + Send,
     61 ) -> Result<Vec<R>, Error> {
     62     let mut builder = prepare();
     63     if let Some(offset) = params.offset {
     64         builder
     65             .push(format_args!(
     66                 " {id_col} {}",
     67                 if params.backward() { '<' } else { '>' }
     68             ))
     69             .push_bind(offset);
     70     } else {
     71         builder.push("TRUE");
     72     }
     73     builder.push(format_args!(
     74         " ORDER BY {id_col} {} LIMIT ",
     75         if params.backward() { "DESC" } else { "ASC" }
     76     ));
     77     builder
     78         .push_bind(params.limit.abs())
     79         .build()
     80         .try_map(map)
     81         .fetch_all(pool)
     82         .await
     83 }
     84 
     85 pub async fn history<'a, 'b, R: Send + Unpin>(
     86     pool: &PgPool,
     87     id_col: &str,
     88     params: &History,
     89     listen: impl FnOnce() -> Receiver<i64>,
     90     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
     91     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
     92 ) -> Result<Vec<R>, Error> {
     93     let load = || async { page(pool, id_col, &params.page, prepare, map).await };
     94 
     95     // When going backward there is always at least one transaction or none
     96     if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) {
     97         let mut listener = listen();
     98         let init = load().await?;
     99         // Long polling if we found no transactions
    100         if init.is_empty() {
    101             let pooling = tokio::time::timeout(
    102                 Duration::from_millis(params.timeout_ms.unwrap_or(0)),
    103                 async {
    104                     listener
    105                         .wait_for(|id| params.page.offset.is_none_or(|offset| *id > offset))
    106                         .await
    107                         .ok();
    108                 },
    109             )
    110             .await;
    111             match pooling {
    112                 Ok(_) => load().await,
    113                 Err(_) => Ok(init),
    114             }
    115         } else {
    116             Ok(init)
    117         }
    118     } else {
    119         load().await
    120     }
    121 }
    122 
    123 /* ----- Bind ----- */
    124 
    125 pub trait BindHelper {
    126     fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
    127     fn bind_date(self, date: &Date) -> Self;
    128 }
    129 
    130 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
    131     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    132         self.bind(timestamp.as_microsecond())
    133     }
    134 
    135     fn bind_date(self, date: &Date) -> Self {
    136         self.bind_timestamp(&date_to_utc_timestamp(date))
    137     }
    138 }
    139 
    140 /* ----- Get ----- */
    141 
    142 pub trait TypeHelper {
    143     fn try_get_map<
    144         'r,
    145         I: sqlx::ColumnIndex<Self>,
    146         T: Decode<'r, Postgres> + Type<Postgres>,
    147         E: Into<BoxDynError>,
    148         R,
    149         M: FnOnce(T) -> Result<R, E>,
    150     >(
    151         &'r self,
    152         index: I,
    153         map: M,
    154     ) -> sqlx::Result<R>;
    155     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    156         &self,
    157         index: I,
    158     ) -> sqlx::Result<T>;
    159     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    160         &self,
    161         index: I,
    162     ) -> sqlx::Result<Option<T>>;
    163     fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> {
    164         self.try_get_map(index, |micros| {
    165             jiff::Timestamp::from_microsecond(micros)
    166                 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}"))
    167         })
    168     }
    169     fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
    170         let timestamp = self.try_get_timestamp(index)?;
    171         let zoned = timestamp.to_zoned(TimeZone::UTC);
    172         assert_eq!(zoned.time(), Time::midnight());
    173         Ok(zoned.date())
    174     }
    175     fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> {
    176         self.try_get_map(index, |signed: i32| signed.try_into())
    177     }
    178     fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> {
    179         self.try_get_map(index, |signed: i64| signed.try_into())
    180     }
    181     fn try_get_safeu64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<SafeU64> {
    182         self.try_get_map(index, |signed: i64| SafeU64::try_from(signed))
    183     }
    184     fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> {
    185         self.try_get_parse(index)
    186     }
    187     fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> {
    188         self.try_get_parse(index)
    189     }
    190     fn try_get_opt_payto<I: sqlx::ColumnIndex<Self>>(
    191         &self,
    192         index: I,
    193     ) -> sqlx::Result<Option<PaytoURI>> {
    194         self.try_get_opt_parse(index)
    195     }
    196     fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> {
    197         self.try_get_parse(index)
    198     }
    199     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    200         &self,
    201         index: I,
    202         currency: &Currency,
    203     ) -> sqlx::Result<Amount>;
    204 
    205     /** Flag consider NULL and false to be the same */
    206     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool>;
    207 }
    208 
    209 impl TypeHelper for PgRow {
    210     fn try_get_map<
    211         'r,
    212         I: sqlx::ColumnIndex<Self>,
    213         T: Decode<'r, Postgres> + Type<Postgres>,
    214         E: Into<BoxDynError>,
    215         R,
    216         M: FnOnce(T) -> Result<R, E>,
    217     >(
    218         &'r self,
    219         index: I,
    220         map: M,
    221     ) -> sqlx::Result<R> {
    222         let primitive: T = self.try_get(&index)?;
    223         map(primitive).map_err(|source| sqlx::Error::ColumnDecode {
    224             index: format!("{index:?}"),
    225             source: source.into(),
    226         })
    227     }
    228 
    229     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    230         &self,
    231         index: I,
    232     ) -> sqlx::Result<T> {
    233         self.try_get_map(index, |s: &str| s.parse())
    234     }
    235 
    236     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    237         &self,
    238         index: I,
    239     ) -> sqlx::Result<Option<T>> {
    240         self.try_get_map(index, |s: Option<&str>| s.map(|s| s.parse()).transpose())
    241     }
    242 
    243     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    244         &self,
    245         index: I,
    246         currency: &Currency,
    247     ) -> sqlx::Result<Amount> {
    248         let decimal: Decimal = self.try_get(index)?;
    249         Ok(Amount::new_decimal(currency, decimal))
    250     }
    251 
    252     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool> {
    253         let opt_bool: Option<bool> = self.try_get(index)?;
    254         Ok(opt_bool.unwrap_or(false))
    255     }
    256 }