commit 90d4608cf2f4a73cf13cb252e8892a27fa090e62
parent 6b86233c35e7d0d3d065ad8195512d8691e93fcf
Author: Antoine A <>
Date: Tue, 10 Dec 2024 12:59:41 +0100
utils: improve amount logic
Diffstat:
2 files changed, 131 insertions(+), 87 deletions(-)
diff --git a/taler-api/src/db.rs b/taler-api/src/db.rs
@@ -121,8 +121,7 @@ impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Argume
}
fn bind_decimal(self, decimal: &Decimal) -> Self {
- self.bind(decimal.value as i64)
- .bind(decimal.fraction as i32)
+ self.bind(decimal.val as i64).bind(decimal.frac as i32)
}
fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
diff --git a/taler-common/src/amount.rs b/taler-common/src/amount.rs
@@ -22,9 +22,18 @@ use std::{
str::FromStr,
};
-const CURRENCY_LEN: usize = 12;
-const MAX_VALUE: u64 = 2 << 52;
-const FRACTION_BASE: u32 = 100_000_000;
+/** Number of characters we use to represent currency names */
+// We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
+pub const CURRENCY_LEN: usize = 11;
+
+/** Maximum legal value for an amount, based on IEEE double */
+pub const MAX_VALUE: u64 = 2 << 52;
+
+/** The number of digits in a fraction part of an amount */
+pub const FRAC_BASE_NB_DIGITS: u8 = 8;
+
+/** The fraction part of an amount represents which fraction of the value */
+pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
#[derive(Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
/// Inlined ISO 4217 currency string
@@ -32,6 +41,7 @@ pub struct Currency {
/// Len of currency string in buf
len: u8,
/// Buffer of currency bytes, left adjusted and zero padded
+ // TODO use std::ascii::Char when stable
buf: [u8; CURRENCY_LEN],
}
@@ -42,6 +52,23 @@ impl AsRef<str> for Currency {
}
}
+#[derive(Debug, thiserror::Error)]
+pub enum CurrencyErrorKind {
+ #[error("contains illegal characters (only A-Z allowed)")]
+ Invalid,
+ #[error("too long (max {CURRENCY_LEN} chars)")]
+ Big,
+ #[error("is empty")]
+ Empty,
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("currency code name '{currency}' {kind}")]
+pub struct ParseCurrencyError {
+ currency: String,
+ pub kind: CurrencyErrorKind,
+}
+
impl FromStr for Currency {
type Err = ParseCurrencyError;
@@ -49,11 +76,11 @@ impl FromStr for Currency {
let bytes = s.as_bytes();
let len = bytes.len();
if bytes.is_empty() {
- Err(ParseCurrencyError::Empty)
+ Err(CurrencyErrorKind::Empty)
} else if len > CURRENCY_LEN {
- Err(ParseCurrencyError::Big)
+ Err(CurrencyErrorKind::Big)
} else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
- Err(ParseCurrencyError::Invalid)
+ Err(CurrencyErrorKind::Invalid)
} else {
let mut buf = [0; CURRENCY_LEN];
buf[..len].copy_from_slice(bytes);
@@ -62,6 +89,10 @@ impl FromStr for Currency {
buf,
})
}
+ .map_err(|kind| ParseCurrencyError {
+ currency: s.to_owned(),
+ kind,
+ })
}
}
@@ -81,54 +112,60 @@ impl Display for Currency {
Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
)]
pub struct Decimal {
- pub value: u64,
- pub fraction: u32,
+ /** Integer part */
+ pub val: u64,
+ /** Factional part, multiple of FRAC_BASE */
+ pub frac: u32,
}
impl Decimal {
pub const fn max() -> Self {
Self {
- value: MAX_VALUE,
- fraction: FRACTION_BASE - 1,
+ val: MAX_VALUE,
+ frac: FRAC_BASE - 1,
}
}
pub const fn zero() -> Self {
- Self {
- value: 0,
- fraction: 0,
- }
+ Self { val: 0, frac: 0 }
}
fn normalize(mut self) -> Option<Self> {
- self.value = self
- .value
- .checked_add((self.fraction / FRACTION_BASE) as u64)?;
- self.fraction %= FRACTION_BASE;
- if self.value > MAX_VALUE {
+ self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?;
+ self.frac %= FRAC_BASE;
+ if self.val > MAX_VALUE {
return None;
}
Some(self)
}
pub fn add(mut self, rhs: &Self) -> Option<Self> {
- self.value = self.value.checked_add(rhs.value)?;
- self.fraction = self
- .fraction
- .checked_add(rhs.fraction)
+ self.val = self.val.checked_add(rhs.val)?;
+ self.frac = self
+ .frac
+ .checked_add(rhs.frac)
.expect("amount fraction overflow should never happen with normalized amounts");
self.normalize()
}
}
#[derive(Debug, thiserror::Error)]
-pub enum ParseDecimalError {
- #[error("invalid amount format")]
- Format,
- #[error("amount overflow")]
+pub enum DecimalErrorKind {
+ #[error("value specified is too large (must be <= {MAX_VALUE})")]
Overflow,
- #[error(transparent)]
- Number(#[from] ParseIntError),
+ #[error("invalid value: {0}")]
+ InvalidValue(ParseIntError),
+ #[error("invalid fraction: {0}")]
+ InvalidFraction(ParseIntError),
+ #[error("fractional value overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
+ FractionOverflow,
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("decimal '{decimal}' {kind}")]
+pub struct ParseDecimalError {
+ decimal: String,
+ pub kind: DecimalErrorKind,
}
impl FromStr for Decimal {
@@ -137,39 +174,42 @@ impl FromStr for Decimal {
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
- let value: u64 = value.parse()?;
- if value > MAX_VALUE {
- return Err(ParseDecimalError::Format);
- }
-
- if fraction.len() > 8 {
- return Err(ParseDecimalError::Format);
- }
- let fraction: u32 = if fraction.is_empty() {
- 0
- } else {
- fraction.parse::<u32>()? * 10_u32.pow((8 - fraction.len()) as u32)
- };
- Ok(Self { value, fraction })
+ // TODO use try block when stable
+ (|| {
+ let value: u64 = value.parse().map_err(DecimalErrorKind::InvalidValue)?;
+ if value > MAX_VALUE {
+ return Err(DecimalErrorKind::Overflow);
+ }
+
+ if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
+ return Err(DecimalErrorKind::FractionOverflow);
+ }
+ let fraction: u32 = if fraction.is_empty() {
+ 0
+ } else {
+ fraction
+ .parse::<u32>()
+ .map_err(DecimalErrorKind::InvalidFraction)?
+ * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
+ };
+ Ok(Self {
+ val: value,
+ frac: fraction,
+ })
+ })()
+ .map_err(|kind| ParseDecimalError {
+ decimal: s.to_owned(),
+ kind,
+ })
}
}
impl Display for Decimal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.write_fmt(format_args!("{}.{:08}", self.value, self.fraction))
+ f.write_fmt(format_args!("{}.{:08}", self.val, self.frac))
}
}
-#[derive(Debug, thiserror::Error)]
-pub enum ParseCurrencyError {
- #[error("invalid currency")]
- Invalid,
- #[error("currency is longer than {CURRENCY_LEN} chars")]
- Big,
- #[error("currency is empty")]
- Empty,
-}
-
/// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
#[derive(
Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
@@ -185,8 +225,8 @@ impl Amount {
(currency, decimal).into()
}
- pub fn new(currency: impl AsRef<str>, value: u64, fraction: u32) -> Self {
- Self::new_decimal(currency, Decimal { value, fraction })
+ pub fn new(currency: impl AsRef<str>, val: u64, frac: u32) -> Self {
+ Self::new_decimal(currency, Decimal { val, frac })
}
pub fn max(currency: impl AsRef<str>) -> Self {
@@ -224,27 +264,37 @@ pub fn amount(amount: impl AsRef<str>) -> Amount {
}
#[derive(Debug, thiserror::Error)]
-pub enum ParseAmountError {
- #[error("Invalid amount format")]
- FormatAmount,
- #[error(transparent)]
- Currency(#[from] ParseCurrencyError),
+pub enum AmountErrorKind {
+ #[error("invalid format")]
+ Format,
+ #[error("currency {0}")]
+ Currency(#[from] CurrencyErrorKind),
#[error(transparent)]
- Decimal(#[from] ParseDecimalError),
+ Decimal(#[from] DecimalErrorKind),
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("amount '{amount}' {kind}")]
+pub struct ParseAmountError {
+ amount: String,
+ kind: AmountErrorKind,
}
impl FromStr for Amount {
type Err = ParseAmountError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
- let (currency, amount) = s
- .trim()
- .split_once(':')
- .ok_or(ParseAmountError::FormatAmount)?;
- let currency = currency.parse()?;
- let decimal = amount.parse()?;
-
- Ok((currency, decimal).into())
+ // TODO use try block when stable
+ (|| {
+ let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrorKind::Format)?;
+ let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
+ let decimal = amount.parse().map_err(|e: ParseDecimalError| e.kind)?;
+ Ok((currency, decimal).into())
+ })()
+ .map_err(|kind| ParseAmountError {
+ amount: s.to_owned(),
+ kind,
+ })
}
}
@@ -270,7 +320,9 @@ fn test_amount_parse() {
for str in INVALID_AMOUNTS {
let amount = Amount::from_str(str);
- assert!(amount.is_err(), "invalid {} got {:?}", str, amount);
+ assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
+ dbg!(&amount);
+ dbg!(amount.unwrap_err().to_string());
}
let valid_amounts: Vec<(&str, Amount)> = vec![
@@ -330,8 +382,7 @@ fn test_amount_add() {
);
assert_eq!(Amount::new("EUR", u64::MAX, 0).add(&amount("EUR:1")), None);
assert_eq!(
- amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRACTION_BASE - 1))
- .add(&amount("EUR:5.00000002")),
+ amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)).add(&amount("EUR:5.00000002")),
None
);
}
@@ -339,25 +390,19 @@ fn test_amount_add() {
#[test]
fn test_amount_normalize() {
assert_eq!(
- Amount::new("EUR", 4, 2 * FRACTION_BASE).normalize(),
+ Amount::new("EUR", 4, 2 * FRAC_BASE).normalize(),
Some(amount("EUR:6"))
);
assert_eq!(
- Amount::new("EUR", 4, 2 * FRACTION_BASE + 1).normalize(),
+ Amount::new("EUR", 4, 2 * FRAC_BASE + 1).normalize(),
Some(amount("EUR:6.00000001"))
);
assert_eq!(
- Amount::new("EUR", MAX_VALUE, FRACTION_BASE - 1).normalize(),
- Some(Amount::new("EUR", MAX_VALUE, FRACTION_BASE - 1))
- );
- assert_eq!(
- Amount::new("EUR", u64::MAX, FRACTION_BASE).normalize(),
- None
- );
- assert_eq!(
- Amount::new("EUR", MAX_VALUE, FRACTION_BASE).normalize(),
- None
+ Amount::new("EUR", MAX_VALUE, FRAC_BASE - 1).normalize(),
+ Some(Amount::new("EUR", MAX_VALUE, FRAC_BASE - 1))
);
+ assert_eq!(Amount::new("EUR", u64::MAX, FRAC_BASE).normalize(), None);
+ assert_eq!(Amount::new("EUR", MAX_VALUE, FRAC_BASE).normalize(), None);
for amount in [Amount::max("EUR"), Amount::zero("EUR")] {
assert_eq!(amount.clone().normalize(), Some(amount))