amount.rs (16472B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 18 19 use std::{ 20 fmt::{Debug, Display}, 21 num::ParseIntError, 22 str::FromStr, 23 }; 24 25 use super::utils::InlineStr; 26 27 /** Number of characters we use to represent currency names */ 28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination 29 pub const CURRENCY_LEN: usize = 11; 30 31 /** Maximum legal value for an amount, based on IEEE double */ 32 pub const MAX_VALUE: u64 = 2 << 52; 33 34 /** The number of digits in a fraction part of an amount */ 35 pub const FRAC_BASE_NB_DIGITS: u8 = 8; 36 37 /** The fraction part of an amount represents which fraction of the value */ 38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32); 39 40 const CENT_FRACTION: u32 = 10u32.pow((FRAC_BASE_NB_DIGITS - 2) as u32); 41 42 #[derive( 43 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 44 )] 45 /// Inlined ISO 4217 currency string 46 pub struct Currency(InlineStr<CURRENCY_LEN>); 47 48 impl AsRef<str> for Currency { 49 fn as_ref(&self) -> &str { 50 self.0.as_ref() 51 } 52 } 53 54 #[derive(Debug, thiserror::Error)] 55 pub enum CurrencyErrorKind { 56 #[error("contains illegal characters (only A-Z allowed)")] 57 Invalid, 58 #[error("too long (max {CURRENCY_LEN} chars)")] 59 Big, 60 #[error("is empty")] 61 Empty, 62 } 63 64 #[derive(Debug, thiserror::Error)] 65 #[error("currency code name '{currency}' {kind}")] 66 pub struct ParseCurrencyError { 67 currency: String, 68 pub kind: CurrencyErrorKind, 69 } 70 71 impl Currency { 72 pub const TEST: Self = Self::const_parse("TEST"); 73 pub const KUDOS: Self = Self::const_parse("KUDOS"); 74 pub const EUR: Self = Self::const_parse("EUR"); 75 pub const CHF: Self = Self::const_parse("CHF"); 76 pub const HUF: Self = Self::const_parse("HUF"); 77 78 pub const fn const_parse(s: &str) -> Currency { 79 let bytes = s.as_bytes(); 80 let len = bytes.len(); 81 82 if bytes.is_empty() { 83 panic!("empty") 84 } else if len > CURRENCY_LEN { 85 panic!("too big") 86 } 87 let mut i = 0; 88 while i < bytes.len() { 89 if !bytes[i].is_ascii_uppercase() { 90 panic!("invalid") 91 } 92 i += 1; 93 } 94 Self(InlineStr::copy_from_slice(bytes)) 95 } 96 } 97 98 impl FromStr for Currency { 99 type Err = ParseCurrencyError; 100 101 fn from_str(s: &str) -> Result<Self, Self::Err> { 102 let bytes = s.as_bytes(); 103 let len = bytes.len(); 104 if bytes.is_empty() { 105 Err(CurrencyErrorKind::Empty) 106 } else if len > CURRENCY_LEN { 107 Err(CurrencyErrorKind::Big) 108 } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) { 109 Err(CurrencyErrorKind::Invalid) 110 } else { 111 Ok(Self(InlineStr::copy_from_slice(bytes))) 112 } 113 .map_err(|kind| ParseCurrencyError { 114 currency: s.to_owned(), 115 kind, 116 }) 117 } 118 } 119 120 impl Debug for Currency { 121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 122 Debug::fmt(&self.as_ref(), f) 123 } 124 } 125 126 impl Display for Currency { 127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 128 Display::fmt(&self.as_ref(), f) 129 } 130 } 131 132 #[derive(sqlx::Type)] 133 #[sqlx(type_name = "taler_amount")] 134 struct PgTalerAmount { 135 pub val: i64, 136 pub frac: i32, 137 } 138 139 #[derive( 140 Clone, 141 Copy, 142 PartialEq, 143 Eq, 144 PartialOrd, 145 Ord, 146 serde_with::DeserializeFromStr, 147 serde_with::SerializeDisplay, 148 )] 149 pub struct Decimal { 150 /** Integer part */ 151 pub val: u64, 152 /** Factional part, multiple of FRAC_BASE */ 153 pub frac: u32, 154 } 155 156 impl Decimal { 157 pub const fn new(val: u64, frac: u32) -> Self { 158 Self { val, frac } 159 } 160 161 pub const fn max() -> Self { 162 Self { 163 val: MAX_VALUE, 164 frac: FRAC_BASE - 1, 165 } 166 } 167 168 pub const fn zero() -> Self { 169 Self { val: 0, frac: 0 } 170 } 171 172 const fn normalize(mut self) -> Option<Self> { 173 let Some(val) = self.val.checked_add((self.frac / FRAC_BASE) as u64) else { 174 return None; 175 }; 176 self.val = val; 177 self.frac %= FRAC_BASE; 178 if self.val > MAX_VALUE { 179 return None; 180 } 181 Some(self) 182 } 183 184 pub fn try_add(mut self, rhs: &Self) -> Option<Self> { 185 self.val = self.val.checked_add(rhs.val)?; 186 self.frac = self 187 .frac 188 .checked_add(rhs.frac) 189 .expect("amount fraction overflow should never happen with normalized amounts"); 190 self.normalize() 191 } 192 193 pub fn try_sub(mut self, rhs: &Self) -> Option<Self> { 194 if rhs.frac > self.frac { 195 self.val = self.val.checked_sub(1)?; 196 self.frac += FRAC_BASE; 197 } 198 self.val = self.val.checked_sub(rhs.val)?; 199 self.frac = self.frac.checked_sub(rhs.frac)?; 200 self.normalize() 201 } 202 203 pub const fn to_amount(self, currency: &Currency) -> Amount { 204 Amount::new_decimal(currency, self) 205 } 206 } 207 208 #[derive(Debug, thiserror::Error)] 209 pub enum DecimalErrKind { 210 #[error("value overflow (must be <= {MAX_VALUE})")] 211 Overflow, 212 #[error("invalid value ({0})")] 213 InvalidValue(ParseIntError), 214 #[error("invalid fraction ({0})")] 215 InvalidFraction(ParseIntError), 216 #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")] 217 FractionOverflow, 218 } 219 220 #[derive(Debug, thiserror::Error)] 221 #[error("decimal '{decimal}' {kind}")] 222 pub struct ParseDecimalErr { 223 decimal: String, 224 pub kind: DecimalErrKind, 225 } 226 227 impl FromStr for Decimal { 228 type Err = ParseDecimalErr; 229 230 fn from_str(s: &str) -> Result<Self, Self::Err> { 231 let (value, fraction) = s.split_once('.').unwrap_or((s, "")); 232 233 // TODO use try block when stable 234 (|| { 235 let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?; 236 if value > MAX_VALUE { 237 return Err(DecimalErrKind::Overflow); 238 } 239 240 if fraction.len() > FRAC_BASE_NB_DIGITS as usize { 241 return Err(DecimalErrKind::FractionOverflow); 242 } 243 let fraction: u32 = if fraction.is_empty() { 244 0 245 } else { 246 fraction 247 .parse::<u32>() 248 .map_err(DecimalErrKind::InvalidFraction)? 249 * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32) 250 }; 251 Ok(Self { 252 val: value, 253 frac: fraction, 254 }) 255 })() 256 .map_err(|kind| ParseDecimalErr { 257 decimal: s.to_owned(), 258 kind, 259 }) 260 } 261 } 262 263 impl Display for Decimal { 264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 265 if self.frac == 0 { 266 f.write_fmt(format_args!("{}", self.val)) 267 } else { 268 let num = format!("{:08}", self.frac); 269 f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0'))) 270 } 271 } 272 } 273 274 impl Debug for Decimal { 275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 276 Display::fmt(&self, f) 277 } 278 } 279 280 impl sqlx::Type<sqlx::Postgres> for Decimal { 281 fn type_info() -> sqlx::postgres::PgTypeInfo { 282 PgTalerAmount::type_info() 283 } 284 } 285 286 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal { 287 fn encode_by_ref( 288 &self, 289 buf: &mut sqlx::postgres::PgArgumentBuffer, 290 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 291 PgTalerAmount { 292 val: self.val as i64, 293 frac: self.frac as i32, 294 } 295 .encode_by_ref(buf) 296 } 297 } 298 299 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal { 300 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> { 301 let pg = PgTalerAmount::decode(value)?; 302 Ok(Self { 303 val: pg.val as u64, 304 frac: pg.frac as u32, 305 }) 306 } 307 } 308 309 #[track_caller] 310 pub fn decimal(decimal: impl AsRef<str>) -> Decimal { 311 decimal.as_ref().parse().expect("Invalid decimal constant") 312 } 313 314 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 315 #[derive( 316 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 317 )] 318 pub struct Amount { 319 pub currency: Currency, 320 pub val: u64, 321 pub frac: u32, 322 } 323 324 impl Amount { 325 pub const fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { 326 Self { 327 currency: *currency, 328 val: decimal.val, 329 frac: decimal.frac, 330 } 331 } 332 333 pub const fn new(currency: &Currency, val: u64, frac: u32) -> Self { 334 Self::new_decimal(currency, Decimal { val, frac }) 335 } 336 337 pub const fn max(currency: &Currency) -> Self { 338 Self::new_decimal(currency, Decimal::max()) 339 } 340 341 pub const fn zero(currency: &Currency) -> Self { 342 Self::new_decimal(currency, Decimal::zero()) 343 } 344 345 pub fn is_zero(&self) -> bool { 346 self.decimal() == Decimal::zero() 347 } 348 349 /* Check is amount has fractional amount < 0.01 */ 350 pub const fn is_sub_cent(&self) -> bool { 351 !self.frac.is_multiple_of(CENT_FRACTION) 352 } 353 354 pub const fn decimal(&self) -> Decimal { 355 Decimal { 356 val: self.val, 357 frac: self.frac, 358 } 359 } 360 361 pub fn normalize(self) -> Option<Self> { 362 let decimal = self.decimal().normalize()?; 363 Some((self.currency, decimal).into()) 364 } 365 366 pub fn try_add(self, rhs: &Self) -> Option<Self> { 367 assert_eq!(self.currency, rhs.currency); 368 let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?; 369 Some((self.currency, decimal).into()) 370 } 371 372 pub fn try_sub(self, rhs: &Self) -> Option<Self> { 373 assert_eq!(self.currency, rhs.currency); 374 let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?; 375 Some((self.currency, decimal).into()) 376 } 377 } 378 379 impl From<(Currency, Decimal)> for Amount { 380 fn from((currency, decimal): (Currency, Decimal)) -> Self { 381 Self::new_decimal(¤cy, decimal) 382 } 383 } 384 385 #[track_caller] 386 pub fn amount(amount: impl AsRef<str>) -> Amount { 387 amount.as_ref().parse().expect("Invalid amount constant") 388 } 389 390 #[derive(Debug, thiserror::Error)] 391 pub enum AmountErrKind { 392 #[error("invalid format")] 393 Format, 394 #[error("currency {0}")] 395 Currency(#[from] CurrencyErrorKind), 396 #[error(transparent)] 397 Decimal(#[from] DecimalErrKind), 398 } 399 400 #[derive(Debug, thiserror::Error)] 401 #[error("amount '{amount}' {kind}")] 402 pub struct ParseAmountErr { 403 amount: String, 404 pub kind: AmountErrKind, 405 } 406 407 impl FromStr for Amount { 408 type Err = ParseAmountErr; 409 410 fn from_str(s: &str) -> Result<Self, Self::Err> { 411 // TODO use try block when stable 412 (|| { 413 let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?; 414 let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?; 415 let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?; 416 Ok((currency, decimal).into()) 417 })() 418 .map_err(|kind| ParseAmountErr { 419 amount: s.to_owned(), 420 kind, 421 }) 422 } 423 } 424 425 impl Display for Amount { 426 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 427 f.write_fmt(format_args!("{}:{}", self.currency, self.decimal())) 428 } 429 } 430 431 impl Debug for Amount { 432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 433 Display::fmt(&self, f) 434 } 435 } 436 437 impl sqlx::Type<sqlx::Postgres> for Amount { 438 fn type_info() -> sqlx::postgres::PgTypeInfo { 439 PgTalerAmount::type_info() 440 } 441 } 442 443 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount { 444 fn encode_by_ref( 445 &self, 446 buf: &mut sqlx::postgres::PgArgumentBuffer, 447 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 448 self.decimal().encode_by_ref(buf) 449 } 450 } 451 452 #[test] 453 fn test_amount_parse() { 454 const TALER_AMOUNT_FRAC_BASE: u32 = 100000000; 455 // https://git.taler.net/exchange.git/tree/src/util/test_amount.c 456 457 const INVALID_AMOUNTS: [&str; 6] = [ 458 "EUR:4a", // non-numeric, 459 "EUR:4.4a", // non-numeric 460 "EUR:4.a4", // non-numeric 461 ":4.a4", // no currency 462 "EUR:4.123456789", // precision to high 463 "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big 464 ]; 465 466 for str in INVALID_AMOUNTS { 467 let amount = Amount::from_str(str); 468 assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); 469 } 470 471 let eur: Currency = Currency::EUR; 472 let local: Currency = Currency::CHF; 473 let valid_amounts: Vec<(&str, &str, Amount)> = vec![ 474 ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction 475 ( 476 "EUR:0.02", 477 "EUR:0.02", 478 Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2), 479 ), // leading zero fraction 480 ( 481 " EUR:4.12", 482 "EUR:4.12", 483 Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12), 484 ), // leading space and fraction 485 ( 486 " CHF:4444.1000", 487 "CHF:4444.1", 488 Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10), 489 ), // local currency 490 ]; 491 for (raw, expected, goal) in valid_amounts { 492 let amount = Amount::from_str(raw); 493 assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount); 494 assert_eq!( 495 *amount.as_ref().unwrap(), 496 goal, 497 "Expected {:?} got {:?} for {}", 498 goal, 499 amount, 500 raw 501 ); 502 let amount = amount.unwrap(); 503 let str = amount.to_string(); 504 assert_eq!(str, expected); 505 assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}"); 506 } 507 } 508 509 #[test] 510 fn test_amount_add() { 511 let eur: Currency = Currency::EUR; 512 assert_eq!( 513 Amount::max(&eur).try_add(&Amount::zero(&eur)), 514 Some(Amount::max(&eur)) 515 ); 516 assert_eq!( 517 Amount::zero(&eur).try_add(&Amount::zero(&eur)), 518 Some(Amount::zero(&eur)) 519 ); 520 assert_eq!( 521 amount("EUR:6.41").try_add(&amount("EUR:4.69")), 522 Some(amount("EUR:11.1")) 523 ); 524 assert_eq!( 525 amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")), 526 Some(Amount::max(&eur)) 527 ); 528 529 assert_eq!( 530 amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")), 531 None 532 ); 533 assert_eq!( 534 Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")), 535 None 536 ); 537 assert_eq!( 538 amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)) 539 .try_add(&amount("EUR:5.00000002")), 540 None 541 ); 542 } 543 544 #[test] 545 fn test_amount_normalize() { 546 let eur: Currency = "EUR".parse().unwrap(); 547 assert_eq!( 548 Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(), 549 Some(amount("EUR:6")) 550 ); 551 assert_eq!( 552 Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(), 553 Some(amount("EUR:6.00000001")) 554 ); 555 assert_eq!( 556 Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(), 557 Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1)) 558 ); 559 assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None); 560 assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None); 561 562 for amount in [Amount::max(&eur), Amount::zero(&eur)] { 563 assert_eq!(amount.normalize(), Some(amount)) 564 } 565 }