commit 0ff83d96b82649f06ee6c3a48090c4f9c646d088
parent d8e5f800f6adcb43822c779d71f8187003841217
Author: Antoine A <>
Date: Fri, 6 Dec 2024 16:09:41 +0100
utils: add decimal & better db utils
Diffstat:
4 files changed, 186 insertions(+), 109 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
@@ -2293,9 +2293,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
-version = "0.1.16"
+version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
+checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
diff --git a/taler-api/src/db.rs b/taler-api/src/db.rs
@@ -14,7 +14,7 @@
TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/>
*/
-use query_helper::{history, SqlQueryHelper};
+use query_helper::{history, page, SqlQueryHelper};
use sqlx::{
postgres::{PgListener, PgRow},
QueryBuilder, Row,
@@ -92,14 +92,13 @@ pub async fn transfer(db: &PgPool, transfer: TransferRequest) -> ApiResult<Trans
FROM taler_transfer(($1, $2)::taler_amount, $3, $4, $5, $6, $7, $8)
",
)
- .bind(transfer.amount.value as i64)
- .bind(transfer.amount.fraction as i32)
+ .bind_amount(&transfer.amount)
.bind(transfer.exchange_base_url.as_str())
.bind(format!("{} {}", transfer.wtid, transfer.exchange_base_url))
.bind(transfer.credit_account.as_str())
.bind(transfer.request_uid.as_slice())
.bind(transfer.wtid.as_slice())
- .bind(Timestamp::now().as_sql_micros())
+ .bind_timestamp(&Timestamp::now())
.try_map(|r: PgRow| {
Ok(if r.try_get("out_request_uid_reuse")? {
TransferResult::RequestUidReuse
@@ -117,28 +116,32 @@ pub async fn transfer(db: &PgPool, transfer: TransferRequest) -> ApiResult<Trans
pub async fn transfer_page(
db: &PgPool,
status: &Option<TransferState>,
- page: &Page,
+ params: &Page,
currency: &str,
) -> ApiResult<Vec<TransferListStatus>> {
- let mut builder = QueryBuilder::new(
- "
- SELECT
- transfer_id,
- status,
- (amount).val as amount_val,
- (amount).frac as amount_frac,
- credit_payto,
- transfer_time
- FROM transfers WHERE
- ",
- );
- if let Some(status) = status {
- builder.push(" status = ").push_bind(status).push(" AND ");
- }
- Ok(builder
- .page("transfer_id", page)
- .build()
- .try_map(|r: PgRow| {
+ Ok(page(
+ db,
+ "transfer_id",
+ params,
+ || {
+ let mut builder = QueryBuilder::new(
+ "
+ SELECT
+ transfer_id,
+ status,
+ (amount).val as amount_val,
+ (amount).frac as amount_frac,
+ credit_payto,
+ transfer_time
+ FROM transfers WHERE
+ ",
+ );
+ if let Some(status) = status {
+ builder.push(" status = ").push_bind(status).push(" AND ");
+ }
+ builder
+ },
+ |r: PgRow| {
Ok(TransferListStatus {
row_id: r.try_get_safeu64("transfer_id")?,
status: r.try_get("status")?,
@@ -146,9 +149,9 @@ pub async fn transfer_page(
credit_account: r.try_get_url("credit_payto")?,
timestamp: r.try_get_timestamp("transfer_time")?,
})
- })
- .fetch_all(db)
- .await?)
+ },
+ )
+ .await?)
}
pub async fn transfer_by_id(
@@ -248,10 +251,9 @@ pub async fn add_incoming(
)
.bind(key.as_slice())
.bind(subject)
- .bind(amount.value as i64)
- .bind(amount.fraction as i32)
+ .bind_amount(amount)
.bind(debit_account.as_str())
- .bind(timestamp.as_sql_micros())
+ .bind_timestamp(timestamp)
.bind(kind)
.try_map(|r: PgRow| {
Ok(if r.try_get("out_reserve_pub_reuse")? {
diff --git a/taler-api/src/db/query_helper.rs b/taler-api/src/db/query_helper.rs
@@ -16,33 +16,56 @@
use std::time::Duration;
-use sqlx::{postgres::PgRow, Error, PgPool, Postgres, QueryBuilder};
-use taler_common::api_params::{History, Page};
+use sqlx::{postgres::PgRow, query::Query, Error, PgPool, Postgres, QueryBuilder};
+use taler_common::{
+ amount::{Amount, Decimal}, api_common::Timestamp, api_params::{History, Page}
+};
use tokio::sync::watch::Receiver;
-use super::PgQueryBuilder;
-
pub trait SqlQueryHelper {
- fn page(&mut self, id_col: &str, params: &Page) -> &mut Self;
+ fn bind_amount(self, amount: &Amount) -> Self;
+ fn bind_decimal(self, decimal: &Decimal) -> Self;
+ fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
+}
+
+impl<'q> SqlQueryHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
+ fn bind_amount(self, amount: &Amount) -> Self {
+ self.bind_decimal(&amount.decimal)
+ }
+
+ fn bind_decimal(self, decimal: &Decimal) -> Self {
+ self.bind(decimal.value as i64)
+ .bind(decimal.fraction as i32)
+ }
+
+ fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
+ self.bind(timestamp.as_sql_micros())
+ }
}
-impl SqlQueryHelper for PgQueryBuilder<'_> {
- fn page(&mut self, id_col: &str, params: &Page) -> &mut Self {
- if let Some(offset) = params.offset {
- self.push(format_args!(
+pub async fn page<'a, R: Send + Unpin>(
+ pool: &PgPool,
+ id_col: &str,
+ params: &Page,
+ prepare: impl Fn() -> QueryBuilder<'a, Postgres>,
+ map: impl Fn(PgRow) -> Result<R, Error> + Send,
+) -> Result<Vec<R>, Error> {
+ let mut builder = prepare();
+ if let Some(offset) = params.offset {
+ builder
+ .push(format_args!(
" {id_col} {}",
if params.backward() { '<' } else { '>' }
))
.push_bind(offset);
- } else {
- self.push("TRUE");
- }
- self.push(format_args!(
- " ORDER BY {id_col} {} LIMIT ",
- if params.backward() { "DESC" } else { "ASC" }
- ))
- .push_bind(params.limit.abs())
+ } else {
+ builder.push("TRUE");
}
+ builder.push(format_args!(
+ " ORDER BY {id_col} {} LIMIT ",
+ if params.backward() { "DESC" } else { "ASC" }
+ ));
+ builder.push_bind(params.limit.abs()).build().try_map(map).fetch_all(pool).await
}
pub async fn history<'a, R: Send + Unpin>(
@@ -50,17 +73,10 @@ pub async fn history<'a, R: Send + Unpin>(
id_col: &str,
params: &History,
listen: impl FnOnce() -> Receiver<i64>,
- prepare: impl Fn() -> QueryBuilder<'a, Postgres>,
+ prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
) -> Result<Vec<R>, Error> {
- let load = || async {
- prepare()
- .page(id_col, ¶ms.page)
- .build()
- .try_map(map)
- .fetch_all(pool)
- .await
- };
+ let load = || async { page(pool, id_col, ¶ms.page, prepare, map).await };
// When going backward there is always at least one transaction or none
if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) {
diff --git a/taler-common/src/amount.rs b/taler-common/src/amount.rs
@@ -77,6 +77,89 @@ impl Display for Currency {
}
}
+#[derive(
+ Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
+)]
+pub struct Decimal {
+ pub value: u64,
+ pub fraction: u32,
+}
+
+impl Decimal {
+ pub const fn max() -> Self {
+ Self {
+ value: MAX_VALUE,
+ fraction: FRACTION_BASE - 1,
+ }
+ }
+
+ pub const fn zero() -> Self {
+ Self {
+ value: 0,
+ fraction: 0,
+ }
+ }
+
+ fn normalize(mut self) -> Option<Self> {
+ self.value = self
+ .value
+ .checked_add((self.fraction / FRACTION_BASE) as u64)?;
+ self.fraction %= FRACTION_BASE;
+ if self.value > MAX_VALUE {
+ return None;
+ }
+ Some(self)
+ }
+
+ pub fn add(mut self, rhs: &Self) -> Option<Self> {
+ self.value = self.value.checked_add(rhs.value)?;
+ self.fraction = self
+ .fraction
+ .checked_add(rhs.fraction)
+ .expect("amount fraction overflow should never happen with normalized amounts");
+ self.normalize()
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum ParseDecimalError {
+ #[error("invalid amount format")]
+ Format,
+ #[error("amount overflow")]
+ Overflow,
+ #[error(transparent)]
+ Number(#[from] ParseIntError),
+}
+
+impl FromStr for Decimal {
+ type Err = ParseDecimalError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
+
+ let value: u64 = value.parse()?;
+ if value > MAX_VALUE {
+ return Err(ParseDecimalError::Format);
+ }
+
+ if fraction.len() > 8 {
+ return Err(ParseDecimalError::Format);
+ }
+ let fraction: u32 = if fraction.is_empty() {
+ 0
+ } else {
+ fraction.parse::<u32>()? * 10_u32.pow((8 - fraction.len()) as u32)
+ };
+ Ok(Self { value, fraction })
+ }
+}
+
+impl Display for Decimal {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_fmt(format_args!("{}.{:08}", self.value, self.fraction))
+ }
+}
+
#[derive(Debug, thiserror::Error)]
pub enum ParseCurrencyError {
#[error("invalid currency")]
@@ -93,50 +176,48 @@ pub enum ParseCurrencyError {
)]
pub struct Amount {
pub currency: Currency,
- pub value: u64,
- pub fraction: u32,
+ pub decimal: Decimal,
}
impl Amount {
- pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self {
+ pub fn new_decimal(currency: impl AsRef<str>, decimal: Decimal) -> Self {
let currency = currency.as_ref().parse().expect("Invalid currency");
- Self {
- value,
- fraction,
- currency,
- }
+ (currency, decimal).into()
+ }
+
+ pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self {
+ Self::new_decimal(currency, Decimal { value, fraction })
}
pub fn max(currency: impl AsRef<str>) -> Self {
- Self::new(currency, MAX_VALUE, FRACTION_BASE - 1)
+ Self::new_decimal(currency, Decimal::max())
}
pub fn zero(currency: impl AsRef<str>) -> Self {
- Self::new(currency, 0, 0)
+ Self::new_decimal(currency, Decimal::zero())
}
fn normalize(mut self) -> Option<Self> {
- self.value = self
- .value
- .checked_add((self.fraction / FRACTION_BASE) as u64)?;
- self.fraction %= FRACTION_BASE;
- if self.value > MAX_VALUE {
- return None;
- }
+ self.decimal = self.decimal.normalize()?;
Some(self)
}
pub fn add(mut self, rhs: &Self) -> Option<Self> {
assert_eq!(self.currency, rhs.currency);
- self.value = self.value.checked_add(rhs.value)?;
- self.fraction = self
- .fraction
- .checked_add(rhs.fraction)
- .expect("amount fraction overflow should never happen with normalized amounts");
+ self.decimal = self.decimal.add(&rhs.decimal)?;
self.normalize()
}
}
+impl From<(Currency, Decimal)> for Amount {
+ fn from((currency, amount): (Currency, Decimal)) -> Self {
+ Self {
+ currency,
+ decimal: amount,
+ }
+ }
+}
+
#[track_caller]
pub fn amount(amount: impl AsRef<str>) -> Amount {
amount.as_ref().parse().expect("Invalid amount constant")
@@ -146,10 +227,10 @@ pub fn amount(amount: impl AsRef<str>) -> Amount {
pub enum ParseAmountError {
#[error("Invalid amount format")]
FormatAmount,
- #[error("Amount overflow")]
- AmountOverflow,
#[error(transparent)]
- Format(#[from] ParseIntError),
+ Currency(#[from] ParseCurrencyError),
+ #[error(transparent)]
+ Decimal(#[from] ParseDecimalError),
}
impl FromStr for Amount {
@@ -160,38 +241,16 @@ impl FromStr for Amount {
.trim()
.split_once(':')
.ok_or(ParseAmountError::FormatAmount)?;
- if currency.len() > CURRENCY_LEN {
- return Err(ParseAmountError::FormatAmount);
- }
- let (value, fraction) = amount.split_once('.').unwrap_or((amount, ""));
-
- let value: u64 = value.parse().map_err(|_| ParseAmountError::FormatAmount)?;
- if value > MAX_VALUE {
- return Err(ParseAmountError::FormatAmount);
- }
-
- if fraction.len() > 8 {
- return Err(ParseAmountError::FormatAmount);
- }
- let fraction: u32 = if fraction.is_empty() {
- 0
- } else {
- fraction
- .parse::<u32>()
- .map_err(|_| ParseAmountError::FormatAmount)?
- * 10_u32.pow((8 - fraction.len()) as u32)
- };
+ let currency = currency.parse()?;
+ let decimal = amount.parse()?;
- Ok(Self::new(currency, value, fraction))
+ Ok((currency, decimal).into())
}
}
impl Display for Amount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.write_fmt(format_args!(
- "{}:{}.{:08}",
- self.currency, self.value, self.fraction
- ))
+ f.write_fmt(format_args!("{}:{}", self.currency, self.decimal))
}
}