amount.rs (16637B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 18 19 use std::{ 20 fmt::{Debug, Display}, 21 num::ParseIntError, 22 str::FromStr, 23 }; 24 25 use compact_str::format_compact; 26 27 use super::utils::InlineStr; 28 29 /** Number of characters we use to represent currency names */ 30 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination 31 pub const CURRENCY_LEN: usize = 11; 32 33 /** Maximum legal value for an amount, based on IEEE double */ 34 pub const MAX_VALUE: u64 = 2 << 51; 35 36 /** The number of digits in a fraction part of an amount */ 37 pub const FRAC_BASE_NB_DIGITS: u8 = 8; 38 39 /** The fraction part of an amount represents which fraction of the value */ 40 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32); 41 42 const CENT_FRACTION: u32 = 10u32.pow((FRAC_BASE_NB_DIGITS - 2) as u32); 43 44 #[derive( 45 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 46 )] 47 /// Inlined ISO 4217 currency string 48 pub struct Currency(InlineStr<CURRENCY_LEN>); 49 50 impl AsRef<str> for Currency { 51 fn as_ref(&self) -> &str { 52 self.0.as_ref() 53 } 54 } 55 56 #[derive(Debug, thiserror::Error)] 57 pub enum CurrencyErrorKind { 58 #[error("contains illegal characters (only A-Z allowed)")] 59 Invalid, 60 #[error("too long (max {CURRENCY_LEN} chars)")] 61 Big, 62 #[error("is empty")] 63 Empty, 64 } 65 66 #[derive(Debug, thiserror::Error)] 67 #[error("currency code name '{currency}' {kind}")] 68 pub struct ParseCurrencyError { 69 currency: String, 70 pub kind: CurrencyErrorKind, 71 } 72 73 impl Currency { 74 pub const TEST: Self = Self::const_parse("TEST"); 75 pub const KUDOS: Self = Self::const_parse("KUDOS"); 76 pub const EUR: Self = Self::const_parse("EUR"); 77 pub const CHF: Self = Self::const_parse("CHF"); 78 pub const HUF: Self = Self::const_parse("HUF"); 79 80 pub const fn const_parse(s: &str) -> Currency { 81 let bytes = s.as_bytes(); 82 let len = bytes.len(); 83 84 if bytes.is_empty() { 85 panic!("empty") 86 } else if len > CURRENCY_LEN { 87 panic!("too big") 88 } 89 let mut i = 0; 90 while i < bytes.len() { 91 if !bytes[i].is_ascii_uppercase() { 92 panic!("invalid") 93 } 94 i += 1; 95 } 96 Self(InlineStr::copy_from_slice(bytes)) 97 } 98 } 99 100 impl FromStr for Currency { 101 type Err = ParseCurrencyError; 102 103 fn from_str(s: &str) -> Result<Self, Self::Err> { 104 let bytes = s.as_bytes(); 105 let len = bytes.len(); 106 if bytes.is_empty() { 107 Err(CurrencyErrorKind::Empty) 108 } else if len > CURRENCY_LEN { 109 Err(CurrencyErrorKind::Big) 110 } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) { 111 Err(CurrencyErrorKind::Invalid) 112 } else { 113 Ok(Self(InlineStr::copy_from_slice(bytes))) 114 } 115 .map_err(|kind| ParseCurrencyError { 116 currency: s.to_owned(), 117 kind, 118 }) 119 } 120 } 121 122 impl Debug for Currency { 123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 124 Debug::fmt(&self.as_ref(), f) 125 } 126 } 127 128 impl Display for Currency { 129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 130 Display::fmt(&self.as_ref(), f) 131 } 132 } 133 134 #[derive(sqlx::Type)] 135 #[sqlx(type_name = "taler_amount")] 136 struct PgTalerAmount { 137 pub val: i64, 138 pub frac: i32, 139 } 140 141 #[derive( 142 Clone, 143 Copy, 144 PartialEq, 145 Eq, 146 PartialOrd, 147 Ord, 148 serde_with::DeserializeFromStr, 149 serde_with::SerializeDisplay, 150 )] 151 pub struct Decimal { 152 /** Integer part */ 153 pub val: u64, 154 /** Factional part, multiple of FRAC_BASE */ 155 pub frac: u32, 156 } 157 158 impl Decimal { 159 pub const fn new(val: u64, frac: u32) -> Self { 160 Self { val, frac } 161 } 162 163 pub const ZERO: Self = Self::new(0, 0); 164 pub const MAX: Self = Self::new(MAX_VALUE, FRAC_BASE - 1); 165 166 const fn normalize(mut self) -> Option<Self> { 167 let Some(val) = self.val.checked_add((self.frac / FRAC_BASE) as u64) else { 168 return None; 169 }; 170 self.val = val; 171 self.frac %= FRAC_BASE; 172 if self.val > MAX_VALUE { 173 return None; 174 } 175 Some(self) 176 } 177 178 pub fn try_add(mut self, rhs: &Self) -> Option<Self> { 179 self.val = self.val.checked_add(rhs.val)?; 180 self.frac = self 181 .frac 182 .checked_add(rhs.frac) 183 .expect("amount fraction overflow should never happen with normalized amounts"); 184 self.normalize() 185 } 186 187 pub fn try_sub(mut self, rhs: &Self) -> Option<Self> { 188 if rhs.frac > self.frac { 189 self.val = self.val.checked_sub(1)?; 190 self.frac += FRAC_BASE; 191 } 192 self.val = self.val.checked_sub(rhs.val)?; 193 self.frac = self.frac.checked_sub(rhs.frac)?; 194 self.normalize() 195 } 196 197 pub const fn to_amount(self, currency: &Currency) -> Amount { 198 Amount::new_decimal(currency, self) 199 } 200 } 201 202 #[derive(Debug, thiserror::Error)] 203 pub enum DecimalErrKind { 204 #[error("value overflow (must be <= {MAX_VALUE})")] 205 Overflow, 206 #[error("invalid value ({0})")] 207 InvalidValue(ParseIntError), 208 #[error("invalid fraction ({0})")] 209 InvalidFraction(ParseIntError), 210 #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")] 211 FractionOverflow, 212 } 213 214 #[derive(Debug, thiserror::Error)] 215 #[error("decimal '{decimal}' {kind}")] 216 pub struct ParseDecimalErr { 217 decimal: String, 218 pub kind: DecimalErrKind, 219 } 220 221 impl FromStr for Decimal { 222 type Err = ParseDecimalErr; 223 224 fn from_str(s: &str) -> Result<Self, Self::Err> { 225 let (value, fraction) = s.split_once('.').unwrap_or((s, "")); 226 227 // TODO use try block when stable 228 (|| { 229 let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?; 230 if value > MAX_VALUE { 231 return Err(DecimalErrKind::Overflow); 232 } 233 234 if fraction.len() > FRAC_BASE_NB_DIGITS as usize { 235 return Err(DecimalErrKind::FractionOverflow); 236 } 237 let fraction: u32 = if fraction.is_empty() { 238 0 239 } else { 240 fraction 241 .parse::<u32>() 242 .map_err(DecimalErrKind::InvalidFraction)? 243 * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32) 244 }; 245 Ok(Self { 246 val: value, 247 frac: fraction, 248 }) 249 })() 250 .map_err(|kind| ParseDecimalErr { 251 decimal: s.to_owned(), 252 kind, 253 }) 254 } 255 } 256 257 impl Display for Decimal { 258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 259 if self.frac == 0 { 260 f.write_fmt(format_args!("{}", self.val)) 261 } else { 262 let num = format_compact!("{:08}", self.frac); 263 f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0'))) 264 } 265 } 266 } 267 268 impl Debug for Decimal { 269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 270 Display::fmt(&self, f) 271 } 272 } 273 274 impl sqlx::Type<sqlx::Postgres> for Decimal { 275 fn type_info() -> sqlx::postgres::PgTypeInfo { 276 PgTalerAmount::type_info() 277 } 278 } 279 280 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal { 281 fn encode_by_ref( 282 &self, 283 buf: &mut sqlx::postgres::PgArgumentBuffer, 284 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 285 PgTalerAmount { 286 val: self.val as i64, 287 frac: self.frac as i32, 288 } 289 .encode_by_ref(buf) 290 } 291 } 292 293 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal { 294 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> { 295 let pg = PgTalerAmount::decode(value)?; 296 Ok(Self { 297 val: pg.val as u64, 298 frac: pg.frac as u32, 299 }) 300 } 301 } 302 303 #[track_caller] 304 pub fn decimal(decimal: impl AsRef<str>) -> Decimal { 305 decimal.as_ref().parse().expect("Invalid decimal constant") 306 } 307 308 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 309 #[derive( 310 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 311 )] 312 pub struct Amount { 313 pub currency: Currency, 314 pub val: u64, 315 pub frac: u32, 316 } 317 318 impl Amount { 319 pub const fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { 320 Self { 321 currency: *currency, 322 val: decimal.val, 323 frac: decimal.frac, 324 } 325 } 326 327 pub const fn new(currency: &Currency, val: u64, frac: u32) -> Self { 328 Self::new_decimal(currency, Decimal { val, frac }) 329 } 330 331 pub const fn max(currency: &Currency) -> Self { 332 Self::new_decimal(currency, Decimal::MAX) 333 } 334 335 pub const fn zero(currency: &Currency) -> Self { 336 Self::new_decimal(currency, Decimal::ZERO) 337 } 338 339 pub fn is_zero(&self) -> bool { 340 self.decimal() == Decimal::ZERO 341 } 342 343 /* Check is amount has fractional amount < 0.01 */ 344 pub const fn is_sub_cent(&self) -> bool { 345 !self.frac.is_multiple_of(CENT_FRACTION) 346 } 347 348 pub const fn decimal(&self) -> Decimal { 349 Decimal { 350 val: self.val, 351 frac: self.frac, 352 } 353 } 354 355 pub fn normalize(self) -> Option<Self> { 356 let decimal = self.decimal().normalize()?; 357 Some((self.currency, decimal).into()) 358 } 359 360 pub fn try_add(self, rhs: &Self) -> Option<Self> { 361 assert_eq!(self.currency, rhs.currency); 362 let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?; 363 Some((self.currency, decimal).into()) 364 } 365 366 pub fn try_sub(self, rhs: &Self) -> Option<Self> { 367 assert_eq!(self.currency, rhs.currency); 368 let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?; 369 Some((self.currency, decimal).into()) 370 } 371 } 372 373 impl From<(Currency, Decimal)> for Amount { 374 fn from((currency, decimal): (Currency, Decimal)) -> Self { 375 Self::new_decimal(¤cy, decimal) 376 } 377 } 378 379 #[track_caller] 380 pub fn amount(amount: impl AsRef<str>) -> Amount { 381 amount.as_ref().parse().expect("Invalid amount constant") 382 } 383 384 #[derive(Debug, thiserror::Error)] 385 pub enum AmountErrKind { 386 #[error("invalid format")] 387 Format, 388 #[error("currency {0}")] 389 Currency(#[from] CurrencyErrorKind), 390 #[error(transparent)] 391 Decimal(#[from] DecimalErrKind), 392 } 393 394 #[derive(Debug, thiserror::Error)] 395 #[error("amount '{amount}' {kind}")] 396 pub struct ParseAmountErr { 397 amount: String, 398 pub kind: AmountErrKind, 399 } 400 401 impl FromStr for Amount { 402 type Err = ParseAmountErr; 403 404 fn from_str(s: &str) -> Result<Self, Self::Err> { 405 // TODO use try block when stable 406 (|| { 407 let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?; 408 let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?; 409 let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?; 410 Ok((currency, decimal).into()) 411 })() 412 .map_err(|kind| ParseAmountErr { 413 amount: s.to_owned(), 414 kind, 415 }) 416 } 417 } 418 419 impl Display for Amount { 420 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 421 f.write_fmt(format_args!("{}:{}", self.currency, self.decimal())) 422 } 423 } 424 425 impl Debug for Amount { 426 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 427 Display::fmt(&self, f) 428 } 429 } 430 431 impl sqlx::Type<sqlx::Postgres> for Amount { 432 fn type_info() -> sqlx::postgres::PgTypeInfo { 433 PgTalerAmount::type_info() 434 } 435 } 436 437 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount { 438 fn encode_by_ref( 439 &self, 440 buf: &mut sqlx::postgres::PgArgumentBuffer, 441 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 442 self.decimal().encode_by_ref(buf) 443 } 444 } 445 446 #[test] 447 fn constants() { 448 assert_eq!(format!("{}", Amount::zero(&Currency::KUDOS)), "KUDOS:0"); 449 assert_eq!( 450 format!("{}", Amount::max(&Currency::KUDOS)), 451 "KUDOS:4503599627370496.99999999" 452 ); 453 } 454 455 #[test] 456 fn test_amount_parse() { 457 const TALER_AMOUNT_FRAC_BASE: u32 = 100000000; 458 // https://git.taler.net/exchange.git/tree/src/util/test_amount.c 459 460 const INVALID_AMOUNTS: [&str; 6] = [ 461 "EUR:4a", // non-numeric, 462 "EUR:4.4a", // non-numeric 463 "EUR:4.a4", // non-numeric 464 ":4.a4", // no currency 465 "EUR:4.123456789", // precision to high 466 "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big 467 ]; 468 469 for str in INVALID_AMOUNTS { 470 let amount = Amount::from_str(str); 471 assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); 472 } 473 474 let eur: Currency = Currency::EUR; 475 let local: Currency = Currency::CHF; 476 let valid_amounts: Vec<(&str, &str, Amount)> = vec![ 477 ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction 478 ( 479 "EUR:0.02", 480 "EUR:0.02", 481 Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2), 482 ), // leading zero fraction 483 ( 484 " EUR:4.12", 485 "EUR:4.12", 486 Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12), 487 ), // leading space and fraction 488 ( 489 " CHF:4444.1000", 490 "CHF:4444.1", 491 Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10), 492 ), // local currency 493 ]; 494 for (raw, expected, goal) in valid_amounts { 495 let amount = Amount::from_str(raw); 496 assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount); 497 assert_eq!( 498 *amount.as_ref().unwrap(), 499 goal, 500 "Expected {:?} got {:?} for {}", 501 goal, 502 amount, 503 raw 504 ); 505 let amount = amount.unwrap(); 506 let str = amount.to_string(); 507 assert_eq!(str, expected); 508 assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}"); 509 } 510 } 511 512 #[test] 513 fn test_amount_add() { 514 let eur: Currency = Currency::EUR; 515 assert_eq!( 516 Amount::max(&eur).try_add(&Amount::zero(&eur)), 517 Some(Amount::max(&eur)) 518 ); 519 assert_eq!( 520 Amount::zero(&eur).try_add(&Amount::zero(&eur)), 521 Some(Amount::zero(&eur)) 522 ); 523 assert_eq!( 524 amount("EUR:6.41").try_add(&amount("EUR:4.69")), 525 Some(amount("EUR:11.1")) 526 ); 527 assert_eq!( 528 amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")), 529 Some(Amount::max(&eur)) 530 ); 531 532 assert_eq!( 533 amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")), 534 None 535 ); 536 assert_eq!( 537 Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")), 538 None 539 ); 540 assert_eq!( 541 amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)) 542 .try_add(&amount("EUR:5.00000002")), 543 None 544 ); 545 } 546 547 #[test] 548 fn test_amount_normalize() { 549 let eur: Currency = "EUR".parse().unwrap(); 550 assert_eq!( 551 Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(), 552 Some(amount("EUR:6")) 553 ); 554 assert_eq!( 555 Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(), 556 Some(amount("EUR:6.00000001")) 557 ); 558 assert_eq!( 559 Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(), 560 Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1)) 561 ); 562 assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None); 563 assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None); 564 565 for amount in [Amount::max(&eur), Amount::zero(&eur)] { 566 assert_eq!(amount.normalize(), Some(amount)) 567 } 568 }