quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

macro_collector.py (23318B)


      1 """Collect macro definitions from header files.
      2 """
      3 
      4 # Copyright The Mbed TLS Contributors
      5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
      6 #
      7 
      8 import itertools
      9 import re
     10 from typing import Dict, IO, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
     11 
     12 
     13 class ReadFileLineException(Exception):
     14     def __init__(self, filename: str, line_number: Union[int, str]) -> None:
     15         message = 'in {} at {}'.format(filename, line_number)
     16         super(ReadFileLineException, self).__init__(message)
     17         self.filename = filename
     18         self.line_number = line_number
     19 
     20 
     21 class read_file_lines:
     22     # Dear Pylint, conventionally, a context manager class name is lowercase.
     23     # pylint: disable=invalid-name,too-few-public-methods
     24     """Context manager to read a text file line by line.
     25 
     26     ```
     27     with read_file_lines(filename) as lines:
     28         for line in lines:
     29             process(line)
     30     ```
     31     is equivalent to
     32     ```
     33     with open(filename, 'r') as input_file:
     34         for line in input_file:
     35             process(line)
     36     ```
     37     except that if process(line) raises an exception, then the read_file_lines
     38     snippet annotates the exception with the file name and line number.
     39     """
     40     def __init__(self, filename: str, binary: bool = False) -> None:
     41         self.filename = filename
     42         self.file = None #type: Optional[IO[str]]
     43         self.line_number = 'entry' #type: Union[int, str]
     44         self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
     45         self.binary = binary
     46     def __enter__(self) -> 'read_file_lines':
     47         self.file = open(self.filename, 'rb' if self.binary else 'r')
     48         self.generator = enumerate(self.file)
     49         return self
     50     def __iter__(self) -> Iterator[str]:
     51         assert self.generator is not None
     52         for line_number, content in self.generator:
     53             self.line_number = line_number
     54             yield content
     55         self.line_number = 'exit'
     56     def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
     57         if self.file is not None:
     58             self.file.close()
     59         if exc_type is not None:
     60             raise ReadFileLineException(self.filename, self.line_number) \
     61                 from exc_value
     62 
     63 
     64 class PSAMacroEnumerator:
     65     """Information about constructors of various PSA Crypto types.
     66 
     67     This includes macro names as well as information about their arguments
     68     when applicable.
     69 
     70     This class only provides ways to enumerate expressions that evaluate to
     71     values of the covered types. Derived classes are expected to populate
     72     the set of known constructors of each kind, as well as populate
     73     `self.arguments_for` for arguments that are not of a kind that is
     74     enumerated here.
     75     """
     76     #pylint: disable=too-many-instance-attributes
     77 
     78     def __init__(self) -> None:
     79         """Set up an empty set of known constructor macros.
     80         """
     81         self.statuses = set() #type: Set[str]
     82         self.lifetimes = set() #type: Set[str]
     83         self.locations = set() #type: Set[str]
     84         self.persistence_levels = set() #type: Set[str]
     85         self.algorithms = set() #type: Set[str]
     86         self.ecc_curves = set() #type: Set[str]
     87         self.dh_groups = set() #type: Set[str]
     88         self.key_types = set() #type: Set[str]
     89         self.key_usage_flags = set() #type: Set[str]
     90         self.hash_algorithms = set() #type: Set[str]
     91         self.mac_algorithms = set() #type: Set[str]
     92         self.ka_algorithms = set() #type: Set[str]
     93         self.kdf_algorithms = set() #type: Set[str]
     94         self.pake_algorithms = set() #type: Set[str]
     95         self.aead_algorithms = set() #type: Set[str]
     96         self.sign_algorithms = set() #type: Set[str]
     97         # macro name -> list of argument names
     98         self.argspecs = {} #type: Dict[str, List[str]]
     99         # argument name -> list of values
    100         self.arguments_for = {
    101             'mac_length': [],
    102             'min_mac_length': [],
    103             'tag_length': [],
    104             'min_tag_length': [],
    105         } #type: Dict[str, List[str]]
    106         # Whether to include intermediate macros in enumerations. Intermediate
    107         # macros serve as category headers and are not valid values of their
    108         # type. See `is_internal_name`.
    109         # Always false in this class, may be set to true in derived classes.
    110         self.include_intermediate = False
    111 
    112     def is_internal_name(self, name: str) -> bool:
    113         """Whether this is an internal macro. Internal macros will be skipped."""
    114         if not self.include_intermediate:
    115             if name.endswith('_BASE') or name.endswith('_NONE'):
    116                 return True
    117             if '_CATEGORY_' in name:
    118                 return True
    119         return name.endswith('_FLAG') or name.endswith('_MASK')
    120 
    121     def gather_arguments(self) -> None:
    122         """Populate the list of values for macro arguments.
    123 
    124         Call this after parsing all the inputs.
    125         """
    126         self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
    127         self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
    128         self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
    129         self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
    130         self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
    131         self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
    132         self.arguments_for['curve'] = sorted(self.ecc_curves)
    133         self.arguments_for['group'] = sorted(self.dh_groups)
    134         self.arguments_for['persistence'] = sorted(self.persistence_levels)
    135         self.arguments_for['location'] = sorted(self.locations)
    136         self.arguments_for['lifetime'] = sorted(self.lifetimes)
    137 
    138     @staticmethod
    139     def _format_arguments(name: str, arguments: Iterable[str]) -> str:
    140         """Format a macro call with arguments.
    141 
    142         The resulting format is consistent with
    143         `InputsForTest.normalize_argument`.
    144         """
    145         return name + '(' + ', '.join(arguments) + ')'
    146 
    147     _argument_split_re = re.compile(r' *, *')
    148     @classmethod
    149     def _argument_split(cls, arguments: str) -> List[str]:
    150         return re.split(cls._argument_split_re, arguments)
    151 
    152     def distribute_arguments(self, name: str) -> Iterator[str]:
    153         """Generate macro calls with each tested argument set.
    154 
    155         If name is a macro without arguments, just yield "name".
    156         If name is a macro with arguments, yield a series of
    157         "name(arg1,...,argN)" where each argument takes each possible
    158         value at least once.
    159         """
    160         try:
    161             if name not in self.argspecs:
    162                 yield name
    163                 return
    164             argspec = self.argspecs[name]
    165             if argspec == []:
    166                 yield name + '()'
    167                 return
    168             argument_lists = [self.arguments_for[arg] for arg in argspec]
    169             arguments = [values[0] for values in argument_lists]
    170             yield self._format_arguments(name, arguments)
    171             # Dear Pylint, enumerate won't work here since we're modifying
    172             # the array.
    173             # pylint: disable=consider-using-enumerate
    174             for i in range(len(arguments)):
    175                 for value in argument_lists[i][1:]:
    176                     arguments[i] = value
    177                     yield self._format_arguments(name, arguments)
    178                 arguments[i] = argument_lists[i][0]
    179         except BaseException as e:
    180             raise Exception('distribute_arguments({})'.format(name)) from e
    181 
    182     def distribute_arguments_without_duplicates(
    183             self, seen: Set[str], name: str
    184     ) -> Iterator[str]:
    185         """Same as `distribute_arguments`, but don't repeat seen results."""
    186         for result in self.distribute_arguments(name):
    187             if result not in seen:
    188                 seen.add(result)
    189                 yield result
    190 
    191     def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
    192         """Generate expressions covering values constructed from the given names.
    193 
    194         `names` can be any iterable collection of macro names.
    195 
    196         For example:
    197         * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
    198           generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
    199           every known hash algorithm ``h``.
    200         * ``macros.generate_expressions(macros.key_types)`` generates all
    201           key types.
    202         """
    203         seen = set() #type: Set[str]
    204         return itertools.chain(*(
    205             self.distribute_arguments_without_duplicates(seen, name)
    206             for name in names
    207         ))
    208 
    209 
    210 class PSAMacroCollector(PSAMacroEnumerator):
    211     """Collect PSA crypto macro definitions from C header files.
    212     """
    213 
    214     def __init__(self, include_intermediate: bool = False) -> None:
    215         """Set up an object to collect PSA macro definitions.
    216 
    217         Call the read_file method of the constructed object on each header file.
    218 
    219         * include_intermediate: if true, include intermediate macros such as
    220           PSA_XXX_BASE that do not designate semantic values.
    221         """
    222         super().__init__()
    223         self.include_intermediate = include_intermediate
    224         self.key_types_from_curve = {} #type: Dict[str, str]
    225         self.key_types_from_group = {} #type: Dict[str, str]
    226         self.algorithms_from_hash = {} #type: Dict[str, str]
    227 
    228     @staticmethod
    229     def algorithm_tester(name: str) -> str:
    230         """The predicate for whether an algorithm is built from the given constructor.
    231 
    232         The given name must be the name of an algorithm constructor of the
    233         form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build
    234         an algorithm value. Return the corresponding predicate macro which
    235         is used as ``predicate(alg)`` to test whether ``alg`` can be built
    236         as ``PSA_ALG_xxx(yyy)``. The predicate is usually called
    237         ``PSA_ALG_IS_xxx``.
    238         """
    239         prefix = 'PSA_ALG_'
    240         assert name.startswith(prefix)
    241         midfix = 'IS_'
    242         suffix = name[len(prefix):]
    243         if suffix in ['DSA', 'ECDSA']:
    244             midfix += 'RANDOMIZED_'
    245         elif suffix == 'RSA_PSS':
    246             suffix += '_STANDARD_SALT'
    247         return prefix + midfix + suffix
    248 
    249     def record_algorithm_subtype(self, name: str, expansion: str) -> None:
    250         """Record the subtype of an algorithm constructor.
    251 
    252         Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
    253         is of a subtype that is tracked in its own set, add it to the relevant
    254         set.
    255         """
    256         # This code is very ad hoc and fragile. It should be replaced by
    257         # something more robust.
    258         if re.match(r'MAC(?:_|\Z)', name):
    259             self.mac_algorithms.add(name)
    260         elif re.match(r'KDF(?:_|\Z)', name):
    261             self.kdf_algorithms.add(name)
    262         elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
    263             self.hash_algorithms.add(name)
    264         elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
    265             self.mac_algorithms.add(name)
    266         elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
    267             self.aead_algorithms.add(name)
    268         elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
    269             self.ka_algorithms.add(name)
    270         elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
    271             self.kdf_algorithms.add(name)
    272 
    273     # "#define" followed by a macro name with either no parameters
    274     # or a single parameter and a non-empty expansion.
    275     # Grab the macro name in group 1, the parameter name if any in group 2
    276     # and the expansion in group 3.
    277     _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
    278                                       r'(?:\s+|\((\w+)\)\s*)' +
    279                                       r'(.+)')
    280     _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
    281 
    282     def read_line(self, line):
    283         """Parse a C header line and record the PSA identifier it defines if any.
    284         This function analyzes lines that start with "#define PSA_"
    285         (up to non-significant whitespace) and skips all non-matching lines.
    286         """
    287         # pylint: disable=too-many-branches
    288         m = re.match(self._define_directive_re, line)
    289         if not m:
    290             return
    291         name, parameter, expansion = m.groups()
    292         expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
    293         if parameter:
    294             self.argspecs[name] = [parameter]
    295         if re.match(self._deprecated_definition_re, expansion):
    296             # Skip deprecated values, which are assumed to be
    297             # backward compatibility aliases that share
    298             # numerical values with non-deprecated values.
    299             return
    300         if self.is_internal_name(name):
    301             # Macro only to build actual values
    302             return
    303         elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
    304            and not parameter:
    305             self.statuses.add(name)
    306         elif name.startswith('PSA_KEY_TYPE_') and not parameter:
    307             self.key_types.add(name)
    308         elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
    309             self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
    310         elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
    311             self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
    312         elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
    313             self.ecc_curves.add(name)
    314         elif name.startswith('PSA_DH_FAMILY_') and not parameter:
    315             self.dh_groups.add(name)
    316         elif name.startswith('PSA_ALG_') and not parameter:
    317             if name in ['PSA_ALG_ECDSA_BASE',
    318                         'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
    319                 # Ad hoc skipping of duplicate names for some numerical values
    320                 return
    321             self.algorithms.add(name)
    322             self.record_algorithm_subtype(name, expansion)
    323         elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
    324             self.algorithms_from_hash[name] = self.algorithm_tester(name)
    325         elif name.startswith('PSA_KEY_USAGE_') and not parameter:
    326             self.key_usage_flags.add(name)
    327         else:
    328             # Other macro without parameter
    329             return
    330 
    331     _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
    332     _continued_line_re = re.compile(rb'\\\r?\n\Z')
    333     def read_file(self, header_file):
    334         for line in header_file:
    335             m = re.search(self._continued_line_re, line)
    336             while m:
    337                 cont = next(header_file)
    338                 line = line[:m.start(0)] + cont
    339                 m = re.search(self._continued_line_re, line)
    340             line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
    341             self.read_line(line)
    342 
    343 
    344 class InputsForTest(PSAMacroEnumerator):
    345     # pylint: disable=too-many-instance-attributes
    346     """Accumulate information about macros to test.
    347 enumerate
    348     This includes macro names as well as information about their arguments
    349     when applicable.
    350     """
    351 
    352     def __init__(self) -> None:
    353         super().__init__()
    354         self.all_declared = set() #type: Set[str]
    355         # Identifier prefixes
    356         self.table_by_prefix = {
    357             'ERROR': self.statuses,
    358             'ALG': self.algorithms,
    359             'ECC_CURVE': self.ecc_curves,
    360             'DH_GROUP': self.dh_groups,
    361             'KEY_LIFETIME': self.lifetimes,
    362             'KEY_LOCATION': self.locations,
    363             'KEY_PERSISTENCE': self.persistence_levels,
    364             'KEY_TYPE': self.key_types,
    365             'KEY_USAGE': self.key_usage_flags,
    366         } #type: Dict[str, Set[str]]
    367         # Test functions
    368         self.table_by_test_function = {
    369             # Any function ending in _algorithm also gets added to
    370             # self.algorithms.
    371             'key_type': [self.key_types],
    372             'block_cipher_key_type': [self.key_types],
    373             'stream_cipher_key_type': [self.key_types],
    374             'ecc_key_family': [self.ecc_curves],
    375             'ecc_key_types': [self.ecc_curves],
    376             'dh_key_family': [self.dh_groups],
    377             'dh_key_types': [self.dh_groups],
    378             'hash_algorithm': [self.hash_algorithms],
    379             'mac_algorithm': [self.mac_algorithms],
    380             'cipher_algorithm': [],
    381             'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
    382             'aead_algorithm': [self.aead_algorithms],
    383             'key_derivation_algorithm': [self.kdf_algorithms],
    384             'key_agreement_algorithm': [self.ka_algorithms],
    385             'asymmetric_signature_algorithm': [self.sign_algorithms],
    386             'asymmetric_signature_wildcard': [self.algorithms],
    387             'asymmetric_encryption_algorithm': [],
    388             'pake_algorithm': [self.pake_algorithms],
    389             'other_algorithm': [],
    390             'lifetime': [self.lifetimes],
    391         } #type: Dict[str, List[Set[str]]]
    392         mac_lengths = [str(n) for n in [
    393             1,  # minimum expressible
    394             4,  # minimum allowed by policy
    395             13, # an odd size in a plausible range
    396             14, # an even non-power-of-two size in a plausible range
    397             16, # same as full size for at least one algorithm
    398             63, # maximum expressible
    399         ]]
    400         self.arguments_for['mac_length'] += mac_lengths
    401         self.arguments_for['min_mac_length'] += mac_lengths
    402         aead_lengths = [str(n) for n in [
    403             1,  # minimum expressible
    404             4,  # minimum allowed by policy
    405             13, # an odd size in a plausible range
    406             14, # an even non-power-of-two size in a plausible range
    407             16, # same as full size for at least one algorithm
    408             63, # maximum expressible
    409         ]]
    410         self.arguments_for['tag_length'] += aead_lengths
    411         self.arguments_for['min_tag_length'] += aead_lengths
    412 
    413     def add_numerical_values(self) -> None:
    414         """Add numerical values that are not supported to the known identifiers."""
    415         # Sets of names per type
    416         self.algorithms.add('0xffffffff')
    417         self.ecc_curves.add('0xff')
    418         self.dh_groups.add('0xff')
    419         self.key_types.add('0xffff')
    420         self.key_usage_flags.add('0x80000000')
    421 
    422         # Hard-coded values for unknown algorithms
    423         #
    424         # These have to have values that are correct for their respective
    425         # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
    426         # not likely to be assigned in the near future.
    427         self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
    428         self.mac_algorithms.add('0x03007fff')
    429         self.ka_algorithms.add('0x09fc0000')
    430         self.kdf_algorithms.add('0x080000ff')
    431         self.pake_algorithms.add('0x0a0000ff')
    432         # For AEAD algorithms, the only variability is over the tag length,
    433         # and this only applies to known algorithms, so don't test an
    434         # unknown algorithm.
    435 
    436     def get_names(self, type_word: str) -> Set[str]:
    437         """Return the set of known names of values of the given type."""
    438         return {
    439             'status': self.statuses,
    440             'algorithm': self.algorithms,
    441             'ecc_curve': self.ecc_curves,
    442             'dh_group': self.dh_groups,
    443             'key_type': self.key_types,
    444             'key_usage': self.key_usage_flags,
    445         }[type_word]
    446 
    447     # Regex for interesting header lines.
    448     # Groups: 1=macro name, 2=type, 3=argument list (optional).
    449     _header_line_re = \
    450         re.compile(r'#define +' +
    451                    r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
    452                    r'(?:\(([^\n()]*)\))?')
    453     # Regex of macro names to exclude.
    454     _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
    455     # Additional excluded macros.
    456     _excluded_names = set([
    457         # Macros that provide an alternative way to build the same
    458         # algorithm as another macro.
    459         'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
    460         'PSA_ALG_FULL_LENGTH_MAC',
    461         # Auxiliary macro whose name doesn't fit the usual patterns for
    462         # auxiliary macros.
    463         'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
    464     ])
    465     def parse_header_line(self, line: str) -> None:
    466         """Parse a C header line, looking for "#define PSA_xxx"."""
    467         m = re.match(self._header_line_re, line)
    468         if not m:
    469             return
    470         name = m.group(1)
    471         self.all_declared.add(name)
    472         if re.search(self._excluded_name_re, name) or \
    473            name in self._excluded_names or \
    474            self.is_internal_name(name):
    475             return
    476         dest = self.table_by_prefix.get(m.group(2))
    477         if dest is None:
    478             return
    479         dest.add(name)
    480         if m.group(3):
    481             self.argspecs[name] = self._argument_split(m.group(3))
    482 
    483     _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
    484     def parse_header(self, filename: str) -> None:
    485         """Parse a C header file, looking for "#define PSA_xxx"."""
    486         with read_file_lines(filename, binary=True) as lines:
    487             for line in lines:
    488                 line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
    489                 self.parse_header_line(line)
    490 
    491     _macro_identifier_re = re.compile(r'[A-Z]\w+')
    492     def generate_undeclared_names(self, expr: str) -> Iterable[str]:
    493         for name in re.findall(self._macro_identifier_re, expr):
    494             if name not in self.all_declared:
    495                 yield name
    496 
    497     def accept_test_case_line(self, function: str, argument: str) -> bool:
    498         #pylint: disable=unused-argument
    499         undeclared = list(self.generate_undeclared_names(argument))
    500         if undeclared:
    501             raise Exception('Undeclared names in test case', undeclared)
    502         return True
    503 
    504     @staticmethod
    505     def normalize_argument(argument: str) -> str:
    506         """Normalize whitespace in the given C expression.
    507 
    508         The result uses the same whitespace as
    509         ` PSAMacroEnumerator.distribute_arguments`.
    510         """
    511         return re.sub(r',', r', ', re.sub(r' +', r'', argument))
    512 
    513     def add_test_case_line(self, function: str, argument: str) -> None:
    514         """Parse a test case data line, looking for algorithm metadata tests."""
    515         sets = []
    516         if function.endswith('_algorithm'):
    517             sets.append(self.algorithms)
    518             if function == 'key_agreement_algorithm' and \
    519                argument.startswith('PSA_ALG_KEY_AGREEMENT('):
    520                 # We only want *raw* key agreement algorithms as such, so
    521                 # exclude ones that are already chained with a KDF.
    522                 # Keep the expression as one to test as an algorithm.
    523                 function = 'other_algorithm'
    524         sets += self.table_by_test_function[function]
    525         if self.accept_test_case_line(function, argument):
    526             for s in sets:
    527                 s.add(self.normalize_argument(argument))
    528 
    529     # Regex matching a *.data line containing a test function call and
    530     # its arguments. The actual definition is partly positional, but this
    531     # regex is good enough in practice.
    532     _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
    533     def parse_test_cases(self, filename: str) -> None:
    534         """Parse a test case file (*.data), looking for algorithm metadata tests."""
    535         with read_file_lines(filename) as lines:
    536             for line in lines:
    537                 m = re.match(self._test_case_line_re, line)
    538                 if m:
    539                     self.add_test_case_line(m.group(1), m.group(2))