taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

amount.rs (16472B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
     18 
     19 use std::{
     20     fmt::{Debug, Display},
     21     num::ParseIntError,
     22     str::FromStr,
     23 };
     24 
     25 use super::utils::InlineStr;
     26 
     27 /** Number of characters we use to represent currency names */
     28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
     29 pub const CURRENCY_LEN: usize = 11;
     30 
     31 /** Maximum legal value for an amount, based on IEEE double */
     32 pub const MAX_VALUE: u64 = 2 << 52;
     33 
     34 /** The number of digits in a fraction part of an amount */
     35 pub const FRAC_BASE_NB_DIGITS: u8 = 8;
     36 
     37 /** The fraction part of an amount represents which fraction of the value */
     38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
     39 
     40 const CENT_FRACTION: u32 = 10u32.pow((FRAC_BASE_NB_DIGITS - 2) as u32);
     41 
     42 #[derive(
     43     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
     44 )]
     45 /// Inlined ISO 4217 currency string
     46 pub struct Currency(InlineStr<CURRENCY_LEN>);
     47 
     48 impl AsRef<str> for Currency {
     49     fn as_ref(&self) -> &str {
     50         self.0.as_ref()
     51     }
     52 }
     53 
     54 #[derive(Debug, thiserror::Error)]
     55 pub enum CurrencyErrorKind {
     56     #[error("contains illegal characters (only A-Z allowed)")]
     57     Invalid,
     58     #[error("too long (max {CURRENCY_LEN} chars)")]
     59     Big,
     60     #[error("is empty")]
     61     Empty,
     62 }
     63 
     64 #[derive(Debug, thiserror::Error)]
     65 #[error("currency code name '{currency}' {kind}")]
     66 pub struct ParseCurrencyError {
     67     currency: String,
     68     pub kind: CurrencyErrorKind,
     69 }
     70 
     71 impl Currency {
     72     pub const TEST: Self = Self::const_parse("TEST");
     73     pub const KUDOS: Self = Self::const_parse("KUDOS");
     74     pub const EUR: Self = Self::const_parse("EUR");
     75     pub const CHF: Self = Self::const_parse("CHF");
     76     pub const HUF: Self = Self::const_parse("HUF");
     77 
     78     pub const fn const_parse(s: &str) -> Currency {
     79         let bytes = s.as_bytes();
     80         let len = bytes.len();
     81 
     82         if bytes.is_empty() {
     83             panic!("empty")
     84         } else if len > CURRENCY_LEN {
     85             panic!("too big")
     86         }
     87         let mut i = 0;
     88         while i < bytes.len() {
     89             if !bytes[i].is_ascii_uppercase() {
     90                 panic!("invalid")
     91             }
     92             i += 1;
     93         }
     94         Self(InlineStr::copy_from_slice(bytes))
     95     }
     96 }
     97 
     98 impl FromStr for Currency {
     99     type Err = ParseCurrencyError;
    100 
    101     fn from_str(s: &str) -> Result<Self, Self::Err> {
    102         let bytes = s.as_bytes();
    103         let len = bytes.len();
    104         if bytes.is_empty() {
    105             Err(CurrencyErrorKind::Empty)
    106         } else if len > CURRENCY_LEN {
    107             Err(CurrencyErrorKind::Big)
    108         } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
    109             Err(CurrencyErrorKind::Invalid)
    110         } else {
    111             Ok(Self(InlineStr::copy_from_slice(bytes)))
    112         }
    113         .map_err(|kind| ParseCurrencyError {
    114             currency: s.to_owned(),
    115             kind,
    116         })
    117     }
    118 }
    119 
    120 impl Debug for Currency {
    121     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    122         Debug::fmt(&self.as_ref(), f)
    123     }
    124 }
    125 
    126 impl Display for Currency {
    127     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    128         Display::fmt(&self.as_ref(), f)
    129     }
    130 }
    131 
    132 #[derive(sqlx::Type)]
    133 #[sqlx(type_name = "taler_amount")]
    134 struct PgTalerAmount {
    135     pub val: i64,
    136     pub frac: i32,
    137 }
    138 
    139 #[derive(
    140     Clone,
    141     Copy,
    142     PartialEq,
    143     Eq,
    144     PartialOrd,
    145     Ord,
    146     serde_with::DeserializeFromStr,
    147     serde_with::SerializeDisplay,
    148 )]
    149 pub struct Decimal {
    150     /** Integer part */
    151     pub val: u64,
    152     /** Factional part, multiple of FRAC_BASE */
    153     pub frac: u32,
    154 }
    155 
    156 impl Decimal {
    157     pub const fn new(val: u64, frac: u32) -> Self {
    158         Self { val, frac }
    159     }
    160 
    161     pub const fn max() -> Self {
    162         Self {
    163             val: MAX_VALUE,
    164             frac: FRAC_BASE - 1,
    165         }
    166     }
    167 
    168     pub const fn zero() -> Self {
    169         Self { val: 0, frac: 0 }
    170     }
    171 
    172     const fn normalize(mut self) -> Option<Self> {
    173         let Some(val) = self.val.checked_add((self.frac / FRAC_BASE) as u64) else {
    174             return None;
    175         };
    176         self.val = val;
    177         self.frac %= FRAC_BASE;
    178         if self.val > MAX_VALUE {
    179             return None;
    180         }
    181         Some(self)
    182     }
    183 
    184     pub fn try_add(mut self, rhs: &Self) -> Option<Self> {
    185         self.val = self.val.checked_add(rhs.val)?;
    186         self.frac = self
    187             .frac
    188             .checked_add(rhs.frac)
    189             .expect("amount fraction overflow should never happen with normalized amounts");
    190         self.normalize()
    191     }
    192 
    193     pub fn try_sub(mut self, rhs: &Self) -> Option<Self> {
    194         if rhs.frac > self.frac {
    195             self.val = self.val.checked_sub(1)?;
    196             self.frac += FRAC_BASE;
    197         }
    198         self.val = self.val.checked_sub(rhs.val)?;
    199         self.frac = self.frac.checked_sub(rhs.frac)?;
    200         self.normalize()
    201     }
    202 
    203     pub const fn to_amount(self, currency: &Currency) -> Amount {
    204         Amount::new_decimal(currency, self)
    205     }
    206 }
    207 
    208 #[derive(Debug, thiserror::Error)]
    209 pub enum DecimalErrKind {
    210     #[error("value overflow (must be <= {MAX_VALUE})")]
    211     Overflow,
    212     #[error("invalid value ({0})")]
    213     InvalidValue(ParseIntError),
    214     #[error("invalid fraction ({0})")]
    215     InvalidFraction(ParseIntError),
    216     #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
    217     FractionOverflow,
    218 }
    219 
    220 #[derive(Debug, thiserror::Error)]
    221 #[error("decimal '{decimal}' {kind}")]
    222 pub struct ParseDecimalErr {
    223     decimal: String,
    224     pub kind: DecimalErrKind,
    225 }
    226 
    227 impl FromStr for Decimal {
    228     type Err = ParseDecimalErr;
    229 
    230     fn from_str(s: &str) -> Result<Self, Self::Err> {
    231         let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
    232 
    233         // TODO use try block when stable
    234         (|| {
    235             let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?;
    236             if value > MAX_VALUE {
    237                 return Err(DecimalErrKind::Overflow);
    238             }
    239 
    240             if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
    241                 return Err(DecimalErrKind::FractionOverflow);
    242             }
    243             let fraction: u32 = if fraction.is_empty() {
    244                 0
    245             } else {
    246                 fraction
    247                     .parse::<u32>()
    248                     .map_err(DecimalErrKind::InvalidFraction)?
    249                     * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
    250             };
    251             Ok(Self {
    252                 val: value,
    253                 frac: fraction,
    254             })
    255         })()
    256         .map_err(|kind| ParseDecimalErr {
    257             decimal: s.to_owned(),
    258             kind,
    259         })
    260     }
    261 }
    262 
    263 impl Display for Decimal {
    264     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    265         if self.frac == 0 {
    266             f.write_fmt(format_args!("{}", self.val))
    267         } else {
    268             let num = format!("{:08}", self.frac);
    269             f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0')))
    270         }
    271     }
    272 }
    273 
    274 impl Debug for Decimal {
    275     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    276         Display::fmt(&self, f)
    277     }
    278 }
    279 
    280 impl sqlx::Type<sqlx::Postgres> for Decimal {
    281     fn type_info() -> sqlx::postgres::PgTypeInfo {
    282         PgTalerAmount::type_info()
    283     }
    284 }
    285 
    286 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal {
    287     fn encode_by_ref(
    288         &self,
    289         buf: &mut sqlx::postgres::PgArgumentBuffer,
    290     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    291         PgTalerAmount {
    292             val: self.val as i64,
    293             frac: self.frac as i32,
    294         }
    295         .encode_by_ref(buf)
    296     }
    297 }
    298 
    299 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal {
    300     fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
    301         let pg = PgTalerAmount::decode(value)?;
    302         Ok(Self {
    303             val: pg.val as u64,
    304             frac: pg.frac as u32,
    305         })
    306     }
    307 }
    308 
    309 #[track_caller]
    310 pub fn decimal(decimal: impl AsRef<str>) -> Decimal {
    311     decimal.as_ref().parse().expect("Invalid decimal constant")
    312 }
    313 
    314 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
    315 #[derive(
    316     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    317 )]
    318 pub struct Amount {
    319     pub currency: Currency,
    320     pub val: u64,
    321     pub frac: u32,
    322 }
    323 
    324 impl Amount {
    325     pub const fn new_decimal(currency: &Currency, decimal: Decimal) -> Self {
    326         Self {
    327             currency: *currency,
    328             val: decimal.val,
    329             frac: decimal.frac,
    330         }
    331     }
    332 
    333     pub const fn new(currency: &Currency, val: u64, frac: u32) -> Self {
    334         Self::new_decimal(currency, Decimal { val, frac })
    335     }
    336 
    337     pub const fn max(currency: &Currency) -> Self {
    338         Self::new_decimal(currency, Decimal::max())
    339     }
    340 
    341     pub const fn zero(currency: &Currency) -> Self {
    342         Self::new_decimal(currency, Decimal::zero())
    343     }
    344 
    345     pub fn is_zero(&self) -> bool {
    346         self.decimal() == Decimal::zero()
    347     }
    348 
    349     /* Check is amount has fractional amount < 0.01 */
    350     pub const fn is_sub_cent(&self) -> bool {
    351         !self.frac.is_multiple_of(CENT_FRACTION)
    352     }
    353 
    354     pub const fn decimal(&self) -> Decimal {
    355         Decimal {
    356             val: self.val,
    357             frac: self.frac,
    358         }
    359     }
    360 
    361     pub fn normalize(self) -> Option<Self> {
    362         let decimal = self.decimal().normalize()?;
    363         Some((self.currency, decimal).into())
    364     }
    365 
    366     pub fn try_add(self, rhs: &Self) -> Option<Self> {
    367         assert_eq!(self.currency, rhs.currency);
    368         let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?;
    369         Some((self.currency, decimal).into())
    370     }
    371 
    372     pub fn try_sub(self, rhs: &Self) -> Option<Self> {
    373         assert_eq!(self.currency, rhs.currency);
    374         let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?;
    375         Some((self.currency, decimal).into())
    376     }
    377 }
    378 
    379 impl From<(Currency, Decimal)> for Amount {
    380     fn from((currency, decimal): (Currency, Decimal)) -> Self {
    381         Self::new_decimal(&currency, decimal)
    382     }
    383 }
    384 
    385 #[track_caller]
    386 pub fn amount(amount: impl AsRef<str>) -> Amount {
    387     amount.as_ref().parse().expect("Invalid amount constant")
    388 }
    389 
    390 #[derive(Debug, thiserror::Error)]
    391 pub enum AmountErrKind {
    392     #[error("invalid format")]
    393     Format,
    394     #[error("currency {0}")]
    395     Currency(#[from] CurrencyErrorKind),
    396     #[error(transparent)]
    397     Decimal(#[from] DecimalErrKind),
    398 }
    399 
    400 #[derive(Debug, thiserror::Error)]
    401 #[error("amount '{amount}' {kind}")]
    402 pub struct ParseAmountErr {
    403     amount: String,
    404     pub kind: AmountErrKind,
    405 }
    406 
    407 impl FromStr for Amount {
    408     type Err = ParseAmountErr;
    409 
    410     fn from_str(s: &str) -> Result<Self, Self::Err> {
    411         // TODO use try block when stable
    412         (|| {
    413             let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?;
    414             let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
    415             let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?;
    416             Ok((currency, decimal).into())
    417         })()
    418         .map_err(|kind| ParseAmountErr {
    419             amount: s.to_owned(),
    420             kind,
    421         })
    422     }
    423 }
    424 
    425 impl Display for Amount {
    426     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    427         f.write_fmt(format_args!("{}:{}", self.currency, self.decimal()))
    428     }
    429 }
    430 
    431 impl Debug for Amount {
    432     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    433         Display::fmt(&self, f)
    434     }
    435 }
    436 
    437 impl sqlx::Type<sqlx::Postgres> for Amount {
    438     fn type_info() -> sqlx::postgres::PgTypeInfo {
    439         PgTalerAmount::type_info()
    440     }
    441 }
    442 
    443 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount {
    444     fn encode_by_ref(
    445         &self,
    446         buf: &mut sqlx::postgres::PgArgumentBuffer,
    447     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    448         self.decimal().encode_by_ref(buf)
    449     }
    450 }
    451 
    452 #[test]
    453 fn test_amount_parse() {
    454     const TALER_AMOUNT_FRAC_BASE: u32 = 100000000;
    455     // https://git.taler.net/exchange.git/tree/src/util/test_amount.c
    456 
    457     const INVALID_AMOUNTS: [&str; 6] = [
    458         "EUR:4a",                                                                     // non-numeric,
    459         "EUR:4.4a",                                                                   // non-numeric
    460         "EUR:4.a4",                                                                   // non-numeric
    461         ":4.a4",                                                                      // no currency
    462         "EUR:4.123456789", // precision to high
    463         "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big
    464     ];
    465 
    466     for str in INVALID_AMOUNTS {
    467         let amount = Amount::from_str(str);
    468         assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
    469     }
    470 
    471     let eur: Currency = Currency::EUR;
    472     let local: Currency = Currency::CHF;
    473     let valid_amounts: Vec<(&str, &str, Amount)> = vec![
    474         ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction
    475         (
    476             "EUR:0.02",
    477             "EUR:0.02",
    478             Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2),
    479         ), // leading zero fraction
    480         (
    481             " EUR:4.12",
    482             "EUR:4.12",
    483             Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12),
    484         ), // leading space and fraction
    485         (
    486             " CHF:4444.1000",
    487             "CHF:4444.1",
    488             Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10),
    489         ), // local currency
    490     ];
    491     for (raw, expected, goal) in valid_amounts {
    492         let amount = Amount::from_str(raw);
    493         assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount);
    494         assert_eq!(
    495             *amount.as_ref().unwrap(),
    496             goal,
    497             "Expected {:?} got {:?} for {}",
    498             goal,
    499             amount,
    500             raw
    501         );
    502         let amount = amount.unwrap();
    503         let str = amount.to_string();
    504         assert_eq!(str, expected);
    505         assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}");
    506     }
    507 }
    508 
    509 #[test]
    510 fn test_amount_add() {
    511     let eur: Currency = Currency::EUR;
    512     assert_eq!(
    513         Amount::max(&eur).try_add(&Amount::zero(&eur)),
    514         Some(Amount::max(&eur))
    515     );
    516     assert_eq!(
    517         Amount::zero(&eur).try_add(&Amount::zero(&eur)),
    518         Some(Amount::zero(&eur))
    519     );
    520     assert_eq!(
    521         amount("EUR:6.41").try_add(&amount("EUR:4.69")),
    522         Some(amount("EUR:11.1"))
    523     );
    524     assert_eq!(
    525         amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")),
    526         Some(Amount::max(&eur))
    527     );
    528 
    529     assert_eq!(
    530         amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")),
    531         None
    532     );
    533     assert_eq!(
    534         Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")),
    535         None
    536     );
    537     assert_eq!(
    538         amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1))
    539             .try_add(&amount("EUR:5.00000002")),
    540         None
    541     );
    542 }
    543 
    544 #[test]
    545 fn test_amount_normalize() {
    546     let eur: Currency = "EUR".parse().unwrap();
    547     assert_eq!(
    548         Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(),
    549         Some(amount("EUR:6"))
    550     );
    551     assert_eq!(
    552         Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(),
    553         Some(amount("EUR:6.00000001"))
    554     );
    555     assert_eq!(
    556         Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(),
    557         Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1))
    558     );
    559     assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None);
    560     assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None);
    561 
    562     for amount in [Amount::max(&eur), Amount::zero(&eur)] {
    563         assert_eq!(amount.normalize(), Some(amount))
    564     }
    565 }