taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

commit 00a8009bd60e382c0de1cd054e9b21bfc8707001
parent 387ba0e5b8876decdaed97f2333f033f5707e52b
Author: Antoine A <>
Date:   Sat, 18 Apr 2026 13:55:05 +0200

common: constify amount

Diffstat:
Mcommon/taler-api/tests/security.rs | 1-
Mcommon/taler-common/src/config.rs | 6++----
Mcommon/taler-common/src/types/amount.rs | 70++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
Mcommon/taler-common/src/types/utils.rs | 10++++++++--
Mtaler-cyclos/src/api.rs | 2+-
Mtaler-cyclos/src/db.rs | 36++++++++++++++++--------------------
Mtaler-magnet-bank/src/api.rs | 4++--
Mtaler-magnet-bank/src/constants.rs | 4+---
Mtaler-magnet-bank/src/db.rs | 18+++++++++---------
9 files changed, 89 insertions(+), 62 deletions(-)

diff --git a/common/taler-api/tests/security.rs b/common/taler-api/tests/security.rs @@ -34,7 +34,6 @@ mod common; #[tokio::test] async fn body_parsing() { let (server, _) = setup().await; - let eur: Currency = "EUR".parse().unwrap(); let normal_body = TransferRequest { request_uid: Base32::rand(), amount: Amount::zero(&eur), diff --git a/common/taler-common/src/config.rs b/common/taler-common/src/config.rs @@ -958,8 +958,7 @@ mod test { use std::{ fmt::{Debug, Display}, fs::{File, Permissions}, - os::unix::fs::PermissionsExt, - str::FromStr, + os::unix::fs::PermissionsExt }; use tracing::error; @@ -1231,10 +1230,9 @@ mod test { #[test] fn amount() { - let currency = Currency::from_str("KUDOS").unwrap(); routine( "amount", - |sect, value| sect.amount(value, &currency), + |sect, value| sect.amount(value, &Currency::KUDOS), &[( &["KUDOS:12", "KUDOS:12.0", "KUDOS:012.0"], amount::amount("KUDOS:12"), diff --git a/common/taler-common/src/types/amount.rs b/common/taler-common/src/types/amount.rs @@ -68,6 +68,33 @@ pub struct ParseCurrencyError { pub kind: CurrencyErrorKind, } +impl Currency { + pub const TEST: Self = Self::const_parse("TEST"); + pub const KUDOS: Self = Self::const_parse("KUDOS"); + pub const EUR: Self = Self::const_parse("EUR"); + pub const CHF: Self = Self::const_parse("CHF"); + pub const HUF: Self = Self::const_parse("HUF"); + + pub const fn const_parse(s: &str) -> Currency { + let bytes = s.as_bytes(); + let len = bytes.len(); + + if bytes.is_empty() { + panic!("empty") + } else if len > CURRENCY_LEN { + panic!("too big") + } + let mut i = 0; + while i < bytes.len() { + if !bytes[i].is_ascii_uppercase() { + panic!("invalid") + } + i += 1; + } + Self(InlineStr::copy_from_slice(bytes)) + } +} + impl FromStr for Currency { type Err = ParseCurrencyError; @@ -127,7 +154,7 @@ pub struct Decimal { } impl Decimal { - pub fn new(val: u64, frac: u32) -> Self { + pub const fn new(val: u64, frac: u32) -> Self { Self { val, frac } } @@ -142,8 +169,11 @@ impl Decimal { Self { val: 0, frac: 0 } } - fn normalize(mut self) -> Option<Self> { - self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?; + const fn normalize(mut self) -> Option<Self> { + let Some(val) = self.val.checked_add((self.frac / FRAC_BASE) as u64) else { + return None; + }; + self.val = val; self.frac %= FRAC_BASE; if self.val > MAX_VALUE { return None; @@ -170,7 +200,7 @@ impl Decimal { self.normalize() } - pub fn to_amount(self, currency: &Currency) -> Amount { + pub const fn to_amount(self, currency: &Currency) -> Amount { Amount::new_decimal(currency, self) } } @@ -292,19 +322,23 @@ pub struct Amount { } impl Amount { - pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { - (*currency, decimal).into() + pub const fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { + Self { + currency: *currency, + val: decimal.val, + frac: decimal.frac, + } } - pub fn new(currency: &Currency, val: u64, frac: u32) -> Self { + pub const fn new(currency: &Currency, val: u64, frac: u32) -> Self { Self::new_decimal(currency, Decimal { val, frac }) } - pub fn max(currency: &Currency) -> Self { + pub const fn max(currency: &Currency) -> Self { Self::new_decimal(currency, Decimal::max()) } - pub fn zero(currency: &Currency) -> Self { + pub const fn zero(currency: &Currency) -> Self { Self::new_decimal(currency, Decimal::zero()) } @@ -313,7 +347,7 @@ impl Amount { } /* Check is amount has fractional amount < 0.01 */ - pub fn is_sub_cent(&self) -> bool { + pub const fn is_sub_cent(&self) -> bool { !self.frac.is_multiple_of(CENT_FRACTION) } @@ -344,11 +378,7 @@ impl Amount { impl From<(Currency, Decimal)> for Amount { fn from((currency, decimal): (Currency, Decimal)) -> Self { - Self { - currency, - val: decimal.val, - frac: decimal.frac, - } + Self::new_decimal(&currency, decimal) } } @@ -438,8 +468,8 @@ fn test_amount_parse() { assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); } - let eur: Currency = "EUR".parse().unwrap(); - let local: Currency = "LOCAL".parse().unwrap(); + let eur: Currency = Currency::EUR; + let local: Currency = Currency::CHF; let valid_amounts: Vec<(&str, &str, Amount)> = vec![ ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction ( @@ -453,8 +483,8 @@ fn test_amount_parse() { Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12), ), // leading space and fraction ( - " LOCAL:4444.1000", - "LOCAL:4444.1", + " CHF:4444.1000", + "CHF:4444.1", Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10), ), // local currency ]; @@ -478,7 +508,7 @@ fn test_amount_parse() { #[test] fn test_amount_add() { - let eur: Currency = "EUR".parse().unwrap(); + let eur: Currency = Currency::EUR; assert_eq!( Amount::max(&eur).try_add(&Amount::zero(&eur)), Some(Amount::max(&eur)) diff --git a/common/taler-common/src/types/utils.rs b/common/taler-common/src/types/utils.rs @@ -34,10 +34,16 @@ pub struct InlineStr<const LEN: usize> { impl<const LEN: usize> InlineStr<LEN> { /// Create an inlined string from a slice #[inline] - pub fn copy_from_slice(slice: &[u8]) -> Self { + pub const fn copy_from_slice(slice: &[u8]) -> Self { let len = slice.len(); let mut buf = [0; LEN]; - buf[..len].copy_from_slice(slice); + // TODO use buf[..len].copy_from_slice(slice); once allowed in const + let mut i = 0; + while i < len { + buf[i] = slice[i]; + i += 1; + } + debug_assert!(buf.is_ascii()); Self { len: len as u8, diff --git a/taler-cyclos/src/api.rs b/taler-cyclos/src/api.rs @@ -376,7 +376,7 @@ mod test { pool.clone(), CompactString::const_new("localhost"), ACCOUNT.clone(), - Currency::from_str("TEST").unwrap(), + Currency::TEST, )); let server = Router::new() .wire_gateway(api.clone(), AuthMethod::None) diff --git a/taler-cyclos/src/db.rs b/taler-cyclos/src/db.rs @@ -911,8 +911,6 @@ impl CyclosTypeHelper for PgRow { #[cfg(test)] mod test { - use std::sync::LazyLock; - use compact_str::CompactString; use jiff::{Span, Timestamp}; use serde_json::json; @@ -941,7 +939,7 @@ mod test { }, }; - pub static CURRENCY: LazyLock<Currency> = LazyLock::new(|| "TEST".parse().unwrap()); + pub const CURR: Currency = Currency::TEST; pub const ROOT: CompactString = CompactString::const_new("localhost"); async fn setup() -> (PoolConnection<Postgres>, PgPool) { @@ -1052,13 +1050,13 @@ mod test { // Empty db assert_eq!( - db::revenue_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::revenue_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap(), Vec::new() ); assert_eq!( - db::incoming_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::incoming_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap(), Vec::new() @@ -1083,14 +1081,14 @@ mod test { // History assert_eq!( - db::revenue_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::revenue_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap() .len(), 6 ); assert_eq!( - db::incoming_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::incoming_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap() .len(), @@ -1104,7 +1102,7 @@ mod test { // Empty db assert_eq!( - db::incoming_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::incoming_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap(), Vec::new() @@ -1154,7 +1152,7 @@ mod test { // History assert_eq!( - db::incoming_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::incoming_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap() .len(), @@ -1257,7 +1255,7 @@ mod test { // Empty db assert_eq!( - db::outgoing_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::outgoing_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap(), Vec::new() @@ -1278,7 +1276,7 @@ mod test { // History assert_eq!( - db::outgoing_history(&pool, &History::default(), &CURRENCY, &ROOT, dummy_listen) + db::outgoing_history(&pool, &History::default(), &CURR, &ROOT, dummy_listen) .await .unwrap() .len(), @@ -1294,13 +1292,11 @@ mod test { // Empty db assert_eq!( - db::transfer_by_id(&pool, 0, &CURRENCY, &ROOT) - .await - .unwrap(), + db::transfer_by_id(&pool, 0, &CURR, &ROOT).await.unwrap(), None ); assert_eq!( - db::transfer_page(&pool, &None, &CURRENCY, &ROOT, &Page::default()) + db::transfer_page(&pool, &None, &CURR, &ROOT, &Page::default()) .await .unwrap(), Vec::new() @@ -1386,25 +1382,25 @@ mod test { // Get assert!( - db::transfer_by_id(&pool, 1, &CURRENCY, &ROOT) + db::transfer_by_id(&pool, 1, &CURR, &ROOT) .await .unwrap() .is_some() ); assert!( - db::transfer_by_id(&pool, 2, &CURRENCY, &ROOT) + db::transfer_by_id(&pool, 2, &CURR, &ROOT) .await .unwrap() .is_some() ); assert!( - db::transfer_by_id(&pool, 3, &CURRENCY, &ROOT) + db::transfer_by_id(&pool, 3, &CURR, &ROOT) .await .unwrap() .is_none() ); assert_eq!( - db::transfer_page(&pool, &None, &CURRENCY, &ROOT, &Page::default()) + db::transfer_page(&pool, &None, &CURR, &ROOT, &Page::default()) .await .unwrap() .len(), @@ -1499,7 +1495,7 @@ mod test { let (mut db, pool) = setup().await; let check_status = async |id: u64, status: TransferState, msg: Option<&str>| { - let transfer = db::transfer_by_id(&pool, id, &CURRENCY, &ROOT) + let transfer = db::transfer_by_id(&pool, id, &CURR, &ROOT) .await .unwrap() .unwrap(); diff --git a/taler-magnet-bank/src/api.rs b/taler-magnet-bank/src/api.rs @@ -40,7 +40,7 @@ use tokio::sync::watch::Sender; use crate::{ FullHuPayto, - constants::CURRENCY, + constants::CURR, db::{self, AddIncomingResult, Transfer, TxInAdmin}, }; @@ -80,7 +80,7 @@ impl MagnetApi { impl TalerApi for MagnetApi { fn currency(&self) -> &str { - CURRENCY.as_ref() + CURR.as_ref() } fn implementation(&self) -> &'static str { diff --git a/taler-magnet-bank/src/constants.rs b/taler-magnet-bank/src/constants.rs @@ -14,12 +14,10 @@ TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> */ -use std::sync::LazyLock; - use aws_lc_rs::signature::{ECDSA_P256_SHA256_ASN1_SIGNING, EcdsaSigningAlgorithm}; use taler_common::{config::parser::ConfigSource, types::amount::Currency}; -pub static CURRENCY: LazyLock<Currency> = LazyLock::new(|| "HUF".parse().unwrap()); +pub const CURR: Currency = Currency::HUF; pub const MAX_MAGNET_BBAN_SIZE: usize = 24; pub const CONFIG_SOURCE: ConfigSource = ConfigSource::new("taler-magnet-bank", "magnet-bank", "taler-magnet-bank"); diff --git a/taler-magnet-bank/src/db.rs b/taler-magnet-bank/src/db.rs @@ -44,7 +44,7 @@ use taler_common::{ use tokio::sync::watch::{Receiver, Sender}; use url::Url; -use crate::{FullHuPayto, config::parse_db_cfg, constants::CURRENCY, magnet_api::types::TxStatus}; +use crate::{FullHuPayto, config::parse_db_cfg, constants::CURR, magnet_api::types::TxStatus}; const SCHEMA: &str = "magnet_bank"; @@ -512,7 +512,7 @@ pub async fn transfer_page( Ok(TransferListStatus { row_id: r.try_get_safeu64(0)?, status: r.try_get(1)?, - amount: r.try_get_amount(2, &CURRENCY)?, + amount: r.try_get_amount(2, &CURR)?, credit_account: r.try_get_iban(3)?.as_full_payto(r.try_get(4)?), timestamp: r.try_get_timestamp(5)?.into(), }) @@ -552,7 +552,7 @@ pub async fn outgoing_history( |r: PgRow| { Ok(OutgoingBankTransaction { row_id: r.try_get_safeu64(0)?, - amount: r.try_get_amount(1, &CURRENCY)?, + amount: r.try_get_amount(1, &CURR)?, debit_fee: None, credit_account: r.try_get_iban(2)?.as_full_payto(r.try_get(3)?), date: r.try_get_timestamp(4)?.into(), @@ -598,7 +598,7 @@ pub async fn incoming_history( Ok(match r.try_get(0)? { IncomingType::reserve => IncomingBankTransaction::Reserve { row_id: r.try_get_safeu64(1)?, - amount: r.try_get_amount(2, &CURRENCY)?, + amount: r.try_get_amount(2, &CURR)?, credit_fee: None, debit_account: r.try_get_iban(3)?.as_full_payto(r.try_get(4)?), date: r.try_get_timestamp(5)?.into(), @@ -608,7 +608,7 @@ pub async fn incoming_history( }, IncomingType::kyc => IncomingBankTransaction::Kyc { row_id: r.try_get_safeu64(1)?, - amount: r.try_get_amount(2, &CURRENCY)?, + amount: r.try_get_amount(2, &CURR)?, credit_fee: None, debit_account: r.try_get_iban(3)?.as_full_payto(r.try_get(4)?), date: r.try_get_timestamp(5)?.into(), @@ -652,7 +652,7 @@ pub async fn revenue_history( Ok(RevenueIncomingBankTransaction { row_id: r.try_get_safeu64(0)?, date: r.try_get_timestamp(1)?.into(), - amount: r.try_get_amount(2, &CURRENCY)?, + amount: r.try_get_amount(2, &CURR)?, credit_fee: None, debit_account: r.try_get_iban(3)?.as_full_payto(r.try_get(4)?), subject: r.try_get(5)?, @@ -686,7 +686,7 @@ pub async fn transfer_by_id(db: &PgPool, id: u64) -> sqlx::Result<Option<Transfe Ok(TransferStatus { status: r.try_get(0)?, status_msg: r.try_get(1)?, - amount: r.try_get_amount(2, &CURRENCY)?, + amount: r.try_get_amount(2, &CURR)?, origin_exchange_url: r.try_get(3)?, metadata: r.try_get(4)?, wtid: r.try_get(5)?, @@ -718,7 +718,7 @@ pub async fn pending_batch( .try_map(|r: PgRow| { Ok(Initiated { id: r.try_get_u64(0)?, - amount: r.try_get_amount(1, &CURRENCY)?, + amount: r.try_get_amount(1, &CURR)?, subject: r.try_get(2)?, creditor: FullHuPayto::new(r.try_get_parse(3)?, r.try_get(4)?), }) @@ -744,7 +744,7 @@ pub async fn initiated_by_code( .try_map(|r: PgRow| { Ok(Initiated { id: r.try_get_u64(0)?, - amount: r.try_get_amount(1, &CURRENCY)?, + amount: r.try_get_amount(1, &CURR)?, subject: r.try_get(2)?, creditor: FullHuPayto::new(r.try_get_parse(3)?, r.try_get(4)?), })