taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

commit 90d4608cf2f4a73cf13cb252e8892a27fa090e62
parent 6b86233c35e7d0d3d065ad8195512d8691e93fcf
Author: Antoine A <>
Date:   Tue, 10 Dec 2024 12:59:41 +0100

utils: improve amount logic

Diffstat:
Mtaler-api/src/db.rs | 3+--
Mtaler-common/src/amount.rs | 215++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
2 files changed, 131 insertions(+), 87 deletions(-)

diff --git a/taler-api/src/db.rs b/taler-api/src/db.rs @@ -121,8 +121,7 @@ impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Argume } fn bind_decimal(self, decimal: &Decimal) -> Self { - self.bind(decimal.value as i64) - .bind(decimal.fraction as i32) + self.bind(decimal.val as i64).bind(decimal.frac as i32) } fn bind_timestamp(self, timestamp: &Timestamp) -> Self { diff --git a/taler-common/src/amount.rs b/taler-common/src/amount.rs @@ -22,9 +22,18 @@ use std::{ str::FromStr, }; -const CURRENCY_LEN: usize = 12; -const MAX_VALUE: u64 = 2 << 52; -const FRACTION_BASE: u32 = 100_000_000; +/** Number of characters we use to represent currency names */ +// We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination +pub const CURRENCY_LEN: usize = 11; + +/** Maximum legal value for an amount, based on IEEE double */ +pub const MAX_VALUE: u64 = 2 << 52; + +/** The number of digits in a fraction part of an amount */ +pub const FRAC_BASE_NB_DIGITS: u8 = 8; + +/** The fraction part of an amount represents which fraction of the value */ +pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32); #[derive(Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)] /// Inlined ISO 4217 currency string @@ -32,6 +41,7 @@ pub struct Currency { /// Len of currency string in buf len: u8, /// Buffer of currency bytes, left adjusted and zero padded + // TODO use std::ascii::Char when stable buf: [u8; CURRENCY_LEN], } @@ -42,6 +52,23 @@ impl AsRef<str> for Currency { } } +#[derive(Debug, thiserror::Error)] +pub enum CurrencyErrorKind { + #[error("contains illegal characters (only A-Z allowed)")] + Invalid, + #[error("too long (max {CURRENCY_LEN} chars)")] + Big, + #[error("is empty")] + Empty, +} + +#[derive(Debug, thiserror::Error)] +#[error("currency code name '{currency}' {kind}")] +pub struct ParseCurrencyError { + currency: String, + pub kind: CurrencyErrorKind, +} + impl FromStr for Currency { type Err = ParseCurrencyError; @@ -49,11 +76,11 @@ impl FromStr for Currency { let bytes = s.as_bytes(); let len = bytes.len(); if bytes.is_empty() { - Err(ParseCurrencyError::Empty) + Err(CurrencyErrorKind::Empty) } else if len > CURRENCY_LEN { - Err(ParseCurrencyError::Big) + Err(CurrencyErrorKind::Big) } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) { - Err(ParseCurrencyError::Invalid) + Err(CurrencyErrorKind::Invalid) } else { let mut buf = [0; CURRENCY_LEN]; buf[..len].copy_from_slice(bytes); @@ -62,6 +89,10 @@ impl FromStr for Currency { buf, }) } + .map_err(|kind| ParseCurrencyError { + currency: s.to_owned(), + kind, + }) } } @@ -81,54 +112,60 @@ impl Display for Currency { Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, )] pub struct Decimal { - pub value: u64, - pub fraction: u32, + /** Integer part */ + pub val: u64, + /** Factional part, multiple of FRAC_BASE */ + pub frac: u32, } impl Decimal { pub const fn max() -> Self { Self { - value: MAX_VALUE, - fraction: FRACTION_BASE - 1, + val: MAX_VALUE, + frac: FRAC_BASE - 1, } } pub const fn zero() -> Self { - Self { - value: 0, - fraction: 0, - } + Self { val: 0, frac: 0 } } fn normalize(mut self) -> Option<Self> { - self.value = self - .value - .checked_add((self.fraction / FRACTION_BASE) as u64)?; - self.fraction %= FRACTION_BASE; - if self.value > MAX_VALUE { + self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?; + self.frac %= FRAC_BASE; + if self.val > MAX_VALUE { return None; } Some(self) } pub fn add(mut self, rhs: &Self) -> Option<Self> { - self.value = self.value.checked_add(rhs.value)?; - self.fraction = self - .fraction - .checked_add(rhs.fraction) + self.val = self.val.checked_add(rhs.val)?; + self.frac = self + .frac + .checked_add(rhs.frac) .expect("amount fraction overflow should never happen with normalized amounts"); self.normalize() } } #[derive(Debug, thiserror::Error)] -pub enum ParseDecimalError { - #[error("invalid amount format")] - Format, - #[error("amount overflow")] +pub enum DecimalErrorKind { + #[error("value specified is too large (must be <= {MAX_VALUE})")] Overflow, - #[error(transparent)] - Number(#[from] ParseIntError), + #[error("invalid value: {0}")] + InvalidValue(ParseIntError), + #[error("invalid fraction: {0}")] + InvalidFraction(ParseIntError), + #[error("fractional value overflow (max {FRAC_BASE_NB_DIGITS} digits)")] + FractionOverflow, +} + +#[derive(Debug, thiserror::Error)] +#[error("decimal '{decimal}' {kind}")] +pub struct ParseDecimalError { + decimal: String, + pub kind: DecimalErrorKind, } impl FromStr for Decimal { @@ -137,39 +174,42 @@ impl FromStr for Decimal { fn from_str(s: &str) -> Result<Self, Self::Err> { let (value, fraction) = s.split_once('.').unwrap_or((s, "")); - let value: u64 = value.parse()?; - if value > MAX_VALUE { - return Err(ParseDecimalError::Format); - } - - if fraction.len() > 8 { - return Err(ParseDecimalError::Format); - } - let fraction: u32 = if fraction.is_empty() { - 0 - } else { - fraction.parse::<u32>()? * 10_u32.pow((8 - fraction.len()) as u32) - }; - Ok(Self { value, fraction }) + // TODO use try block when stable + (|| { + let value: u64 = value.parse().map_err(DecimalErrorKind::InvalidValue)?; + if value > MAX_VALUE { + return Err(DecimalErrorKind::Overflow); + } + + if fraction.len() > FRAC_BASE_NB_DIGITS as usize { + return Err(DecimalErrorKind::FractionOverflow); + } + let fraction: u32 = if fraction.is_empty() { + 0 + } else { + fraction + .parse::<u32>() + .map_err(DecimalErrorKind::InvalidFraction)? + * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32) + }; + Ok(Self { + val: value, + frac: fraction, + }) + })() + .map_err(|kind| ParseDecimalError { + decimal: s.to_owned(), + kind, + }) } } impl Display for Decimal { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{}.{:08}", self.value, self.fraction)) + f.write_fmt(format_args!("{}.{:08}", self.val, self.frac)) } } -#[derive(Debug, thiserror::Error)] -pub enum ParseCurrencyError { - #[error("invalid currency")] - Invalid, - #[error("currency is longer than {CURRENCY_LEN} chars")] - Big, - #[error("currency is empty")] - Empty, -} - /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount> #[derive( Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, @@ -185,8 +225,8 @@ impl Amount { (currency, decimal).into() } - pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self { - Self::new_decimal(currency, Decimal { value, fraction }) + pub fn new(currency: impl AsRef<str>, val: u64, frac: u32) -> Self { + Self::new_decimal(currency, Decimal { val, frac }) } pub fn max(currency: impl AsRef<str>) -> Self { @@ -224,27 +264,37 @@ pub fn amount(amount: impl AsRef<str>) -> Amount { } #[derive(Debug, thiserror::Error)] -pub enum ParseAmountError { - #[error("Invalid amount format")] - FormatAmount, - #[error(transparent)] - Currency(#[from] ParseCurrencyError), +pub enum AmountErrorKind { + #[error("invalid format")] + Format, + #[error("currency {0}")] + Currency(#[from] CurrencyErrorKind), #[error(transparent)] - Decimal(#[from] ParseDecimalError), + Decimal(#[from] DecimalErrorKind), +} + +#[derive(Debug, thiserror::Error)] +#[error("amount '{amount}' {kind}")] +pub struct ParseAmountError { + amount: String, + kind: AmountErrorKind, } impl FromStr for Amount { type Err = ParseAmountError; fn from_str(s: &str) -> Result<Self, Self::Err> { - let (currency, amount) = s - .trim() - .split_once(':') - .ok_or(ParseAmountError::FormatAmount)?; - let currency = currency.parse()?; - let decimal = amount.parse()?; - - Ok((currency, decimal).into()) + // TODO use try block when stable + (|| { + let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrorKind::Format)?; + let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?; + let decimal = amount.parse().map_err(|e: ParseDecimalError| e.kind)?; + Ok((currency, decimal).into()) + })() + .map_err(|kind| ParseAmountError { + amount: s.to_owned(), + kind, + }) } } @@ -270,7 +320,9 @@ fn test_amount_parse() { for str in INVALID_AMOUNTS { let amount = Amount::from_str(str); - assert!(amount.is_err(), "invalid {} got {:?}", str, amount); + assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); + dbg!(&amount); + dbg!(amount.unwrap_err().to_string()); } let valid_amounts: Vec<(&str, Amount)> = vec![ @@ -330,8 +382,7 @@ fn test_amount_add() { ); assert_eq!(Amount::new("EUR", u64::MAX, 0).add(&amount("EUR:1")), None); assert_eq!( - amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRACTION_BASE - 1)) - .add(&amount("EUR:5.00000002")), + amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)).add(&amount("EUR:5.00000002")), None ); } @@ -339,25 +390,19 @@ fn test_amount_add() { #[test] fn test_amount_normalize() { assert_eq!( - Amount::new("EUR", 4, 2 * FRACTION_BASE).normalize(), + Amount::new("EUR", 4, 2 * FRAC_BASE).normalize(), Some(amount("EUR:6")) ); assert_eq!( - Amount::new("EUR", 4, 2 * FRACTION_BASE + 1).normalize(), + Amount::new("EUR", 4, 2 * FRAC_BASE + 1).normalize(), Some(amount("EUR:6.00000001")) ); assert_eq!( - Amount::new("EUR", MAX_VALUE, FRACTION_BASE - 1).normalize(), - Some(Amount::new("EUR", MAX_VALUE, FRACTION_BASE - 1)) - ); - assert_eq!( - Amount::new("EUR", u64::MAX, FRACTION_BASE).normalize(), - None - ); - assert_eq!( - Amount::new("EUR", MAX_VALUE, FRACTION_BASE).normalize(), - None + Amount::new("EUR", MAX_VALUE, FRAC_BASE - 1).normalize(), + Some(Amount::new("EUR", MAX_VALUE, FRAC_BASE - 1)) ); + assert_eq!(Amount::new("EUR", u64::MAX, FRAC_BASE).normalize(), None); + assert_eq!(Amount::new("EUR", MAX_VALUE, FRAC_BASE).normalize(), None); for amount in [Amount::max("EUR"), Amount::zero("EUR")] { assert_eq!(amount.clone().normalize(), Some(amount))