diff options
author | Marcello Stanisci <stanisci.m@gmail.com> | 2017-12-06 17:51:55 +0100 |
---|---|---|
committer | Marcello Stanisci <stanisci.m@gmail.com> | 2017-12-06 17:51:55 +0100 |
commit | 4efff7788f052cfdd2949af9198eb2c18a4cd4e8 (patch) | |
tree | 04cf19060f3c3d029af5290e2165b0b33432c695 | |
parent | e10115be96dfcbbfae2c0d004e5498a224ba9a3b (diff) | |
download | bank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.tar.gz bank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.tar.bz2 bank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.zip |
annotating types for config parser
-rw-r--r-- | talerbank/__init__.py | 8 | ||||
-rw-r--r-- | talerbank/app/amount.py | 29 | ||||
-rw-r--r-- | talerbank/app/models.py | 25 | ||||
-rw-r--r-- | talerbank/app/types.py | 4 | ||||
-rw-r--r-- | talerbank/app/views.py | 2 | ||||
-rw-r--r-- | talerbank/talerconfig.py | 56 |
6 files changed, 58 insertions, 66 deletions
diff --git a/talerbank/__init__.py b/talerbank/__init__.py index ca35fb3..6c4125c 100644 --- a/talerbank/__init__.py +++ b/talerbank/__init__.py @@ -1,8 +1,4 @@ import logging -LOG_CONF = { - 'format': '%(asctime)-15s %(module)s %(levelname)s %(message)s', - 'level': logging.WARNING -} - -logging.basicConfig(**LOG_CONF) +FMT = '%(asctime)-15s %(module)s %(levelname)s %(message)s' +logging.basicConfig(format=FMT, level=logging.WARNING) diff --git a/talerbank/app/amount.py b/talerbank/app/amount.py index 46e3446..45e306e 100644 --- a/talerbank/app/amount.py +++ b/talerbank/app/amount.py @@ -22,13 +22,15 @@ # mentioned above, and it is meant to be manually copied into any project # which might need it. +from typing import Type + class CurrencyMismatch(Exception): - def __init__(self, curr1, curr2): + def __init__(self, curr1, curr2) -> None: super(CurrencyMismatch, self).__init__( "%s vs %s" % (curr1, curr2)) class BadFormatAmount(Exception): - def __init__(self, faulty_str): + def __init__(self, faulty_str) -> None: super(BadFormatAmount, self).__init__( "Bad format amount: " + faulty_str) @@ -36,15 +38,14 @@ class Amount: # How many "fraction" units make one "value" unit of currency # (Taler requires 10^8). Do not change this 'constant'. @staticmethod - def _fraction(): + def _fraction() -> int: return 10 ** 8 @staticmethod - def _max_value(): + def _max_value() -> int: return (2 ** 53) - 1 - def __init__(self, currency, value=0, fraction=0): - # type: (str, int, int) -> Amount + def __init__(self, currency, value=0, fraction=0) -> None: assert value >= 0 and fraction >= 0 self.value = value self.fraction = fraction @@ -53,7 +54,7 @@ class Amount: assert self.value <= Amount._max_value() # Normalize amount - def __normalize(self): + def __normalize(self) -> None: if self.fraction >= Amount._fraction(): self.value += int(self.fraction / Amount._fraction()) self.fraction = self.fraction % Amount._fraction() @@ -61,7 +62,7 @@ class Amount: # Parse a string matching the format "A:B.C" # instantiating an amount object. @classmethod - def parse(cls, amount_str): + def parse(cls: Type[Amount], amount_str: str) -> Amount: exp = r'^\s*([-_*A-Za-z0-9]+):([0-9]+)\.([0-9]+)\s*$' import re parsed = re.search(exp, amount_str) @@ -78,7 +79,7 @@ class Amount: # 0 if a == b # 1 if a > b @staticmethod - def cmp(am1, am2): + def cmp(am1: Amount, am2: Amount) -> int: if am1.currency != am2.currency: raise CurrencyMismatch(am1.currency, am2.currency) if am1.value == am2.value: @@ -91,13 +92,13 @@ class Amount: return -1 return 1 - def set(self, currency, value=0, fraction=0): + def set(self, currency: str, value=0, fraction=0) -> None: self.currency = currency self.value = value self.fraction = fraction # Add the given amount to this one - def add(self, amount): + def add(self, amount: Amount) -> None: if self.currency != amount.currency: raise CurrencyMismatch(self.currency, amount.currency) self.value += amount.value @@ -105,7 +106,7 @@ class Amount: self.__normalize() # Subtract passed amount from this one - def subtract(self, amount): + def subtract(self, amount: Amount) -> None: if self.currency != amount.currency: raise CurrencyMismatch(self.currency, amount.currency) if self.fraction < amount.fraction: @@ -118,7 +119,7 @@ class Amount: # Dump string from this amount, will put 'ndigits' numbers # after the dot. - def stringify(self, ndigits): + def stringify(self, ndigits: int) -> str: assert ndigits > 0 ret = '%s:%s.' % (self.currency, str(self.value)) fraction = self.fraction @@ -129,7 +130,7 @@ class Amount: return ret # Dump the Taler-compliant 'dict' amount - def dump(self): + def dump(self) -> dict: return dict(value=self.value, fraction=self.fraction, currency=self.currency) diff --git a/talerbank/app/models.py b/talerbank/app/models.py index 97da159..b0d0ca4 100644 --- a/talerbank/app/models.py +++ b/talerbank/app/models.py @@ -21,8 +21,7 @@ from django.contrib.auth.models import User from django.db import models from django.conf import settings from django.core.exceptions import ValidationError -from . import amount -from .types import TA +from .amount import Amount, BadFormatAmount class AmountField(models.Field): @@ -36,30 +35,30 @@ class AmountField(models.Field): return "varchar" # Pass stringified object to db connector - def get_prep_value(self, value: TA) -> str: + def get_prep_value(self, value: Amount) -> str: if not value: return "%s:0.0" % settings.TALER_CURRENCY return value.stringify(settings.TALER_DIGITS) @staticmethod - def from_db_value(value: str, *args) -> TA: + def from_db_value(value: str, *args) -> Amount: del args # pacify PEP checkers if value is None: - return amount.Amount.parse(settings.TALER_CURRENCY) - return amount.Amount.parse(value) + return Amount.parse(settings.TALER_CURRENCY) + return Amount.parse(value) - def to_python(self, value: Any) -> TA: - if isinstance(value, amount.Amount): + def to_python(self, value: Any) -> Amount: + if isinstance(value, Amount): return value try: if value is None: - return amount.Amount.parse(settings.TALER_CURRENCY) - return amount.Amount.parse(value) - except amount.BadFormatAmount: + return Amount.parse(settings.TALER_CURRENCY) + return Amount.parse(value) + except BadFormatAmount: raise ValidationError("Invalid input for an amount string: %s" % value) -def get_zero_amount() -> TA: - return amount.Amount(settings.TALER_CURRENCY) +def get_zero_amount() -> Amount: + return Amount(settings.TALER_CURRENCY) class BankAccount(models.Model): is_public = models.BooleanField(default=False) diff --git a/talerbank/app/types.py b/talerbank/app/types.py deleted file mode 100644 index d103340..0000000 --- a/talerbank/app/types.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing import TypeVar -from .amount import Amount - -TA = TypeVar('TA', Amount) diff --git a/talerbank/app/views.py b/talerbank/app/views.py index 8e2e452..a18c684 100644 --- a/talerbank/app/views.py +++ b/talerbank/app/views.py @@ -49,7 +49,7 @@ from .schemas import (validate_pin_tan_args, check_withdraw_session, LOGGER = logging.getLogger(__name__) class DebtLimitExceededException(Exception): - def __init__(self): + def __init__(self) -> None: super().__init__("Debt limit exceeded") class SameAccountException(Exception): diff --git a/talerbank/talerconfig.py b/talerbank/talerconfig.py index a7ca065..41ebf44 100644 --- a/talerbank/talerconfig.py +++ b/talerbank/talerconfig.py @@ -24,6 +24,7 @@ import os import weakref import sys import re +from typing import Callable, Any LOGGER = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class ExpansionSyntaxError(Exception): pass -def expand(var, getter): +def expand(var: str, getter: Callable[[str], str]) -> str: """ Do shell-style parameter expansion. Supported syntax: @@ -93,17 +94,17 @@ def expand(var, getter): class OptionDict(collections.defaultdict): - def __init__(self, config, section_name): + def __init__(self, config: SectionDict, section_name: str) -> None: self.config = weakref.ref(config) self.section_name = section_name super().__init__() - def __missing__(self, key): + def __missing__(self, key: str) -> Entry: entry = Entry(self.config(), self.section_name, key) self[key] = entry return entry - def __getitem__(self, chunk): + def __getitem__(self, chunk: str) -> Entry: return super().__getitem__(chunk.lower()) - def __setitem__(self, chunk, value): + def __setitem__(self, chunk: str, value: Entry) -> None: super().__setitem__(chunk.lower(), value) @@ -112,14 +113,13 @@ class SectionDict(collections.defaultdict): value = OptionDict(self, key) self[key] = value return value - def __getitem__(self, chunk): + def __getitem__(self, chunk: str) -> OptionDict: return super().__getitem__(chunk.lower()) - def __setitem__(self, chunk, value): + def __setitem__(self, chunk: str, value: OptionDict) -> None: super().__setitem__(chunk.lower(), value) - class Entry: - def __init__(self, config, section, option, **kwargs): + def __init__(self, config: SectionDict, section: str, option: str, **kwargs) -> None: self.value = kwargs.get("value") self.filename = kwargs.get("filename") self.lineno = kwargs.get("lineno") @@ -127,14 +127,14 @@ class Entry: self.option = option self.config = weakref.ref(config) - def __repr__(self): + def __repr__(self) -> str: return "<Entry section=%s, option=%s, value=%s>" \ % (self.section, self.option, repr(self.value),) - def __str__(self): + def __str__(self) -> Any: return self.value - def value_string(self, default=None, required=False, warn=False): + def value_string(self, default=None, required=False, warn=False) -> str: if required and self.value is None: raise ConfigurationError("Missing required option '%s' in section '%s'" \ % (self.option.upper(), self.section.upper())) @@ -149,7 +149,7 @@ class Entry: return default return self.value - def value_int(self, default=None, required=False, warn=False): + def value_int(self, default=None, required=False, warn=False) -> int: value = self.value_string(default, warn, required) if value is None: return None @@ -159,7 +159,7 @@ class Entry: raise ConfigurationError("Expected number for option '%s' in section '%s'" \ % (self.option.upper(), self.section.upper())) - def _getsubst(self, key): + def _getsubst(self, key: str) -> Any: value = self.config()["paths"][key].value if value is not None: return value @@ -168,13 +168,13 @@ class Entry: return value return None - def value_filename(self, default=None, required=False, warn=False): + def value_filename(self, default=None, required=False, warn=False) -> str: value = self.value_string(default, required, warn) if value is None: return None return expand(value, self._getsubst) - def location(self): + def location(self) -> str: if self.filename is None or self.lineno is None: return "<unknown>" return "%s:%s" % (self.filename, self.lineno) @@ -185,16 +185,16 @@ class TalerConfig: One loaded taler configuration, including base configuration files and included files. """ - def __init__(self): + def __init__(self) -> None: """ Initialize an empty configuration """ - self.sections = SectionDict() + self.sections = SectionDict() # just plain dict # defaults != config file: the first is the 'base' # whereas the second overrides things from the first. @staticmethod - def from_file(filename=None, load_defaults=True): + def from_file(filename=None, load_defaults=True) -> TalerConfig: cfg = TalerConfig() if filename is None: xdg = os.environ.get("XDG_CONFIG_HOME") @@ -207,19 +207,19 @@ class TalerConfig: cfg.load_file(filename) return cfg - def value_string(self, section, option, **kwargs): + def value_string(self, section, option, **kwargs) -> str: return self.sections[section][option].value_string( kwargs.get("default"), kwargs.get("required"), kwargs.get("warn")) - def value_filename(self, section, option, **kwargs): + def value_filename(self, section, option, **kwargs) -> str: return self.sections[section][option].value_filename( kwargs.get("default"), kwargs.get("required"), kwargs.get("warn")) - def value_int(self, section, option, **kwargs): + def value_int(self, section, option, **kwargs) -> int: return self.sections[section][option].value_int( kwargs.get("default"), kwargs.get("required"), kwargs.get("warn")) - def load_defaults(self): + def load_defaults(self) -> None: base_dir = os.environ.get("TALER_BASE_CONFIG") if base_dir: self.load_dir(base_dir) @@ -237,7 +237,7 @@ class TalerConfig: LOGGER.warning("no base directory found") @staticmethod - def from_env(*args, **kwargs): + def from_env(*args, **kwargs) -> TalerConfig: """ Load configuration from environment variable TALER_CONFIG_FILE or from default location if the variable is not set. @@ -245,7 +245,7 @@ class TalerConfig: filename = os.environ.get("TALER_CONFIG_FILE") return TalerConfig.from_file(filename, *args, **kwargs) - def load_dir(self, dirname): + def load_dir(self, dirname) -> None: try: files = os.listdir(dirname) except FileNotFoundError: @@ -256,7 +256,7 @@ class TalerConfig: continue self.load_file(os.path.join(dirname, file)) - def load_file(self, filename): + def load_file(self, filename) -> None: sections = self.sections try: with open(filename, "r") as file: @@ -300,7 +300,7 @@ class TalerConfig: sys.exit(3) - def dump(self): + def dump(self) -> None: for kv_section in self.sections.items(): print("[%s]" % (kv_section[1].section_name,)) for kv_option in kv_section[1].items(): @@ -309,7 +309,7 @@ class TalerConfig: kv_option[1].value, kv_option[1].location())) - def __getitem__(self, chunk): + def __getitem__(self, chunk: str) -> OptionDict: if isinstance(chunk, str): return self.sections[chunk] raise TypeError("index must be string") |