quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

bignum_common.py (15228B)


      1 """Common features for bignum in test generation framework."""
      2 # Copyright The Mbed TLS Contributors
      3 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
      4 #
      5 
      6 from abc import abstractmethod
      7 import enum
      8 from typing import Iterator, List, Tuple, TypeVar, Any
      9 from copy import deepcopy
     10 from itertools import chain
     11 from math import ceil
     12 
     13 from . import test_case
     14 from . import test_data_generation
     15 from .bignum_data import INPUTS_DEFAULT, MODULI_DEFAULT
     16 
     17 T = TypeVar('T') #pylint: disable=invalid-name
     18 
     19 def invmod(a: int, n: int) -> int:
     20     """Return inverse of a to modulo n.
     21 
     22     Equivalent to pow(a, -1, n) in Python 3.8+. Implementation is equivalent
     23     to long_invmod() in CPython.
     24     """
     25     b, c = 1, 0
     26     while n:
     27         q, r = divmod(a, n)
     28         a, b, c, n = n, c, b - q*c, r
     29     # at this point a is the gcd of the original inputs
     30     if a == 1:
     31         return b
     32     raise ValueError("Not invertible")
     33 
     34 def invmod_positive(a: int, n: int) -> int:
     35     """Return a non-negative inverse of a to modulo n."""
     36     inv = invmod(a, n)
     37     return inv if inv >= 0 else inv + n
     38 
     39 def hex_to_int(val: str) -> int:
     40     """Implement the syntax accepted by mbedtls_test_read_mpi().
     41 
     42     This is a superset of what is accepted by mbedtls_test_read_mpi_core().
     43     """
     44     if val in ['', '-']:
     45         return 0
     46     return int(val, 16)
     47 
     48 def quote_str(val: str) -> str:
     49     return "\"{}\"".format(val)
     50 
     51 def bound_mpi(val: int, bits_in_limb: int) -> int:
     52     """First number exceeding number of limbs needed for given input value."""
     53     return bound_mpi_limbs(limbs_mpi(val, bits_in_limb), bits_in_limb)
     54 
     55 def bound_mpi_limbs(limbs: int, bits_in_limb: int) -> int:
     56     """First number exceeding maximum of given number of limbs."""
     57     bits = bits_in_limb * limbs
     58     return 1 << bits
     59 
     60 def limbs_mpi(val: int, bits_in_limb: int) -> int:
     61     """Return the number of limbs required to store value."""
     62     bit_length = max(val.bit_length(), 1)
     63     return (bit_length + bits_in_limb - 1) // bits_in_limb
     64 
     65 def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
     66     """Return all pair combinations from input values."""
     67     return [(x, y) for x in values for y in values]
     68 
     69 def bits_to_limbs(bits: int, bits_in_limb: int) -> int:
     70     """ Return the appropriate ammount of limbs needed to store
     71         a number contained in input bits"""
     72     return ceil(bits / bits_in_limb)
     73 
     74 def hex_digits_for_limb(limbs: int, bits_in_limb: int) -> int:
     75     """ Return the hex digits need for a number of limbs. """
     76     return 2 * ((limbs * bits_in_limb) // 8)
     77 
     78 def hex_digits_max_int(val: str, bits_in_limb: int) -> int:
     79     """ Return the first number exceeding maximum  the limb space
     80     required to store the input hex-string value. This method
     81     weights on the input str_len rather than numerical value
     82     and works with zero-padded inputs"""
     83     n = ((1 << (len(val) * 4)) - 1)
     84     l = limbs_mpi(n, bits_in_limb)
     85     return bound_mpi_limbs(l, bits_in_limb)
     86 
     87 def zfill_match(reference: str, target: str) -> str:
     88     """ Zero pad target hex-string to match the limb size of
     89     the reference input """
     90     lt = len(target)
     91     lr = len(reference)
     92     target_len = lr if lt < lr else lt
     93     return "{:x}".format(int(target, 16)).zfill(target_len)
     94 
     95 class OperationCommon(test_data_generation.BaseTest):
     96     """Common features for bignum binary operations.
     97 
     98     This adds functionality common in binary operation tests.
     99 
    100     Attributes:
    101         symbol: Symbol to use for the operation in case description.
    102         input_values: List of values to use as test case inputs. These are
    103             combined to produce pairs of values.
    104         input_cases: List of tuples containing pairs of test case inputs. This
    105             can be used to implement specific pairs of inputs.
    106         unique_combinations_only: Boolean to select if test case combinations
    107             must be unique. If True, only A,B or B,A would be included as a test
    108             case. If False, both A,B and B,A would be included.
    109         input_style: Controls the way how test data is passed to the functions
    110             in the generated test cases. "variable" passes them as they are
    111             defined in the python source. "arch_split" pads the values with
    112             zeroes depending on the architecture/limb size. If this is set,
    113             test cases are generated for all architectures.
    114         arity: the number of operands for the operation. Currently supported
    115             values are 1 and 2.
    116     """
    117     symbol = ""
    118     input_values = INPUTS_DEFAULT # type: List[str]
    119     input_cases = [] # type: List[Any]
    120     dependencies = [] # type: List[Any]
    121     unique_combinations_only = False
    122     input_styles = ["variable", "fixed", "arch_split"] # type: List[str]
    123     input_style = "variable" # type: str
    124     limb_sizes = [32, 64] # type: List[int]
    125     arities = [1, 2]
    126     arity = 2
    127     suffix = False   # for arity = 1, symbol can be prefix (default) or suffix
    128 
    129     def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None:
    130         self.val_a = val_a
    131         self.val_b = val_b
    132         # Setting the int versions here as opposed to making them @properties
    133         # provides earlier/more robust input validation.
    134         self.int_a = hex_to_int(val_a)
    135         self.int_b = hex_to_int(val_b)
    136         self.dependencies = deepcopy(self.dependencies)
    137         if bits_in_limb not in self.limb_sizes:
    138             raise ValueError("Invalid number of bits in limb!")
    139         if self.input_style == "arch_split":
    140             self.dependencies.append("MBEDTLS_HAVE_INT{:d}".format(bits_in_limb))
    141         self.bits_in_limb = bits_in_limb
    142 
    143     @property
    144     def boundary(self) -> int:
    145         if self.arity == 1:
    146             return self.int_a
    147         elif self.arity == 2:
    148             return max(self.int_a, self.int_b)
    149         raise ValueError("Unsupported number of operands!")
    150 
    151     @property
    152     def limb_boundary(self) -> int:
    153         return bound_mpi(self.boundary, self.bits_in_limb)
    154 
    155     @property
    156     def limbs(self) -> int:
    157         return limbs_mpi(self.boundary, self.bits_in_limb)
    158 
    159     @property
    160     def hex_digits(self) -> int:
    161         return hex_digits_for_limb(self.limbs, self.bits_in_limb)
    162 
    163     def format_arg(self, val: str) -> str:
    164         if self.input_style not in self.input_styles:
    165             raise ValueError("Unknown input style!")
    166         if self.input_style == "variable":
    167             return val
    168         else:
    169             return val.zfill(self.hex_digits)
    170 
    171     def format_result(self, res: int) -> str:
    172         res_str = '{:x}'.format(res)
    173         return quote_str(self.format_arg(res_str))
    174 
    175     @property
    176     def arg_a(self) -> str:
    177         return self.format_arg(self.val_a)
    178 
    179     @property
    180     def arg_b(self) -> str:
    181         if self.arity == 1:
    182             raise AttributeError("Operation is unary and doesn't have arg_b!")
    183         return self.format_arg(self.val_b)
    184 
    185     def arguments(self) -> List[str]:
    186         args = [quote_str(self.arg_a)]
    187         if self.arity == 2:
    188             args.append(quote_str(self.arg_b))
    189         return args + self.result()
    190 
    191     def description(self) -> str:
    192         """Generate a description for the test case.
    193 
    194         If not set, case_description uses the form A `symbol` B, where symbol
    195         is used to represent the operation. Descriptions of each value are
    196         generated to provide some context to the test case.
    197         """
    198         if not self.case_description:
    199             if self.arity == 1:
    200                 format_string = "{1:x} {0}" if self.suffix else "{0} {1:x}"
    201                 self.case_description = format_string.format(
    202                     self.symbol, self.int_a
    203                 )
    204             elif self.arity == 2:
    205                 self.case_description = "{:x} {} {:x}".format(
    206                     self.int_a, self.symbol, self.int_b
    207                 )
    208         return super().description()
    209 
    210     @property
    211     def is_valid(self) -> bool:
    212         return True
    213 
    214     @abstractmethod
    215     def result(self) -> List[str]:
    216         """Get the result of the operation.
    217 
    218         This could be calculated during initialization and stored as `_result`
    219         and then returned, or calculated when the method is called.
    220         """
    221         raise NotImplementedError
    222 
    223     @classmethod
    224     def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
    225         """Generator to yield pairs of inputs.
    226 
    227         Combinations are first generated from all input values, and then
    228         specific cases provided.
    229         """
    230         if cls.arity == 1:
    231             yield from ((a, "0") for a in cls.input_values)
    232         elif cls.arity == 2:
    233             if cls.unique_combinations_only:
    234                 yield from combination_pairs(cls.input_values)
    235             else:
    236                 yield from (
    237                     (a, b)
    238                     for a in cls.input_values
    239                     for b in cls.input_values
    240                 )
    241         else:
    242             raise ValueError("Unsupported number of operands!")
    243 
    244     @classmethod
    245     def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
    246         if cls.input_style not in cls.input_styles:
    247             raise ValueError("Unknown input style!")
    248         if cls.arity not in cls.arities:
    249             raise ValueError("Unsupported number of operands!")
    250         if cls.input_style == "arch_split":
    251             test_objects = (cls(a, b, bits_in_limb=bil)
    252                             for a, b in cls.get_value_pairs()
    253                             for bil in cls.limb_sizes)
    254             special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
    255                              for args in cls.input_cases
    256                              for bil in cls.limb_sizes)
    257         else:
    258             test_objects = (cls(a, b)
    259                             for a, b in cls.get_value_pairs())
    260             special_cases = (cls(*args) for args in cls.input_cases)
    261         yield from (valid_test_object.create_test_case()
    262                     for valid_test_object in filter(
    263                         lambda test_object: test_object.is_valid,
    264                         chain(test_objects, special_cases)
    265                         )
    266                     )
    267 
    268 
    269 class ModulusRepresentation(enum.Enum):
    270     """Representation selector of a modulus."""
    271     # Numerical values aligned with the type mbedtls_mpi_mod_rep_selector
    272     INVALID = 0
    273     MONTGOMERY = 2
    274     OPT_RED = 3
    275 
    276     def symbol(self) -> str:
    277         """The C symbol for this representation selector."""
    278         return 'MBEDTLS_MPI_MOD_REP_' + self.name
    279 
    280     @classmethod
    281     def supported_representations(cls) -> List['ModulusRepresentation']:
    282         """Return all representations that are supported in positive test cases."""
    283         return [cls.MONTGOMERY, cls.OPT_RED]
    284 
    285 
    286 class ModOperationCommon(OperationCommon):
    287     #pylint: disable=abstract-method
    288     """Target for bignum mod_raw test case generation."""
    289     moduli = MODULI_DEFAULT # type: List[str]
    290     montgomery_form_a = False
    291     disallow_zero_a = False
    292 
    293     def __init__(self, val_n: str, val_a: str, val_b: str = "0",
    294                  bits_in_limb: int = 64) -> None:
    295         super().__init__(val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb)
    296         self.val_n = val_n
    297         # Setting the int versions here as opposed to making them @properties
    298         # provides earlier/more robust input validation.
    299         self.int_n = hex_to_int(val_n)
    300 
    301     def to_montgomery(self, val: int) -> int:
    302         return (val * self.r) % self.int_n
    303 
    304     def from_montgomery(self, val: int) -> int:
    305         return (val * self.r_inv) % self.int_n
    306 
    307     def convert_from_canonical(self, canonical: int,
    308                                rep: ModulusRepresentation) -> int:
    309         """Convert values from canonical representation to the given representation."""
    310         if rep is ModulusRepresentation.MONTGOMERY:
    311             return self.to_montgomery(canonical)
    312         elif rep is ModulusRepresentation.OPT_RED:
    313             return canonical
    314         else:
    315             raise ValueError('Modulus representation not supported: {}'
    316                              .format(rep.name))
    317 
    318     @property
    319     def boundary(self) -> int:
    320         return self.int_n
    321 
    322     @property
    323     def arg_a(self) -> str:
    324         if self.montgomery_form_a:
    325             value_a = self.to_montgomery(self.int_a)
    326         else:
    327             value_a = self.int_a
    328         return self.format_arg('{:x}'.format(value_a))
    329 
    330     @property
    331     def arg_n(self) -> str:
    332         return self.format_arg(self.val_n)
    333 
    334     def format_arg(self, val: str) -> str:
    335         return super().format_arg(val).zfill(self.hex_digits)
    336 
    337     def arguments(self) -> List[str]:
    338         return [quote_str(self.arg_n)] + super().arguments()
    339 
    340     @property
    341     def r(self) -> int: # pylint: disable=invalid-name
    342         l = limbs_mpi(self.int_n, self.bits_in_limb)
    343         return bound_mpi_limbs(l, self.bits_in_limb)
    344 
    345     @property
    346     def r_inv(self) -> int:
    347         return invmod(self.r, self.int_n)
    348 
    349     @property
    350     def r2(self) -> int: # pylint: disable=invalid-name
    351         return pow(self.r, 2)
    352 
    353     @property
    354     def is_valid(self) -> bool:
    355         if self.int_a >= self.int_n:
    356             return False
    357         if self.disallow_zero_a and self.int_a == 0:
    358             return False
    359         if self.arity == 2 and self.int_b >= self.int_n:
    360             return False
    361         return True
    362 
    363     def description(self) -> str:
    364         """Generate a description for the test case.
    365 
    366         It uses the form A `symbol` B mod N, where symbol is used to represent
    367         the operation.
    368         """
    369 
    370         if not self.case_description:
    371             return super().description() + " mod {:x}".format(self.int_n)
    372         return super().description()
    373 
    374     @classmethod
    375     def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]:
    376         if cls.arity == 1:
    377             yield from ((n, a, "0") for a, n in cls.input_cases)
    378         elif cls.arity == 2:
    379             yield from ((n, a, b) for a, b, n in cls.input_cases)
    380         else:
    381             raise ValueError("Unsupported number of operands!")
    382 
    383     @classmethod
    384     def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
    385         if cls.input_style not in cls.input_styles:
    386             raise ValueError("Unknown input style!")
    387         if cls.arity not in cls.arities:
    388             raise ValueError("Unsupported number of operands!")
    389         if cls.input_style == "arch_split":
    390             test_objects = (cls(n, a, b, bits_in_limb=bil)
    391                             for n in cls.moduli
    392                             for a, b in cls.get_value_pairs()
    393                             for bil in cls.limb_sizes)
    394             special_cases = (cls(*args, bits_in_limb=bil)
    395                              for args in cls.input_cases_args()
    396                              for bil in cls.limb_sizes)
    397         else:
    398             test_objects = (cls(n, a, b)
    399                             for n in cls.moduli
    400                             for a, b in cls.get_value_pairs())
    401             special_cases = (cls(*args) for args in cls.input_cases_args())
    402         yield from (valid_test_object.create_test_case()
    403                     for valid_test_object in filter(
    404                         lambda test_object: test_object.is_valid,
    405                         chain(test_objects, special_cases)
    406                         ))