taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

amount.rs (13300B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024-2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
     18 
     19 use std::{
     20     fmt::{Debug, Display},
     21     num::ParseIntError,
     22     str::FromStr,
     23 };
     24 
     25 use super::utils::InlineStr;
     26 
     27 /** Number of characters we use to represent currency names */
     28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
     29 pub const CURRENCY_LEN: usize = 11;
     30 
     31 /** Maximum legal value for an amount, based on IEEE double */
     32 pub const MAX_VALUE: u64 = 2 << 52;
     33 
     34 /** The number of digits in a fraction part of an amount */
     35 pub const FRAC_BASE_NB_DIGITS: u8 = 8;
     36 
     37 /** The fraction part of an amount represents which fraction of the value */
     38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
     39 
     40 #[derive(Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
     41 /// Inlined ISO 4217 currency string
     42 pub struct Currency(InlineStr<CURRENCY_LEN>);
     43 
     44 impl AsRef<str> for Currency {
     45     fn as_ref(&self) -> &str {
     46         self.0.as_ref()
     47     }
     48 }
     49 
     50 #[derive(Debug, thiserror::Error)]
     51 pub enum CurrencyErrorKind {
     52     #[error("contains illegal characters (only A-Z allowed)")]
     53     Invalid,
     54     #[error("too long (max {CURRENCY_LEN} chars)")]
     55     Big,
     56     #[error("is empty")]
     57     Empty,
     58 }
     59 
     60 #[derive(Debug, thiserror::Error)]
     61 #[error("currency code name '{currency}' {kind}")]
     62 pub struct ParseCurrencyError {
     63     currency: String,
     64     pub kind: CurrencyErrorKind,
     65 }
     66 
     67 impl FromStr for Currency {
     68     type Err = ParseCurrencyError;
     69 
     70     fn from_str(s: &str) -> Result<Self, Self::Err> {
     71         let bytes = s.as_bytes();
     72         let len = bytes.len();
     73         if bytes.is_empty() {
     74             Err(CurrencyErrorKind::Empty)
     75         } else if len > CURRENCY_LEN {
     76             Err(CurrencyErrorKind::Big)
     77         } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
     78             Err(CurrencyErrorKind::Invalid)
     79         } else {
     80             Ok(Self(InlineStr::copy_from_slice(bytes)))
     81         }
     82         .map_err(|kind| ParseCurrencyError {
     83             currency: s.to_owned(),
     84             kind,
     85         })
     86     }
     87 }
     88 
     89 impl Debug for Currency {
     90     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     91         Debug::fmt(&self.as_ref(), f)
     92     }
     93 }
     94 
     95 impl Display for Currency {
     96     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     97         Display::fmt(&self.as_ref(), f)
     98     }
     99 }
    100 
    101 #[derive(
    102     Debug, Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    103 )]
    104 pub struct Decimal {
    105     /** Integer part */
    106     pub val: u64,
    107     /** Factional part, multiple of FRAC_BASE */
    108     pub frac: u32,
    109 }
    110 
    111 impl Decimal {
    112     pub fn new(val: u64, frac: u32) -> Self {
    113         Self { val, frac }
    114     }
    115 
    116     pub const fn max() -> Self {
    117         Self {
    118             val: MAX_VALUE,
    119             frac: FRAC_BASE - 1,
    120         }
    121     }
    122 
    123     pub const fn zero() -> Self {
    124         Self { val: 0, frac: 0 }
    125     }
    126 
    127     fn normalize(mut self) -> Option<Self> {
    128         self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?;
    129         self.frac %= FRAC_BASE;
    130         if self.val > MAX_VALUE {
    131             return None;
    132         }
    133         Some(self)
    134     }
    135 
    136     pub fn try_add(mut self, rhs: &Self) -> Option<Self> {
    137         self.val = self.val.checked_add(rhs.val)?;
    138         self.frac = self
    139             .frac
    140             .checked_add(rhs.frac)
    141             .expect("amount fraction overflow should never happen with normalized amounts");
    142         self.normalize()
    143     }
    144 
    145     pub fn try_sub(mut self, rhs: &Self) -> Option<Self> {
    146         if rhs.frac > self.frac {
    147             self.val = self.val.checked_sub(1)?;
    148             self.frac += FRAC_BASE;
    149         }
    150         self.val = self.val.checked_sub(rhs.val)?;
    151         self.frac = self.frac.checked_sub(rhs.frac)?;
    152         self.normalize()
    153     }
    154 }
    155 
    156 #[derive(Debug, thiserror::Error)]
    157 pub enum DecimalErrKind {
    158     #[error("value overflow (must be <= {MAX_VALUE})")]
    159     Overflow,
    160     #[error("invalid value ({0})")]
    161     InvalidValue(ParseIntError),
    162     #[error("invalid fraction ({0})")]
    163     InvalidFraction(ParseIntError),
    164     #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
    165     FractionOverflow,
    166 }
    167 
    168 #[derive(Debug, thiserror::Error)]
    169 #[error("decimal '{decimal}' {kind}")]
    170 pub struct ParseDecimalErr {
    171     decimal: String,
    172     pub kind: DecimalErrKind,
    173 }
    174 
    175 impl FromStr for Decimal {
    176     type Err = ParseDecimalErr;
    177 
    178     fn from_str(s: &str) -> Result<Self, Self::Err> {
    179         let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
    180 
    181         // TODO use try block when stable
    182         (|| {
    183             let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?;
    184             if value > MAX_VALUE {
    185                 return Err(DecimalErrKind::Overflow);
    186             }
    187 
    188             if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
    189                 return Err(DecimalErrKind::FractionOverflow);
    190             }
    191             let fraction: u32 = if fraction.is_empty() {
    192                 0
    193             } else {
    194                 fraction
    195                     .parse::<u32>()
    196                     .map_err(DecimalErrKind::InvalidFraction)?
    197                     * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
    198             };
    199             Ok(Self {
    200                 val: value,
    201                 frac: fraction,
    202             })
    203         })()
    204         .map_err(|kind| ParseDecimalErr {
    205             decimal: s.to_owned(),
    206             kind,
    207         })
    208     }
    209 }
    210 
    211 impl Display for Decimal {
    212     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    213         if self.frac == 0 {
    214             f.write_fmt(format_args!("{}", self.val))
    215         } else {
    216             let num = format!("{:08}", self.frac);
    217             f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0')))
    218         }
    219     }
    220 }
    221 
    222 #[track_caller]
    223 pub fn decimal(decimal: impl AsRef<str>) -> Decimal {
    224     decimal.as_ref().parse().expect("Invalid decimal constant")
    225 }
    226 
    227 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
    228 #[derive(
    229     Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    230 )]
    231 pub struct Amount {
    232     pub currency: Currency,
    233     pub val: u64,
    234     pub frac: u32,
    235 }
    236 
    237 impl Amount {
    238     pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self {
    239         (currency.clone(), decimal).into()
    240     }
    241 
    242     pub fn new(currency: &Currency, val: u64, frac: u32) -> Self {
    243         Self::new_decimal(currency, Decimal { val, frac })
    244     }
    245 
    246     pub fn max(currency: &Currency) -> Self {
    247         Self::new_decimal(currency, Decimal::max())
    248     }
    249 
    250     pub fn zero(currency: &Currency) -> Self {
    251         Self::new_decimal(currency, Decimal::zero())
    252     }
    253 
    254     pub const fn decimal(&self) -> Decimal {
    255         Decimal {
    256             val: self.val,
    257             frac: self.frac,
    258         }
    259     }
    260 
    261     pub fn normalize(self) -> Option<Self> {
    262         let decimal = self.decimal().normalize()?;
    263         Some((self.currency, decimal).into())
    264     }
    265 
    266     pub fn try_add(self, rhs: &Self) -> Option<Self> {
    267         assert_eq!(self.currency, rhs.currency);
    268         let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?;
    269         Some((self.currency, decimal).into())
    270     }
    271 }
    272 
    273 impl From<(Currency, Decimal)> for Amount {
    274     fn from((currency, decimal): (Currency, Decimal)) -> Self {
    275         Self {
    276             currency,
    277             val: decimal.val,
    278             frac: decimal.frac,
    279         }
    280     }
    281 }
    282 
    283 #[track_caller]
    284 pub fn amount(amount: impl AsRef<str>) -> Amount {
    285     amount.as_ref().parse().expect("Invalid amount constant")
    286 }
    287 
    288 #[derive(Debug, thiserror::Error)]
    289 pub enum AmountErrKind {
    290     #[error("invalid format")]
    291     Format,
    292     #[error("currency {0}")]
    293     Currency(#[from] CurrencyErrorKind),
    294     #[error(transparent)]
    295     Decimal(#[from] DecimalErrKind),
    296 }
    297 
    298 #[derive(Debug, thiserror::Error)]
    299 #[error("amount '{amount}' {kind}")]
    300 pub struct ParseAmountErr {
    301     amount: String,
    302     pub kind: AmountErrKind,
    303 }
    304 
    305 impl FromStr for Amount {
    306     type Err = ParseAmountErr;
    307 
    308     fn from_str(s: &str) -> Result<Self, Self::Err> {
    309         // TODO use try block when stable
    310         (|| {
    311             let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?;
    312             let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
    313             let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?;
    314             Ok((currency, decimal).into())
    315         })()
    316         .map_err(|kind| ParseAmountErr {
    317             amount: s.to_owned(),
    318             kind,
    319         })
    320     }
    321 }
    322 
    323 impl Display for Amount {
    324     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    325         f.write_fmt(format_args!("{}:{}", self.currency, self.decimal()))
    326     }
    327 }
    328 
    329 #[test]
    330 fn test_amount_parse() {
    331     const TALER_AMOUNT_FRAC_BASE: u32 = 100000000;
    332     // https://git.taler.net/exchange.git/tree/src/util/test_amount.c
    333 
    334     const INVALID_AMOUNTS: [&str; 6] = [
    335         "EUR:4a",                                                                     // non-numeric,
    336         "EUR:4.4a",                                                                   // non-numeric
    337         "EUR:4.a4",                                                                   // non-numeric
    338         ":4.a4",                                                                      // no currency
    339         "EUR:4.123456789", // precision to high
    340         "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big
    341     ];
    342 
    343     for str in INVALID_AMOUNTS {
    344         let amount = Amount::from_str(str);
    345         assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
    346     }
    347 
    348     let eur: Currency = "EUR".parse().unwrap();
    349     let local: Currency = "LOCAL".parse().unwrap();
    350     let valid_amounts: Vec<(&str, &str, Amount)> = vec![
    351         ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction
    352         (
    353             "EUR:0.02",
    354             "EUR:0.02",
    355             Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2),
    356         ), // leading zero fraction
    357         (
    358             " EUR:4.12",
    359             "EUR:4.12",
    360             Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12),
    361         ), // leading space and fraction
    362         (
    363             " LOCAL:4444.1000",
    364             "LOCAL:4444.1",
    365             Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10),
    366         ), // local currency
    367     ];
    368     for (raw, expected, goal) in valid_amounts {
    369         let amount = Amount::from_str(raw);
    370         assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount);
    371         assert_eq!(
    372             *amount.as_ref().unwrap(),
    373             goal,
    374             "Expected {:?} got {:?} for {}",
    375             goal,
    376             amount,
    377             raw
    378         );
    379         let amount = amount.unwrap();
    380         let str = amount.to_string();
    381         assert_eq!(str, expected);
    382         assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}");
    383     }
    384 }
    385 
    386 #[test]
    387 fn test_amount_add() {
    388     let eur: Currency = "EUR".parse().unwrap();
    389     assert_eq!(
    390         Amount::max(&eur).try_add(&Amount::zero(&eur)),
    391         Some(Amount::max(&eur))
    392     );
    393     assert_eq!(
    394         Amount::zero(&eur).try_add(&Amount::zero(&eur)),
    395         Some(Amount::zero(&eur))
    396     );
    397     assert_eq!(
    398         amount("EUR:6.41").try_add(&amount("EUR:4.69")),
    399         Some(amount("EUR:11.1"))
    400     );
    401     assert_eq!(
    402         amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")),
    403         Some(Amount::max(&eur))
    404     );
    405 
    406     assert_eq!(
    407         amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")),
    408         None
    409     );
    410     assert_eq!(
    411         Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")),
    412         None
    413     );
    414     assert_eq!(
    415         amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1))
    416             .try_add(&amount("EUR:5.00000002")),
    417         None
    418     );
    419 }
    420 
    421 #[test]
    422 fn test_amount_normalize() {
    423     let eur: Currency = "EUR".parse().unwrap();
    424     assert_eq!(
    425         Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(),
    426         Some(amount("EUR:6"))
    427     );
    428     assert_eq!(
    429         Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(),
    430         Some(amount("EUR:6.00000001"))
    431     );
    432     assert_eq!(
    433         Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(),
    434         Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1))
    435     );
    436     assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None);
    437     assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None);
    438 
    439     for amount in [Amount::max(&eur), Amount::zero(&eur)] {
    440         assert_eq!(amount.clone().normalize(), Some(amount))
    441     }
    442 }