taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

commit 0ff83d96b82649f06ee6c3a48090c4f9c646d088
parent d8e5f800f6adcb43822c779d71f8187003841217
Author: Antoine A <>
Date:   Fri,  6 Dec 2024 16:09:41 +0100

utils: add decimal & better db utils

Diffstat:
MCargo.lock | 4++--
Mtaler-api/src/db.rs | 62++++++++++++++++++++++++++++++++------------------------------
Mtaler-api/src/db/query_helper.rs | 68++++++++++++++++++++++++++++++++++++++++++--------------------------
Mtaler-common/src/amount.rs | 161++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------
4 files changed, 186 insertions(+), 109 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock @@ -2293,9 +2293,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", diff --git a/taler-api/src/db.rs b/taler-api/src/db.rs @@ -14,7 +14,7 @@ TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> */ -use query_helper::{history, SqlQueryHelper}; +use query_helper::{history, page, SqlQueryHelper}; use sqlx::{ postgres::{PgListener, PgRow}, QueryBuilder, Row, @@ -92,14 +92,13 @@ pub async fn transfer(db: &PgPool, transfer: TransferRequest) -> ApiResult<Trans FROM taler_transfer(($1, $2)::taler_amount, $3, $4, $5, $6, $7, $8) ", ) - .bind(transfer.amount.value as i64) - .bind(transfer.amount.fraction as i32) + .bind_amount(&transfer.amount) .bind(transfer.exchange_base_url.as_str()) .bind(format!("{} {}", transfer.wtid, transfer.exchange_base_url)) .bind(transfer.credit_account.as_str()) .bind(transfer.request_uid.as_slice()) .bind(transfer.wtid.as_slice()) - .bind(Timestamp::now().as_sql_micros()) + .bind_timestamp(&Timestamp::now()) .try_map(|r: PgRow| { Ok(if r.try_get("out_request_uid_reuse")? { TransferResult::RequestUidReuse @@ -117,28 +116,32 @@ pub async fn transfer(db: &PgPool, transfer: TransferRequest) -> ApiResult<Trans pub async fn transfer_page( db: &PgPool, status: &Option<TransferState>, - page: &Page, + params: &Page, currency: &str, ) -> ApiResult<Vec<TransferListStatus>> { - let mut builder = QueryBuilder::new( - " - SELECT - transfer_id, - status, - (amount).val as amount_val, - (amount).frac as amount_frac, - credit_payto, - transfer_time - FROM transfers WHERE - ", - ); - if let Some(status) = status { - builder.push(" status = ").push_bind(status).push(" AND "); - } - Ok(builder - .page("transfer_id", page) - .build() - .try_map(|r: PgRow| { + Ok(page( + db, + "transfer_id", + params, + || { + let mut builder = QueryBuilder::new( + " + SELECT + transfer_id, + status, + (amount).val as amount_val, + (amount).frac as amount_frac, + credit_payto, + transfer_time + FROM transfers WHERE + ", + ); + if let Some(status) = status { + builder.push(" status = ").push_bind(status).push(" AND "); + } + builder + }, + |r: PgRow| { Ok(TransferListStatus { row_id: r.try_get_safeu64("transfer_id")?, status: r.try_get("status")?, @@ -146,9 +149,9 @@ pub async fn transfer_page( credit_account: r.try_get_url("credit_payto")?, timestamp: r.try_get_timestamp("transfer_time")?, }) - }) - .fetch_all(db) - .await?) + }, + ) + .await?) } pub async fn transfer_by_id( @@ -248,10 +251,9 @@ pub async fn add_incoming( ) .bind(key.as_slice()) .bind(subject) - .bind(amount.value as i64) - .bind(amount.fraction as i32) + .bind_amount(amount) .bind(debit_account.as_str()) - .bind(timestamp.as_sql_micros()) + .bind_timestamp(timestamp) .bind(kind) .try_map(|r: PgRow| { Ok(if r.try_get("out_reserve_pub_reuse")? { diff --git a/taler-api/src/db/query_helper.rs b/taler-api/src/db/query_helper.rs @@ -16,33 +16,56 @@ use std::time::Duration; -use sqlx::{postgres::PgRow, Error, PgPool, Postgres, QueryBuilder}; -use taler_common::api_params::{History, Page}; +use sqlx::{postgres::PgRow, query::Query, Error, PgPool, Postgres, QueryBuilder}; +use taler_common::{ + amount::{Amount, Decimal}, api_common::Timestamp, api_params::{History, Page} +}; use tokio::sync::watch::Receiver; -use super::PgQueryBuilder; - pub trait SqlQueryHelper { - fn page(&mut self, id_col: &str, params: &Page) -> &mut Self; + fn bind_amount(self, amount: &Amount) -> Self; + fn bind_decimal(self, decimal: &Decimal) -> Self; + fn bind_timestamp(self, timestamp: &Timestamp) -> Self; +} + +impl<'q> SqlQueryHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> { + fn bind_amount(self, amount: &Amount) -> Self { + self.bind_decimal(&amount.decimal) + } + + fn bind_decimal(self, decimal: &Decimal) -> Self { + self.bind(decimal.value as i64) + .bind(decimal.fraction as i32) + } + + fn bind_timestamp(self, timestamp: &Timestamp) -> Self { + self.bind(timestamp.as_sql_micros()) + } } -impl SqlQueryHelper for PgQueryBuilder<'_> { - fn page(&mut self, id_col: &str, params: &Page) -> &mut Self { - if let Some(offset) = params.offset { - self.push(format_args!( +pub async fn page<'a, R: Send + Unpin>( + pool: &PgPool, + id_col: &str, + params: &Page, + prepare: impl Fn() -> QueryBuilder<'a, Postgres>, + map: impl Fn(PgRow) -> Result<R, Error> + Send, +) -> Result<Vec<R>, Error> { + let mut builder = prepare(); + if let Some(offset) = params.offset { + builder + .push(format_args!( " {id_col} {}", if params.backward() { '<' } else { '>' } )) .push_bind(offset); - } else { - self.push("TRUE"); - } - self.push(format_args!( - " ORDER BY {id_col} {} LIMIT ", - if params.backward() { "DESC" } else { "ASC" } - )) - .push_bind(params.limit.abs()) + } else { + builder.push("TRUE"); } + builder.push(format_args!( + " ORDER BY {id_col} {} LIMIT ", + if params.backward() { "DESC" } else { "ASC" } + )); + builder.push_bind(params.limit.abs()).build().try_map(map).fetch_all(pool).await } pub async fn history<'a, R: Send + Unpin>( @@ -50,17 +73,10 @@ pub async fn history<'a, R: Send + Unpin>( id_col: &str, params: &History, listen: impl FnOnce() -> Receiver<i64>, - prepare: impl Fn() -> QueryBuilder<'a, Postgres>, + prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy, map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy, ) -> Result<Vec<R>, Error> { - let load = || async { - prepare() - .page(id_col, &params.page) - .build() - .try_map(map) - .fetch_all(pool) - .await - }; + let load = || async { page(pool, id_col, &params.page, prepare, map).await }; // When going backward there is always at least one transaction or none if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) { diff --git a/taler-common/src/amount.rs b/taler-common/src/amount.rs @@ -77,6 +77,89 @@ impl Display for Currency { } } +#[derive( + Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, +)] +pub struct Decimal { + pub value: u64, + pub fraction: u32, +} + +impl Decimal { + pub const fn max() -> Self { + Self { + value: MAX_VALUE, + fraction: FRACTION_BASE - 1, + } + } + + pub const fn zero() -> Self { + Self { + value: 0, + fraction: 0, + } + } + + fn normalize(mut self) -> Option<Self> { + self.value = self + .value + .checked_add((self.fraction / FRACTION_BASE) as u64)?; + self.fraction %= FRACTION_BASE; + if self.value > MAX_VALUE { + return None; + } + Some(self) + } + + pub fn add(mut self, rhs: &Self) -> Option<Self> { + self.value = self.value.checked_add(rhs.value)?; + self.fraction = self + .fraction + .checked_add(rhs.fraction) + .expect("amount fraction overflow should never happen with normalized amounts"); + self.normalize() + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ParseDecimalError { + #[error("invalid amount format")] + Format, + #[error("amount overflow")] + Overflow, + #[error(transparent)] + Number(#[from] ParseIntError), +} + +impl FromStr for Decimal { + type Err = ParseDecimalError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let (value, fraction) = s.split_once('.').unwrap_or((s, "")); + + let value: u64 = value.parse()?; + if value > MAX_VALUE { + return Err(ParseDecimalError::Format); + } + + if fraction.len() > 8 { + return Err(ParseDecimalError::Format); + } + let fraction: u32 = if fraction.is_empty() { + 0 + } else { + fraction.parse::<u32>()? * 10_u32.pow((8 - fraction.len()) as u32) + }; + Ok(Self { value, fraction }) + } +} + +impl Display for Decimal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}.{:08}", self.value, self.fraction)) + } +} + #[derive(Debug, thiserror::Error)] pub enum ParseCurrencyError { #[error("invalid currency")] @@ -93,50 +176,48 @@ pub enum ParseCurrencyError { )] pub struct Amount { pub currency: Currency, - pub value: u64, - pub fraction: u32, + pub decimal: Decimal, } impl Amount { - pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self { + pub fn new_decimal(currency: impl AsRef<str>, decimal: Decimal) -> Self { let currency = currency.as_ref().parse().expect("Invalid currency"); - Self { - value, - fraction, - currency, - } + (currency, decimal).into() + } + + pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self { + Self::new_decimal(currency, Decimal { value, fraction }) } pub fn max(currency: impl AsRef<str>) -> Self { - Self::new(currency, MAX_VALUE, FRACTION_BASE - 1) + Self::new_decimal(currency, Decimal::max()) } pub fn zero(currency: impl AsRef<str>) -> Self { - Self::new(currency, 0, 0) + Self::new_decimal(currency, Decimal::zero()) } fn normalize(mut self) -> Option<Self> { - self.value = self - .value - .checked_add((self.fraction / FRACTION_BASE) as u64)?; - self.fraction %= FRACTION_BASE; - if self.value > MAX_VALUE { - return None; - } + self.decimal = self.decimal.normalize()?; Some(self) } pub fn add(mut self, rhs: &Self) -> Option<Self> { assert_eq!(self.currency, rhs.currency); - self.value = self.value.checked_add(rhs.value)?; - self.fraction = self - .fraction - .checked_add(rhs.fraction) - .expect("amount fraction overflow should never happen with normalized amounts"); + self.decimal = self.decimal.add(&rhs.decimal)?; self.normalize() } } +impl From<(Currency, Decimal)> for Amount { + fn from((currency, amount): (Currency, Decimal)) -> Self { + Self { + currency, + decimal: amount, + } + } +} + #[track_caller] pub fn amount(amount: impl AsRef<str>) -> Amount { amount.as_ref().parse().expect("Invalid amount constant") @@ -146,10 +227,10 @@ pub fn amount(amount: impl AsRef<str>) -> Amount { pub enum ParseAmountError { #[error("Invalid amount format")] FormatAmount, - #[error("Amount overflow")] - AmountOverflow, #[error(transparent)] - Format(#[from] ParseIntError), + Currency(#[from] ParseCurrencyError), + #[error(transparent)] + Decimal(#[from] ParseDecimalError), } impl FromStr for Amount { @@ -160,38 +241,16 @@ impl FromStr for Amount { .trim() .split_once(':') .ok_or(ParseAmountError::FormatAmount)?; - if currency.len() > CURRENCY_LEN { - return Err(ParseAmountError::FormatAmount); - } - let (value, fraction) = amount.split_once('.').unwrap_or((amount, "")); - - let value: u64 = value.parse().map_err(|_| ParseAmountError::FormatAmount)?; - if value > MAX_VALUE { - return Err(ParseAmountError::FormatAmount); - } - - if fraction.len() > 8 { - return Err(ParseAmountError::FormatAmount); - } - let fraction: u32 = if fraction.is_empty() { - 0 - } else { - fraction - .parse::<u32>() - .map_err(|_| ParseAmountError::FormatAmount)? - * 10_u32.pow((8 - fraction.len()) as u32) - }; + let currency = currency.parse()?; + let decimal = amount.parse()?; - Ok(Self::new(currency, value, fraction)) + Ok((currency, decimal).into()) } } impl Display for Amount { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "{}:{}.{:08}", - self.currency, self.value, self.fraction - )) + f.write_fmt(format_args!("{}:{}", self.currency, self.decimal)) } }