summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcello Stanisci <stanisci.m@gmail.com>2017-12-06 17:51:55 +0100
committerMarcello Stanisci <stanisci.m@gmail.com>2017-12-06 17:51:55 +0100
commit4efff7788f052cfdd2949af9198eb2c18a4cd4e8 (patch)
tree04cf19060f3c3d029af5290e2165b0b33432c695
parente10115be96dfcbbfae2c0d004e5498a224ba9a3b (diff)
downloadbank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.tar.gz
bank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.tar.bz2
bank-4efff7788f052cfdd2949af9198eb2c18a4cd4e8.zip
annotating types for config parser
-rw-r--r--talerbank/__init__.py8
-rw-r--r--talerbank/app/amount.py29
-rw-r--r--talerbank/app/models.py25
-rw-r--r--talerbank/app/types.py4
-rw-r--r--talerbank/app/views.py2
-rw-r--r--talerbank/talerconfig.py56
6 files changed, 58 insertions, 66 deletions
diff --git a/talerbank/__init__.py b/talerbank/__init__.py
index ca35fb3..6c4125c 100644
--- a/talerbank/__init__.py
+++ b/talerbank/__init__.py
@@ -1,8 +1,4 @@
import logging
-LOG_CONF = {
- 'format': '%(asctime)-15s %(module)s %(levelname)s %(message)s',
- 'level': logging.WARNING
-}
-
-logging.basicConfig(**LOG_CONF)
+FMT = '%(asctime)-15s %(module)s %(levelname)s %(message)s'
+logging.basicConfig(format=FMT, level=logging.WARNING)
diff --git a/talerbank/app/amount.py b/talerbank/app/amount.py
index 46e3446..45e306e 100644
--- a/talerbank/app/amount.py
+++ b/talerbank/app/amount.py
@@ -22,13 +22,15 @@
# mentioned above, and it is meant to be manually copied into any project
# which might need it.
+from typing import Type
+
class CurrencyMismatch(Exception):
- def __init__(self, curr1, curr2):
+ def __init__(self, curr1, curr2) -> None:
super(CurrencyMismatch, self).__init__(
"%s vs %s" % (curr1, curr2))
class BadFormatAmount(Exception):
- def __init__(self, faulty_str):
+ def __init__(self, faulty_str) -> None:
super(BadFormatAmount, self).__init__(
"Bad format amount: " + faulty_str)
@@ -36,15 +38,14 @@ class Amount:
# How many "fraction" units make one "value" unit of currency
# (Taler requires 10^8). Do not change this 'constant'.
@staticmethod
- def _fraction():
+ def _fraction() -> int:
return 10 ** 8
@staticmethod
- def _max_value():
+ def _max_value() -> int:
return (2 ** 53) - 1
- def __init__(self, currency, value=0, fraction=0):
- # type: (str, int, int) -> Amount
+ def __init__(self, currency, value=0, fraction=0) -> None:
assert value >= 0 and fraction >= 0
self.value = value
self.fraction = fraction
@@ -53,7 +54,7 @@ class Amount:
assert self.value <= Amount._max_value()
# Normalize amount
- def __normalize(self):
+ def __normalize(self) -> None:
if self.fraction >= Amount._fraction():
self.value += int(self.fraction / Amount._fraction())
self.fraction = self.fraction % Amount._fraction()
@@ -61,7 +62,7 @@ class Amount:
# Parse a string matching the format "A:B.C"
# instantiating an amount object.
@classmethod
- def parse(cls, amount_str):
+ def parse(cls: Type[Amount], amount_str: str) -> Amount:
exp = r'^\s*([-_*A-Za-z0-9]+):([0-9]+)\.([0-9]+)\s*$'
import re
parsed = re.search(exp, amount_str)
@@ -78,7 +79,7 @@ class Amount:
# 0 if a == b
# 1 if a > b
@staticmethod
- def cmp(am1, am2):
+ def cmp(am1: Amount, am2: Amount) -> int:
if am1.currency != am2.currency:
raise CurrencyMismatch(am1.currency, am2.currency)
if am1.value == am2.value:
@@ -91,13 +92,13 @@ class Amount:
return -1
return 1
- def set(self, currency, value=0, fraction=0):
+ def set(self, currency: str, value=0, fraction=0) -> None:
self.currency = currency
self.value = value
self.fraction = fraction
# Add the given amount to this one
- def add(self, amount):
+ def add(self, amount: Amount) -> None:
if self.currency != amount.currency:
raise CurrencyMismatch(self.currency, amount.currency)
self.value += amount.value
@@ -105,7 +106,7 @@ class Amount:
self.__normalize()
# Subtract passed amount from this one
- def subtract(self, amount):
+ def subtract(self, amount: Amount) -> None:
if self.currency != amount.currency:
raise CurrencyMismatch(self.currency, amount.currency)
if self.fraction < amount.fraction:
@@ -118,7 +119,7 @@ class Amount:
# Dump string from this amount, will put 'ndigits' numbers
# after the dot.
- def stringify(self, ndigits):
+ def stringify(self, ndigits: int) -> str:
assert ndigits > 0
ret = '%s:%s.' % (self.currency, str(self.value))
fraction = self.fraction
@@ -129,7 +130,7 @@ class Amount:
return ret
# Dump the Taler-compliant 'dict' amount
- def dump(self):
+ def dump(self) -> dict:
return dict(value=self.value,
fraction=self.fraction,
currency=self.currency)
diff --git a/talerbank/app/models.py b/talerbank/app/models.py
index 97da159..b0d0ca4 100644
--- a/talerbank/app/models.py
+++ b/talerbank/app/models.py
@@ -21,8 +21,7 @@ from django.contrib.auth.models import User
from django.db import models
from django.conf import settings
from django.core.exceptions import ValidationError
-from . import amount
-from .types import TA
+from .amount import Amount, BadFormatAmount
class AmountField(models.Field):
@@ -36,30 +35,30 @@ class AmountField(models.Field):
return "varchar"
# Pass stringified object to db connector
- def get_prep_value(self, value: TA) -> str:
+ def get_prep_value(self, value: Amount) -> str:
if not value:
return "%s:0.0" % settings.TALER_CURRENCY
return value.stringify(settings.TALER_DIGITS)
@staticmethod
- def from_db_value(value: str, *args) -> TA:
+ def from_db_value(value: str, *args) -> Amount:
del args # pacify PEP checkers
if value is None:
- return amount.Amount.parse(settings.TALER_CURRENCY)
- return amount.Amount.parse(value)
+ return Amount.parse(settings.TALER_CURRENCY)
+ return Amount.parse(value)
- def to_python(self, value: Any) -> TA:
- if isinstance(value, amount.Amount):
+ def to_python(self, value: Any) -> Amount:
+ if isinstance(value, Amount):
return value
try:
if value is None:
- return amount.Amount.parse(settings.TALER_CURRENCY)
- return amount.Amount.parse(value)
- except amount.BadFormatAmount:
+ return Amount.parse(settings.TALER_CURRENCY)
+ return Amount.parse(value)
+ except BadFormatAmount:
raise ValidationError("Invalid input for an amount string: %s" % value)
-def get_zero_amount() -> TA:
- return amount.Amount(settings.TALER_CURRENCY)
+def get_zero_amount() -> Amount:
+ return Amount(settings.TALER_CURRENCY)
class BankAccount(models.Model):
is_public = models.BooleanField(default=False)
diff --git a/talerbank/app/types.py b/talerbank/app/types.py
deleted file mode 100644
index d103340..0000000
--- a/talerbank/app/types.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from typing import TypeVar
-from .amount import Amount
-
-TA = TypeVar('TA', Amount)
diff --git a/talerbank/app/views.py b/talerbank/app/views.py
index 8e2e452..a18c684 100644
--- a/talerbank/app/views.py
+++ b/talerbank/app/views.py
@@ -49,7 +49,7 @@ from .schemas import (validate_pin_tan_args, check_withdraw_session,
LOGGER = logging.getLogger(__name__)
class DebtLimitExceededException(Exception):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__("Debt limit exceeded")
class SameAccountException(Exception):
diff --git a/talerbank/talerconfig.py b/talerbank/talerconfig.py
index a7ca065..41ebf44 100644
--- a/talerbank/talerconfig.py
+++ b/talerbank/talerconfig.py
@@ -24,6 +24,7 @@ import os
import weakref
import sys
import re
+from typing import Callable, Any
LOGGER = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class ExpansionSyntaxError(Exception):
pass
-def expand(var, getter):
+def expand(var: str, getter: Callable[[str], str]) -> str:
"""
Do shell-style parameter expansion.
Supported syntax:
@@ -93,17 +94,17 @@ def expand(var, getter):
class OptionDict(collections.defaultdict):
- def __init__(self, config, section_name):
+ def __init__(self, config: SectionDict, section_name: str) -> None:
self.config = weakref.ref(config)
self.section_name = section_name
super().__init__()
- def __missing__(self, key):
+ def __missing__(self, key: str) -> Entry:
entry = Entry(self.config(), self.section_name, key)
self[key] = entry
return entry
- def __getitem__(self, chunk):
+ def __getitem__(self, chunk: str) -> Entry:
return super().__getitem__(chunk.lower())
- def __setitem__(self, chunk, value):
+ def __setitem__(self, chunk: str, value: Entry) -> None:
super().__setitem__(chunk.lower(), value)
@@ -112,14 +113,13 @@ class SectionDict(collections.defaultdict):
value = OptionDict(self, key)
self[key] = value
return value
- def __getitem__(self, chunk):
+ def __getitem__(self, chunk: str) -> OptionDict:
return super().__getitem__(chunk.lower())
- def __setitem__(self, chunk, value):
+ def __setitem__(self, chunk: str, value: OptionDict) -> None:
super().__setitem__(chunk.lower(), value)
-
class Entry:
- def __init__(self, config, section, option, **kwargs):
+ def __init__(self, config: SectionDict, section: str, option: str, **kwargs) -> None:
self.value = kwargs.get("value")
self.filename = kwargs.get("filename")
self.lineno = kwargs.get("lineno")
@@ -127,14 +127,14 @@ class Entry:
self.option = option
self.config = weakref.ref(config)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<Entry section=%s, option=%s, value=%s>" \
% (self.section, self.option, repr(self.value),)
- def __str__(self):
+ def __str__(self) -> Any:
return self.value
- def value_string(self, default=None, required=False, warn=False):
+ def value_string(self, default=None, required=False, warn=False) -> str:
if required and self.value is None:
raise ConfigurationError("Missing required option '%s' in section '%s'" \
% (self.option.upper(), self.section.upper()))
@@ -149,7 +149,7 @@ class Entry:
return default
return self.value
- def value_int(self, default=None, required=False, warn=False):
+ def value_int(self, default=None, required=False, warn=False) -> int:
value = self.value_string(default, warn, required)
if value is None:
return None
@@ -159,7 +159,7 @@ class Entry:
raise ConfigurationError("Expected number for option '%s' in section '%s'" \
% (self.option.upper(), self.section.upper()))
- def _getsubst(self, key):
+ def _getsubst(self, key: str) -> Any:
value = self.config()["paths"][key].value
if value is not None:
return value
@@ -168,13 +168,13 @@ class Entry:
return value
return None
- def value_filename(self, default=None, required=False, warn=False):
+ def value_filename(self, default=None, required=False, warn=False) -> str:
value = self.value_string(default, required, warn)
if value is None:
return None
return expand(value, self._getsubst)
- def location(self):
+ def location(self) -> str:
if self.filename is None or self.lineno is None:
return "<unknown>"
return "%s:%s" % (self.filename, self.lineno)
@@ -185,16 +185,16 @@ class TalerConfig:
One loaded taler configuration, including base configuration
files and included files.
"""
- def __init__(self):
+ def __init__(self) -> None:
"""
Initialize an empty configuration
"""
- self.sections = SectionDict()
+ self.sections = SectionDict() # just plain dict
# defaults != config file: the first is the 'base'
# whereas the second overrides things from the first.
@staticmethod
- def from_file(filename=None, load_defaults=True):
+ def from_file(filename=None, load_defaults=True) -> TalerConfig:
cfg = TalerConfig()
if filename is None:
xdg = os.environ.get("XDG_CONFIG_HOME")
@@ -207,19 +207,19 @@ class TalerConfig:
cfg.load_file(filename)
return cfg
- def value_string(self, section, option, **kwargs):
+ def value_string(self, section, option, **kwargs) -> str:
return self.sections[section][option].value_string(
kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))
- def value_filename(self, section, option, **kwargs):
+ def value_filename(self, section, option, **kwargs) -> str:
return self.sections[section][option].value_filename(
kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))
- def value_int(self, section, option, **kwargs):
+ def value_int(self, section, option, **kwargs) -> int:
return self.sections[section][option].value_int(
kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))
- def load_defaults(self):
+ def load_defaults(self) -> None:
base_dir = os.environ.get("TALER_BASE_CONFIG")
if base_dir:
self.load_dir(base_dir)
@@ -237,7 +237,7 @@ class TalerConfig:
LOGGER.warning("no base directory found")
@staticmethod
- def from_env(*args, **kwargs):
+ def from_env(*args, **kwargs) -> TalerConfig:
"""
Load configuration from environment variable TALER_CONFIG_FILE
or from default location if the variable is not set.
@@ -245,7 +245,7 @@ class TalerConfig:
filename = os.environ.get("TALER_CONFIG_FILE")
return TalerConfig.from_file(filename, *args, **kwargs)
- def load_dir(self, dirname):
+ def load_dir(self, dirname) -> None:
try:
files = os.listdir(dirname)
except FileNotFoundError:
@@ -256,7 +256,7 @@ class TalerConfig:
continue
self.load_file(os.path.join(dirname, file))
- def load_file(self, filename):
+ def load_file(self, filename) -> None:
sections = self.sections
try:
with open(filename, "r") as file:
@@ -300,7 +300,7 @@ class TalerConfig:
sys.exit(3)
- def dump(self):
+ def dump(self) -> None:
for kv_section in self.sections.items():
print("[%s]" % (kv_section[1].section_name,))
for kv_option in kv_section[1].items():
@@ -309,7 +309,7 @@ class TalerConfig:
kv_option[1].value,
kv_option[1].location()))
- def __getitem__(self, chunk):
+ def __getitem__(self, chunk: str) -> OptionDict:
if isinstance(chunk, str):
return self.sections[chunk]
raise TypeError("index must be string")