taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

db.rs (8420B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024-2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{str::FromStr, time::Duration};
     18 
     19 use jiff::{
     20     Timestamp,
     21     civil::{Date, Time},
     22     tz::TimeZone,
     23 };
     24 use sqlx::{
     25     Decode, Error, PgExecutor, PgPool, QueryBuilder, Type, error::BoxDynError, postgres::PgRow,
     26     query::Query,
     27 };
     28 use sqlx::{Postgres, Row};
     29 use taler_common::{
     30     api_common::SafeU64,
     31     api_params::{History, Page},
     32     types::{
     33         amount::{Amount, Currency, Decimal},
     34         base32::Base32,
     35         iban::IBAN,
     36         payto::PaytoURI,
     37         utils::date_to_utc_timestamp,
     38     },
     39 };
     40 use tokio::sync::watch::Receiver;
     41 use url::Url;
     42 
     43 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>;
     44 
     45 #[derive(Debug, Clone, Copy, PartialEq, Eq, sqlx::Type)]
     46 #[allow(non_camel_case_types)]
     47 #[sqlx(type_name = "incoming_type")]
     48 pub enum IncomingType {
     49     reserve,
     50     kyc,
     51     wad,
     52 }
     53 
     54 /* ----- Routines ------ */
     55 
     56 pub async fn page<'a, 'b, R: Send + Unpin>(
     57     pool: impl PgExecutor<'b>,
     58     id_col: &str,
     59     params: &Page,
     60     prepare: impl Fn() -> QueryBuilder<'a, Postgres>,
     61     map: impl Fn(PgRow) -> Result<R, Error> + Send,
     62 ) -> Result<Vec<R>, Error> {
     63     let mut builder = prepare();
     64     if let Some(offset) = params.offset {
     65         builder
     66             .push(format_args!(
     67                 " {id_col} {}",
     68                 if params.backward() { '<' } else { '>' }
     69             ))
     70             .push_bind(offset);
     71     } else {
     72         builder.push("TRUE");
     73     }
     74     builder.push(format_args!(
     75         " ORDER BY {id_col} {} LIMIT ",
     76         if params.backward() { "DESC" } else { "ASC" }
     77     ));
     78     builder
     79         .push_bind(params.limit.abs())
     80         .build()
     81         .try_map(map)
     82         .fetch_all(pool)
     83         .await
     84 }
     85 
     86 pub async fn history<'a, 'b, R: Send + Unpin>(
     87     pool: &PgPool,
     88     id_col: &str,
     89     params: &History,
     90     listen: impl FnOnce() -> Receiver<i64>,
     91     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
     92     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
     93 ) -> Result<Vec<R>, Error> {
     94     let load = || async { page(pool, id_col, &params.page, prepare, map).await };
     95 
     96     // When going backward there is always at least one transaction or none
     97     if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) {
     98         let mut listener = listen();
     99         let init = load().await?;
    100         // Long polling if we found no transactions
    101         if init.is_empty() {
    102             let pooling = tokio::time::timeout(
    103                 Duration::from_millis(params.timeout_ms.unwrap_or(0)),
    104                 async {
    105                     listener
    106                         .wait_for(|id| params.page.offset.is_none_or(|offset| *id > offset))
    107                         .await
    108                         .ok();
    109                 },
    110             )
    111             .await;
    112             match pooling {
    113                 Ok(_) => load().await,
    114                 Err(_) => Ok(init),
    115             }
    116         } else {
    117             Ok(init)
    118         }
    119     } else {
    120         load().await
    121     }
    122 }
    123 
    124 /* ----- Bind ----- */
    125 
    126 pub trait BindHelper {
    127     fn bind_amount(self, amount: &Amount) -> Self;
    128     fn bind_decimal(self, decimal: &Decimal) -> Self;
    129     fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
    130     fn bind_date(self, date: &Date) -> Self;
    131 }
    132 
    133 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
    134     fn bind_amount(self, amount: &Amount) -> Self {
    135         self.bind_decimal(&amount.decimal())
    136     }
    137 
    138     fn bind_decimal(self, decimal: &Decimal) -> Self {
    139         self.bind(decimal.val as i64).bind(decimal.frac as i32)
    140     }
    141 
    142     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    143         self.bind(timestamp.as_microsecond())
    144     }
    145 
    146     fn bind_date(self, date: &Date) -> Self {
    147         self.bind_timestamp(&date_to_utc_timestamp(date))
    148     }
    149 }
    150 
    151 /* ----- Get ----- */
    152 
    153 pub trait TypeHelper {
    154     fn try_get_map<
    155         'r,
    156         I: sqlx::ColumnIndex<Self>,
    157         T: Decode<'r, Postgres> + Type<Postgres>,
    158         E: Into<BoxDynError>,
    159         R,
    160         M: FnOnce(T) -> Result<R, E>,
    161     >(
    162         &'r self,
    163         index: I,
    164         map: M,
    165     ) -> sqlx::Result<R>;
    166     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    167         &self,
    168         index: I,
    169     ) -> sqlx::Result<T>;
    170     fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> {
    171         self.try_get_map(index, |micros| {
    172             jiff::Timestamp::from_microsecond(micros)
    173                 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}"))
    174         })
    175     }
    176     fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
    177         let timestamp = self.try_get_timestamp(index)?;
    178         let zoned = timestamp.to_zoned(TimeZone::UTC);
    179         assert_eq!(zoned.time(), Time::midnight());
    180         Ok(zoned.date())
    181     }
    182     fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> {
    183         self.try_get_map(index, |signed: i32| signed.try_into())
    184     }
    185     fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> {
    186         self.try_get_map(index, |signed: i64| signed.try_into())
    187     }
    188     fn try_get_safeu64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<SafeU64> {
    189         self.try_get_map(index, |signed: i64| SafeU64::try_from(signed))
    190     }
    191     fn try_get_base32<I: sqlx::ColumnIndex<Self>, const L: usize>(
    192         &self,
    193         index: I,
    194     ) -> sqlx::Result<Base32<L>> {
    195         self.try_get_map(index, |slice: &[u8]| Base32::try_from(slice))
    196     }
    197     fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> {
    198         self.try_get_parse(index)
    199     }
    200     fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> {
    201         self.try_get_parse(index)
    202     }
    203     fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> {
    204         self.try_get_parse(index)
    205     }
    206     fn try_get_decimal<I: sqlx::ColumnIndex<Self>>(
    207         &self,
    208         val: I,
    209         frac: I,
    210     ) -> sqlx::Result<Decimal> {
    211         let val = self.try_get_u64(val)?;
    212         let frac = self.try_get_u32(frac)?;
    213         Ok(Decimal::new(val, frac))
    214     }
    215     fn try_get_amount(&self, index: &str, currency: &Currency) -> sqlx::Result<Amount>;
    216     fn try_get_amount_i(&self, index: usize, currency: &Currency) -> sqlx::Result<Amount>;
    217 }
    218 
    219 impl TypeHelper for PgRow {
    220     fn try_get_map<
    221         'r,
    222         I: sqlx::ColumnIndex<Self>,
    223         T: Decode<'r, Postgres> + Type<Postgres>,
    224         E: Into<BoxDynError>,
    225         R,
    226         M: FnOnce(T) -> Result<R, E>,
    227     >(
    228         &'r self,
    229         index: I,
    230         map: M,
    231     ) -> sqlx::Result<R> {
    232         let primitive: T = self.try_get(&index)?;
    233         map(primitive).map_err(|source| sqlx::Error::ColumnDecode {
    234             index: format!("{index:?}"),
    235             source: source.into(),
    236         })
    237     }
    238 
    239     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    240         &self,
    241         index: I,
    242     ) -> sqlx::Result<T> {
    243         self.try_get_map(index, |s: &str| s.parse())
    244     }
    245 
    246     fn try_get_amount(&self, index: &str, currency: &Currency) -> sqlx::Result<Amount> {
    247         let val_idx = format!("{index}_val");
    248         let frac_idx = format!("{index}_frac");
    249         let val_idx = val_idx.as_str();
    250         let frac_idx = frac_idx.as_str();
    251 
    252         Ok(Amount::new_decimal(
    253             currency,
    254             self.try_get_decimal(val_idx, frac_idx)?,
    255         ))
    256     }
    257 
    258     fn try_get_amount_i(&self, index: usize, currency: &Currency) -> sqlx::Result<Amount> {
    259         Ok(Amount::new_decimal(
    260             currency,
    261             self.try_get_decimal(index, index + 1)?,
    262         ))
    263     }
    264 }