taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

amount.rs (14731B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024-2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
     18 
     19 use std::{
     20     fmt::{Debug, Display},
     21     num::ParseIntError,
     22     str::FromStr,
     23 };
     24 
     25 use super::utils::InlineStr;
     26 
     27 /** Number of characters we use to represent currency names */
     28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
     29 pub const CURRENCY_LEN: usize = 11;
     30 
     31 /** Maximum legal value for an amount, based on IEEE double */
     32 pub const MAX_VALUE: u64 = 2 << 52;
     33 
     34 /** The number of digits in a fraction part of an amount */
     35 pub const FRAC_BASE_NB_DIGITS: u8 = 8;
     36 
     37 /** The fraction part of an amount represents which fraction of the value */
     38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
     39 
     40 #[derive(Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)]
     41 /// Inlined ISO 4217 currency string
     42 pub struct Currency(InlineStr<CURRENCY_LEN>);
     43 
     44 impl AsRef<str> for Currency {
     45     fn as_ref(&self) -> &str {
     46         self.0.as_ref()
     47     }
     48 }
     49 
     50 #[derive(Debug, thiserror::Error)]
     51 pub enum CurrencyErrorKind {
     52     #[error("contains illegal characters (only A-Z allowed)")]
     53     Invalid,
     54     #[error("too long (max {CURRENCY_LEN} chars)")]
     55     Big,
     56     #[error("is empty")]
     57     Empty,
     58 }
     59 
     60 #[derive(Debug, thiserror::Error)]
     61 #[error("currency code name '{currency}' {kind}")]
     62 pub struct ParseCurrencyError {
     63     currency: String,
     64     pub kind: CurrencyErrorKind,
     65 }
     66 
     67 impl FromStr for Currency {
     68     type Err = ParseCurrencyError;
     69 
     70     fn from_str(s: &str) -> Result<Self, Self::Err> {
     71         let bytes = s.as_bytes();
     72         let len = bytes.len();
     73         if bytes.is_empty() {
     74             Err(CurrencyErrorKind::Empty)
     75         } else if len > CURRENCY_LEN {
     76             Err(CurrencyErrorKind::Big)
     77         } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
     78             Err(CurrencyErrorKind::Invalid)
     79         } else {
     80             Ok(Self(InlineStr::copy_from_slice(bytes)))
     81         }
     82         .map_err(|kind| ParseCurrencyError {
     83             currency: s.to_owned(),
     84             kind,
     85         })
     86     }
     87 }
     88 
     89 impl Debug for Currency {
     90     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     91         Debug::fmt(&self.as_ref(), f)
     92     }
     93 }
     94 
     95 impl Display for Currency {
     96     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     97         Display::fmt(&self.as_ref(), f)
     98     }
     99 }
    100 
    101 #[derive(sqlx::Type)]
    102 #[sqlx(type_name = "taler_amount")]
    103 struct PgTalerAmount {
    104     pub val: i64,
    105     pub frac: i32,
    106 }
    107 
    108 #[derive(
    109     Debug, Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    110 )]
    111 pub struct Decimal {
    112     /** Integer part */
    113     pub val: u64,
    114     /** Factional part, multiple of FRAC_BASE */
    115     pub frac: u32,
    116 }
    117 
    118 impl Decimal {
    119     pub fn new(val: u64, frac: u32) -> Self {
    120         Self { val, frac }
    121     }
    122 
    123     pub const fn max() -> Self {
    124         Self {
    125             val: MAX_VALUE,
    126             frac: FRAC_BASE - 1,
    127         }
    128     }
    129 
    130     pub const fn zero() -> Self {
    131         Self { val: 0, frac: 0 }
    132     }
    133 
    134     fn normalize(mut self) -> Option<Self> {
    135         self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?;
    136         self.frac %= FRAC_BASE;
    137         if self.val > MAX_VALUE {
    138             return None;
    139         }
    140         Some(self)
    141     }
    142 
    143     pub fn try_add(mut self, rhs: &Self) -> Option<Self> {
    144         self.val = self.val.checked_add(rhs.val)?;
    145         self.frac = self
    146             .frac
    147             .checked_add(rhs.frac)
    148             .expect("amount fraction overflow should never happen with normalized amounts");
    149         self.normalize()
    150     }
    151 
    152     pub fn try_sub(mut self, rhs: &Self) -> Option<Self> {
    153         if rhs.frac > self.frac {
    154             self.val = self.val.checked_sub(1)?;
    155             self.frac += FRAC_BASE;
    156         }
    157         self.val = self.val.checked_sub(rhs.val)?;
    158         self.frac = self.frac.checked_sub(rhs.frac)?;
    159         self.normalize()
    160     }
    161 
    162     pub fn to_amount(self, currency: &Currency) -> Amount {
    163         Amount::new_decimal(currency, self)
    164     }
    165 }
    166 
    167 #[derive(Debug, thiserror::Error)]
    168 pub enum DecimalErrKind {
    169     #[error("value overflow (must be <= {MAX_VALUE})")]
    170     Overflow,
    171     #[error("invalid value ({0})")]
    172     InvalidValue(ParseIntError),
    173     #[error("invalid fraction ({0})")]
    174     InvalidFraction(ParseIntError),
    175     #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
    176     FractionOverflow,
    177 }
    178 
    179 #[derive(Debug, thiserror::Error)]
    180 #[error("decimal '{decimal}' {kind}")]
    181 pub struct ParseDecimalErr {
    182     decimal: String,
    183     pub kind: DecimalErrKind,
    184 }
    185 
    186 impl FromStr for Decimal {
    187     type Err = ParseDecimalErr;
    188 
    189     fn from_str(s: &str) -> Result<Self, Self::Err> {
    190         let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
    191 
    192         // TODO use try block when stable
    193         (|| {
    194             let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?;
    195             if value > MAX_VALUE {
    196                 return Err(DecimalErrKind::Overflow);
    197             }
    198 
    199             if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
    200                 return Err(DecimalErrKind::FractionOverflow);
    201             }
    202             let fraction: u32 = if fraction.is_empty() {
    203                 0
    204             } else {
    205                 fraction
    206                     .parse::<u32>()
    207                     .map_err(DecimalErrKind::InvalidFraction)?
    208                     * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
    209             };
    210             Ok(Self {
    211                 val: value,
    212                 frac: fraction,
    213             })
    214         })()
    215         .map_err(|kind| ParseDecimalErr {
    216             decimal: s.to_owned(),
    217             kind,
    218         })
    219     }
    220 }
    221 
    222 impl Display for Decimal {
    223     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    224         if self.frac == 0 {
    225             f.write_fmt(format_args!("{}", self.val))
    226         } else {
    227             let num = format!("{:08}", self.frac);
    228             f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0')))
    229         }
    230     }
    231 }
    232 
    233 impl sqlx::Type<sqlx::Postgres> for Decimal {
    234     fn type_info() -> sqlx::postgres::PgTypeInfo {
    235         PgTalerAmount::type_info()
    236     }
    237 }
    238 
    239 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal {
    240     fn encode_by_ref(
    241         &self,
    242         buf: &mut sqlx::postgres::PgArgumentBuffer,
    243     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    244         PgTalerAmount {
    245             val: self.val as i64,
    246             frac: self.frac as i32,
    247         }
    248         .encode_by_ref(buf)
    249     }
    250 }
    251 
    252 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal {
    253     fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
    254         let pg = PgTalerAmount::decode(value)?;
    255         Ok(Self {
    256             val: pg.val as u64,
    257             frac: pg.frac as u32,
    258         })
    259     }
    260 }
    261 
    262 #[track_caller]
    263 pub fn decimal(decimal: impl AsRef<str>) -> Decimal {
    264     decimal.as_ref().parse().expect("Invalid decimal constant")
    265 }
    266 
    267 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
    268 #[derive(
    269     Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    270 )]
    271 pub struct Amount {
    272     pub currency: Currency,
    273     pub val: u64,
    274     pub frac: u32,
    275 }
    276 
    277 impl Amount {
    278     pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self {
    279         (currency.clone(), decimal).into()
    280     }
    281 
    282     pub fn new(currency: &Currency, val: u64, frac: u32) -> Self {
    283         Self::new_decimal(currency, Decimal { val, frac })
    284     }
    285 
    286     pub fn max(currency: &Currency) -> Self {
    287         Self::new_decimal(currency, Decimal::max())
    288     }
    289 
    290     pub fn zero(currency: &Currency) -> Self {
    291         Self::new_decimal(currency, Decimal::zero())
    292     }
    293 
    294     pub const fn decimal(&self) -> Decimal {
    295         Decimal {
    296             val: self.val,
    297             frac: self.frac,
    298         }
    299     }
    300 
    301     pub fn normalize(self) -> Option<Self> {
    302         let decimal = self.decimal().normalize()?;
    303         Some((self.currency, decimal).into())
    304     }
    305 
    306     pub fn try_add(self, rhs: &Self) -> Option<Self> {
    307         assert_eq!(self.currency, rhs.currency);
    308         let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?;
    309         Some((self.currency, decimal).into())
    310     }
    311 }
    312 
    313 impl From<(Currency, Decimal)> for Amount {
    314     fn from((currency, decimal): (Currency, Decimal)) -> Self {
    315         Self {
    316             currency,
    317             val: decimal.val,
    318             frac: decimal.frac,
    319         }
    320     }
    321 }
    322 
    323 #[track_caller]
    324 pub fn amount(amount: impl AsRef<str>) -> Amount {
    325     amount.as_ref().parse().expect("Invalid amount constant")
    326 }
    327 
    328 #[derive(Debug, thiserror::Error)]
    329 pub enum AmountErrKind {
    330     #[error("invalid format")]
    331     Format,
    332     #[error("currency {0}")]
    333     Currency(#[from] CurrencyErrorKind),
    334     #[error(transparent)]
    335     Decimal(#[from] DecimalErrKind),
    336 }
    337 
    338 #[derive(Debug, thiserror::Error)]
    339 #[error("amount '{amount}' {kind}")]
    340 pub struct ParseAmountErr {
    341     amount: String,
    342     pub kind: AmountErrKind,
    343 }
    344 
    345 impl FromStr for Amount {
    346     type Err = ParseAmountErr;
    347 
    348     fn from_str(s: &str) -> Result<Self, Self::Err> {
    349         // TODO use try block when stable
    350         (|| {
    351             let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?;
    352             let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
    353             let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?;
    354             Ok((currency, decimal).into())
    355         })()
    356         .map_err(|kind| ParseAmountErr {
    357             amount: s.to_owned(),
    358             kind,
    359         })
    360     }
    361 }
    362 
    363 impl Display for Amount {
    364     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    365         f.write_fmt(format_args!("{}:{}", self.currency, self.decimal()))
    366     }
    367 }
    368 
    369 impl sqlx::Type<sqlx::Postgres> for Amount {
    370     fn type_info() -> sqlx::postgres::PgTypeInfo {
    371         PgTalerAmount::type_info()
    372     }
    373 }
    374 
    375 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount {
    376     fn encode_by_ref(
    377         &self,
    378         buf: &mut sqlx::postgres::PgArgumentBuffer,
    379     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    380         self.decimal().encode_by_ref(buf)
    381     }
    382 }
    383 
    384 #[test]
    385 fn test_amount_parse() {
    386     const TALER_AMOUNT_FRAC_BASE: u32 = 100000000;
    387     // https://git.taler.net/exchange.git/tree/src/util/test_amount.c
    388 
    389     const INVALID_AMOUNTS: [&str; 6] = [
    390         "EUR:4a",                                                                     // non-numeric,
    391         "EUR:4.4a",                                                                   // non-numeric
    392         "EUR:4.a4",                                                                   // non-numeric
    393         ":4.a4",                                                                      // no currency
    394         "EUR:4.123456789", // precision to high
    395         "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big
    396     ];
    397 
    398     for str in INVALID_AMOUNTS {
    399         let amount = Amount::from_str(str);
    400         assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
    401     }
    402 
    403     let eur: Currency = "EUR".parse().unwrap();
    404     let local: Currency = "LOCAL".parse().unwrap();
    405     let valid_amounts: Vec<(&str, &str, Amount)> = vec![
    406         ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction
    407         (
    408             "EUR:0.02",
    409             "EUR:0.02",
    410             Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2),
    411         ), // leading zero fraction
    412         (
    413             " EUR:4.12",
    414             "EUR:4.12",
    415             Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12),
    416         ), // leading space and fraction
    417         (
    418             " LOCAL:4444.1000",
    419             "LOCAL:4444.1",
    420             Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10),
    421         ), // local currency
    422     ];
    423     for (raw, expected, goal) in valid_amounts {
    424         let amount = Amount::from_str(raw);
    425         assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount);
    426         assert_eq!(
    427             *amount.as_ref().unwrap(),
    428             goal,
    429             "Expected {:?} got {:?} for {}",
    430             goal,
    431             amount,
    432             raw
    433         );
    434         let amount = amount.unwrap();
    435         let str = amount.to_string();
    436         assert_eq!(str, expected);
    437         assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}");
    438     }
    439 }
    440 
    441 #[test]
    442 fn test_amount_add() {
    443     let eur: Currency = "EUR".parse().unwrap();
    444     assert_eq!(
    445         Amount::max(&eur).try_add(&Amount::zero(&eur)),
    446         Some(Amount::max(&eur))
    447     );
    448     assert_eq!(
    449         Amount::zero(&eur).try_add(&Amount::zero(&eur)),
    450         Some(Amount::zero(&eur))
    451     );
    452     assert_eq!(
    453         amount("EUR:6.41").try_add(&amount("EUR:4.69")),
    454         Some(amount("EUR:11.1"))
    455     );
    456     assert_eq!(
    457         amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")),
    458         Some(Amount::max(&eur))
    459     );
    460 
    461     assert_eq!(
    462         amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")),
    463         None
    464     );
    465     assert_eq!(
    466         Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")),
    467         None
    468     );
    469     assert_eq!(
    470         amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1))
    471             .try_add(&amount("EUR:5.00000002")),
    472         None
    473     );
    474 }
    475 
    476 #[test]
    477 fn test_amount_normalize() {
    478     let eur: Currency = "EUR".parse().unwrap();
    479     assert_eq!(
    480         Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(),
    481         Some(amount("EUR:6"))
    482     );
    483     assert_eq!(
    484         Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(),
    485         Some(amount("EUR:6.00000001"))
    486     );
    487     assert_eq!(
    488         Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(),
    489         Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1))
    490     );
    491     assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None);
    492     assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None);
    493 
    494     for amount in [Amount::max(&eur), Amount::zero(&eur)] {
    495         assert_eq!(amount.clone().normalize(), Some(amount))
    496     }
    497 }