taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

amount.rs (16637B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
     18 
     19 use std::{
     20     fmt::{Debug, Display},
     21     num::ParseIntError,
     22     str::FromStr,
     23 };
     24 
     25 use compact_str::format_compact;
     26 
     27 use super::utils::InlineStr;
     28 
     29 /** Number of characters we use to represent currency names */
     30 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
     31 pub const CURRENCY_LEN: usize = 11;
     32 
     33 /** Maximum legal value for an amount, based on IEEE double */
     34 pub const MAX_VALUE: u64 = 2 << 51;
     35 
     36 /** The number of digits in a fraction part of an amount */
     37 pub const FRAC_BASE_NB_DIGITS: u8 = 8;
     38 
     39 /** The fraction part of an amount represents which fraction of the value */
     40 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
     41 
     42 const CENT_FRACTION: u32 = 10u32.pow((FRAC_BASE_NB_DIGITS - 2) as u32);
     43 
     44 #[derive(
     45     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
     46 )]
     47 /// Inlined ISO 4217 currency string
     48 pub struct Currency(InlineStr<CURRENCY_LEN>);
     49 
     50 impl AsRef<str> for Currency {
     51     fn as_ref(&self) -> &str {
     52         self.0.as_ref()
     53     }
     54 }
     55 
     56 #[derive(Debug, thiserror::Error)]
     57 pub enum CurrencyErrorKind {
     58     #[error("contains illegal characters (only A-Z allowed)")]
     59     Invalid,
     60     #[error("too long (max {CURRENCY_LEN} chars)")]
     61     Big,
     62     #[error("is empty")]
     63     Empty,
     64 }
     65 
     66 #[derive(Debug, thiserror::Error)]
     67 #[error("currency code name '{currency}' {kind}")]
     68 pub struct ParseCurrencyError {
     69     currency: String,
     70     pub kind: CurrencyErrorKind,
     71 }
     72 
     73 impl Currency {
     74     pub const TEST: Self = Self::const_parse("TEST");
     75     pub const KUDOS: Self = Self::const_parse("KUDOS");
     76     pub const EUR: Self = Self::const_parse("EUR");
     77     pub const CHF: Self = Self::const_parse("CHF");
     78     pub const HUF: Self = Self::const_parse("HUF");
     79 
     80     pub const fn const_parse(s: &str) -> Currency {
     81         let bytes = s.as_bytes();
     82         let len = bytes.len();
     83 
     84         if bytes.is_empty() {
     85             panic!("empty")
     86         } else if len > CURRENCY_LEN {
     87             panic!("too big")
     88         }
     89         let mut i = 0;
     90         while i < bytes.len() {
     91             if !bytes[i].is_ascii_uppercase() {
     92                 panic!("invalid")
     93             }
     94             i += 1;
     95         }
     96         Self(InlineStr::copy_from_slice(bytes))
     97     }
     98 }
     99 
    100 impl FromStr for Currency {
    101     type Err = ParseCurrencyError;
    102 
    103     fn from_str(s: &str) -> Result<Self, Self::Err> {
    104         let bytes = s.as_bytes();
    105         let len = bytes.len();
    106         if bytes.is_empty() {
    107             Err(CurrencyErrorKind::Empty)
    108         } else if len > CURRENCY_LEN {
    109             Err(CurrencyErrorKind::Big)
    110         } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
    111             Err(CurrencyErrorKind::Invalid)
    112         } else {
    113             Ok(Self(InlineStr::copy_from_slice(bytes)))
    114         }
    115         .map_err(|kind| ParseCurrencyError {
    116             currency: s.to_owned(),
    117             kind,
    118         })
    119     }
    120 }
    121 
    122 impl Debug for Currency {
    123     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    124         Debug::fmt(&self.as_ref(), f)
    125     }
    126 }
    127 
    128 impl Display for Currency {
    129     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    130         Display::fmt(&self.as_ref(), f)
    131     }
    132 }
    133 
    134 #[derive(sqlx::Type)]
    135 #[sqlx(type_name = "taler_amount")]
    136 struct PgTalerAmount {
    137     pub val: i64,
    138     pub frac: i32,
    139 }
    140 
    141 #[derive(
    142     Clone,
    143     Copy,
    144     PartialEq,
    145     Eq,
    146     PartialOrd,
    147     Ord,
    148     serde_with::DeserializeFromStr,
    149     serde_with::SerializeDisplay,
    150 )]
    151 pub struct Decimal {
    152     /** Integer part */
    153     pub val: u64,
    154     /** Factional part, multiple of FRAC_BASE */
    155     pub frac: u32,
    156 }
    157 
    158 impl Decimal {
    159     pub const fn new(val: u64, frac: u32) -> Self {
    160         Self { val, frac }
    161     }
    162 
    163     pub const ZERO: Self = Self::new(0, 0);
    164     pub const MAX: Self = Self::new(MAX_VALUE, FRAC_BASE - 1);
    165 
    166     const fn normalize(mut self) -> Option<Self> {
    167         let Some(val) = self.val.checked_add((self.frac / FRAC_BASE) as u64) else {
    168             return None;
    169         };
    170         self.val = val;
    171         self.frac %= FRAC_BASE;
    172         if self.val > MAX_VALUE {
    173             return None;
    174         }
    175         Some(self)
    176     }
    177 
    178     pub fn try_add(mut self, rhs: &Self) -> Option<Self> {
    179         self.val = self.val.checked_add(rhs.val)?;
    180         self.frac = self
    181             .frac
    182             .checked_add(rhs.frac)
    183             .expect("amount fraction overflow should never happen with normalized amounts");
    184         self.normalize()
    185     }
    186 
    187     pub fn try_sub(mut self, rhs: &Self) -> Option<Self> {
    188         if rhs.frac > self.frac {
    189             self.val = self.val.checked_sub(1)?;
    190             self.frac += FRAC_BASE;
    191         }
    192         self.val = self.val.checked_sub(rhs.val)?;
    193         self.frac = self.frac.checked_sub(rhs.frac)?;
    194         self.normalize()
    195     }
    196 
    197     pub const fn to_amount(self, currency: &Currency) -> Amount {
    198         Amount::new_decimal(currency, self)
    199     }
    200 }
    201 
    202 #[derive(Debug, thiserror::Error)]
    203 pub enum DecimalErrKind {
    204     #[error("value overflow (must be <= {MAX_VALUE})")]
    205     Overflow,
    206     #[error("invalid value ({0})")]
    207     InvalidValue(ParseIntError),
    208     #[error("invalid fraction ({0})")]
    209     InvalidFraction(ParseIntError),
    210     #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
    211     FractionOverflow,
    212 }
    213 
    214 #[derive(Debug, thiserror::Error)]
    215 #[error("decimal '{decimal}' {kind}")]
    216 pub struct ParseDecimalErr {
    217     decimal: String,
    218     pub kind: DecimalErrKind,
    219 }
    220 
    221 impl FromStr for Decimal {
    222     type Err = ParseDecimalErr;
    223 
    224     fn from_str(s: &str) -> Result<Self, Self::Err> {
    225         let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
    226 
    227         // TODO use try block when stable
    228         (|| {
    229             let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?;
    230             if value > MAX_VALUE {
    231                 return Err(DecimalErrKind::Overflow);
    232             }
    233 
    234             if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
    235                 return Err(DecimalErrKind::FractionOverflow);
    236             }
    237             let fraction: u32 = if fraction.is_empty() {
    238                 0
    239             } else {
    240                 fraction
    241                     .parse::<u32>()
    242                     .map_err(DecimalErrKind::InvalidFraction)?
    243                     * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
    244             };
    245             Ok(Self {
    246                 val: value,
    247                 frac: fraction,
    248             })
    249         })()
    250         .map_err(|kind| ParseDecimalErr {
    251             decimal: s.to_owned(),
    252             kind,
    253         })
    254     }
    255 }
    256 
    257 impl Display for Decimal {
    258     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    259         if self.frac == 0 {
    260             f.write_fmt(format_args!("{}", self.val))
    261         } else {
    262             let num = format_compact!("{:08}", self.frac);
    263             f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0')))
    264         }
    265     }
    266 }
    267 
    268 impl Debug for Decimal {
    269     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    270         Display::fmt(&self, f)
    271     }
    272 }
    273 
    274 impl sqlx::Type<sqlx::Postgres> for Decimal {
    275     fn type_info() -> sqlx::postgres::PgTypeInfo {
    276         PgTalerAmount::type_info()
    277     }
    278 }
    279 
    280 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal {
    281     fn encode_by_ref(
    282         &self,
    283         buf: &mut sqlx::postgres::PgArgumentBuffer,
    284     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    285         PgTalerAmount {
    286             val: self.val as i64,
    287             frac: self.frac as i32,
    288         }
    289         .encode_by_ref(buf)
    290     }
    291 }
    292 
    293 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal {
    294     fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
    295         let pg = PgTalerAmount::decode(value)?;
    296         Ok(Self {
    297             val: pg.val as u64,
    298             frac: pg.frac as u32,
    299         })
    300     }
    301 }
    302 
    303 #[track_caller]
    304 pub fn decimal(decimal: impl AsRef<str>) -> Decimal {
    305     decimal.as_ref().parse().expect("Invalid decimal constant")
    306 }
    307 
    308 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
    309 #[derive(
    310     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    311 )]
    312 pub struct Amount {
    313     pub currency: Currency,
    314     pub val: u64,
    315     pub frac: u32,
    316 }
    317 
    318 impl Amount {
    319     pub const fn new_decimal(currency: &Currency, decimal: Decimal) -> Self {
    320         Self {
    321             currency: *currency,
    322             val: decimal.val,
    323             frac: decimal.frac,
    324         }
    325     }
    326 
    327     pub const fn new(currency: &Currency, val: u64, frac: u32) -> Self {
    328         Self::new_decimal(currency, Decimal { val, frac })
    329     }
    330 
    331     pub const fn max(currency: &Currency) -> Self {
    332         Self::new_decimal(currency, Decimal::MAX)
    333     }
    334 
    335     pub const fn zero(currency: &Currency) -> Self {
    336         Self::new_decimal(currency, Decimal::ZERO)
    337     }
    338 
    339     pub fn is_zero(&self) -> bool {
    340         self.decimal() == Decimal::ZERO
    341     }
    342 
    343     /* Check is amount has fractional amount < 0.01 */
    344     pub const fn is_sub_cent(&self) -> bool {
    345         !self.frac.is_multiple_of(CENT_FRACTION)
    346     }
    347 
    348     pub const fn decimal(&self) -> Decimal {
    349         Decimal {
    350             val: self.val,
    351             frac: self.frac,
    352         }
    353     }
    354 
    355     pub fn normalize(self) -> Option<Self> {
    356         let decimal = self.decimal().normalize()?;
    357         Some((self.currency, decimal).into())
    358     }
    359 
    360     pub fn try_add(self, rhs: &Self) -> Option<Self> {
    361         assert_eq!(self.currency, rhs.currency);
    362         let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?;
    363         Some((self.currency, decimal).into())
    364     }
    365 
    366     pub fn try_sub(self, rhs: &Self) -> Option<Self> {
    367         assert_eq!(self.currency, rhs.currency);
    368         let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?;
    369         Some((self.currency, decimal).into())
    370     }
    371 }
    372 
    373 impl From<(Currency, Decimal)> for Amount {
    374     fn from((currency, decimal): (Currency, Decimal)) -> Self {
    375         Self::new_decimal(&currency, decimal)
    376     }
    377 }
    378 
    379 #[track_caller]
    380 pub fn amount(amount: impl AsRef<str>) -> Amount {
    381     amount.as_ref().parse().expect("Invalid amount constant")
    382 }
    383 
    384 #[derive(Debug, thiserror::Error)]
    385 pub enum AmountErrKind {
    386     #[error("invalid format")]
    387     Format,
    388     #[error("currency {0}")]
    389     Currency(#[from] CurrencyErrorKind),
    390     #[error(transparent)]
    391     Decimal(#[from] DecimalErrKind),
    392 }
    393 
    394 #[derive(Debug, thiserror::Error)]
    395 #[error("amount '{amount}' {kind}")]
    396 pub struct ParseAmountErr {
    397     amount: String,
    398     pub kind: AmountErrKind,
    399 }
    400 
    401 impl FromStr for Amount {
    402     type Err = ParseAmountErr;
    403 
    404     fn from_str(s: &str) -> Result<Self, Self::Err> {
    405         // TODO use try block when stable
    406         (|| {
    407             let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?;
    408             let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
    409             let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?;
    410             Ok((currency, decimal).into())
    411         })()
    412         .map_err(|kind| ParseAmountErr {
    413             amount: s.to_owned(),
    414             kind,
    415         })
    416     }
    417 }
    418 
    419 impl Display for Amount {
    420     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    421         f.write_fmt(format_args!("{}:{}", self.currency, self.decimal()))
    422     }
    423 }
    424 
    425 impl Debug for Amount {
    426     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    427         Display::fmt(&self, f)
    428     }
    429 }
    430 
    431 impl sqlx::Type<sqlx::Postgres> for Amount {
    432     fn type_info() -> sqlx::postgres::PgTypeInfo {
    433         PgTalerAmount::type_info()
    434     }
    435 }
    436 
    437 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount {
    438     fn encode_by_ref(
    439         &self,
    440         buf: &mut sqlx::postgres::PgArgumentBuffer,
    441     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    442         self.decimal().encode_by_ref(buf)
    443     }
    444 }
    445 
    446 #[test]
    447 fn constants() {
    448     assert_eq!(format!("{}", Amount::zero(&Currency::KUDOS)), "KUDOS:0");
    449     assert_eq!(
    450         format!("{}", Amount::max(&Currency::KUDOS)),
    451         "KUDOS:4503599627370496.99999999"
    452     );
    453 }
    454 
    455 #[test]
    456 fn test_amount_parse() {
    457     const TALER_AMOUNT_FRAC_BASE: u32 = 100000000;
    458     // https://git.taler.net/exchange.git/tree/src/util/test_amount.c
    459 
    460     const INVALID_AMOUNTS: [&str; 6] = [
    461         "EUR:4a",                                                                     // non-numeric,
    462         "EUR:4.4a",                                                                   // non-numeric
    463         "EUR:4.a4",                                                                   // non-numeric
    464         ":4.a4",                                                                      // no currency
    465         "EUR:4.123456789", // precision to high
    466         "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big
    467     ];
    468 
    469     for str in INVALID_AMOUNTS {
    470         let amount = Amount::from_str(str);
    471         assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
    472     }
    473 
    474     let eur: Currency = Currency::EUR;
    475     let local: Currency = Currency::CHF;
    476     let valid_amounts: Vec<(&str, &str, Amount)> = vec![
    477         ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction
    478         (
    479             "EUR:0.02",
    480             "EUR:0.02",
    481             Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2),
    482         ), // leading zero fraction
    483         (
    484             " EUR:4.12",
    485             "EUR:4.12",
    486             Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12),
    487         ), // leading space and fraction
    488         (
    489             " CHF:4444.1000",
    490             "CHF:4444.1",
    491             Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10),
    492         ), // local currency
    493     ];
    494     for (raw, expected, goal) in valid_amounts {
    495         let amount = Amount::from_str(raw);
    496         assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount);
    497         assert_eq!(
    498             *amount.as_ref().unwrap(),
    499             goal,
    500             "Expected {:?} got {:?} for {}",
    501             goal,
    502             amount,
    503             raw
    504         );
    505         let amount = amount.unwrap();
    506         let str = amount.to_string();
    507         assert_eq!(str, expected);
    508         assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}");
    509     }
    510 }
    511 
    512 #[test]
    513 fn test_amount_add() {
    514     let eur: Currency = Currency::EUR;
    515     assert_eq!(
    516         Amount::max(&eur).try_add(&Amount::zero(&eur)),
    517         Some(Amount::max(&eur))
    518     );
    519     assert_eq!(
    520         Amount::zero(&eur).try_add(&Amount::zero(&eur)),
    521         Some(Amount::zero(&eur))
    522     );
    523     assert_eq!(
    524         amount("EUR:6.41").try_add(&amount("EUR:4.69")),
    525         Some(amount("EUR:11.1"))
    526     );
    527     assert_eq!(
    528         amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")),
    529         Some(Amount::max(&eur))
    530     );
    531 
    532     assert_eq!(
    533         amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")),
    534         None
    535     );
    536     assert_eq!(
    537         Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")),
    538         None
    539     );
    540     assert_eq!(
    541         amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1))
    542             .try_add(&amount("EUR:5.00000002")),
    543         None
    544     );
    545 }
    546 
    547 #[test]
    548 fn test_amount_normalize() {
    549     let eur: Currency = "EUR".parse().unwrap();
    550     assert_eq!(
    551         Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(),
    552         Some(amount("EUR:6"))
    553     );
    554     assert_eq!(
    555         Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(),
    556         Some(amount("EUR:6.00000001"))
    557     );
    558     assert_eq!(
    559         Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(),
    560         Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1))
    561     );
    562     assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None);
    563     assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None);
    564 
    565     for amount in [Amount::max(&eur), Amount::zero(&eur)] {
    566         assert_eq!(amount.normalize(), Some(amount))
    567     }
    568 }