quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

test_psa_constant_names.py (8645B)


      1 #!/usr/bin/env python3
      2 """Test the program psa_constant_names.
      3 Gather constant names from header files and test cases. Compile a C program
      4 to print out their numerical values, feed these numerical values to
      5 psa_constant_names, and check that the output is the original name.
      6 Return 0 if all test cases pass, 1 if the output was not always as expected,
      7 or 1 (with a Python backtrace) if there was an operational error.
      8 """
      9 
     10 # Copyright The Mbed TLS Contributors
     11 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
     12 
     13 import argparse
     14 from collections import namedtuple
     15 import os
     16 import re
     17 import subprocess
     18 import sys
     19 from typing import Iterable, List, Optional, Tuple
     20 
     21 from mbedtls_framework import build_tree
     22 from mbedtls_framework import c_build_helper
     23 from mbedtls_framework.macro_collector import InputsForTest, PSAMacroEnumerator
     24 from mbedtls_framework import typing_util
     25 
     26 def gather_inputs(headers: Iterable[str],
     27                   test_suites: Iterable[str],
     28                   inputs_class=InputsForTest) -> PSAMacroEnumerator:
     29     """Read the list of inputs to test psa_constant_names with."""
     30     inputs = inputs_class()
     31     for header in headers:
     32         inputs.parse_header(header)
     33     for test_cases in test_suites:
     34         inputs.parse_test_cases(test_cases)
     35     inputs.add_numerical_values()
     36     inputs.gather_arguments()
     37     return inputs
     38 
     39 def run_c(type_word: str,
     40           expressions: Iterable[str],
     41           include_path: Optional[str] = None,
     42           keep_c: bool = False) -> List[str]:
     43     """Generate and run a program to print out numerical values of C expressions."""
     44     if type_word == 'status':
     45         cast_to = 'long'
     46         printf_format = '%ld'
     47     else:
     48         cast_to = 'unsigned long'
     49         printf_format = '0x%08lx'
     50     return c_build_helper.get_c_expression_values(
     51         cast_to, printf_format,
     52         expressions,
     53         caller='test_psa_constant_names.py for {} values'.format(type_word),
     54         file_label=type_word,
     55         header='#include <psa/crypto.h>',
     56         include_path=include_path,
     57         keep_c=keep_c
     58     )
     59 
     60 NORMALIZE_STRIP_RE = re.compile(r'\s+')
     61 def normalize(expr: str) -> str:
     62     """Normalize the C expression so as not to care about trivial differences.
     63 
     64     Currently "trivial differences" means whitespace.
     65     """
     66     return re.sub(NORMALIZE_STRIP_RE, '', expr)
     67 
     68 ALG_TRUNCATED_TO_SELF_RE = \
     69     re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\('
     70                r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)'
     71                r', *16\)\Z')
     72 
     73 def is_simplifiable(expr: str) -> bool:
     74     """Determine whether an expression is simplifiable.
     75 
     76     Simplifiable expressions can't be output in their input form, since
     77     the output will be the simple form. Therefore they must be excluded
     78     from testing.
     79     """
     80     if ALG_TRUNCATED_TO_SELF_RE.match(expr):
     81         return True
     82     return False
     83 
     84 def collect_values(inputs: InputsForTest,
     85                    type_word: str,
     86                    include_path: Optional[str] = None,
     87                    keep_c: bool = False) -> Tuple[List[str], List[str]]:
     88     """Generate expressions using known macro names and calculate their values.
     89 
     90     Return a list of pairs of (expr, value) where expr is an expression and
     91     value is a string representation of its integer value.
     92     """
     93     names = inputs.get_names(type_word)
     94     expressions = sorted(expr
     95                          for expr in inputs.generate_expressions(names)
     96                          if not is_simplifiable(expr))
     97     values = run_c(type_word, expressions,
     98                    include_path=include_path, keep_c=keep_c)
     99     return expressions, values
    100 
    101 class Tests:
    102     """An object representing tests and their results."""
    103 
    104     Error = namedtuple('Error',
    105                        ['type', 'expression', 'value', 'output'])
    106 
    107     def __init__(self, options) -> None:
    108         self.options = options
    109         self.count = 0
    110         self.errors = [] #type: List[Tests.Error]
    111 
    112     def run_one(self, inputs: InputsForTest, type_word: str) -> None:
    113         """Test psa_constant_names for the specified type.
    114 
    115         Run the program on the names for this type.
    116         Use the inputs to figure out what arguments to pass to macros that
    117         take arguments.
    118         """
    119         expressions, values = collect_values(inputs, type_word,
    120                                              include_path=self.options.include,
    121                                              keep_c=self.options.keep_c)
    122         output_bytes = subprocess.check_output([self.options.program,
    123                                                 type_word] + values)
    124         output = output_bytes.decode('ascii')
    125         outputs = output.strip().split('\n')
    126         self.count += len(expressions)
    127         for expr, value, output in zip(expressions, values, outputs):
    128             if self.options.show:
    129                 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
    130             if normalize(expr) != normalize(output):
    131                 self.errors.append(self.Error(type=type_word,
    132                                               expression=expr,
    133                                               value=value,
    134                                               output=output))
    135 
    136     def run_all(self, inputs: InputsForTest) -> None:
    137         """Run psa_constant_names on all the gathered inputs."""
    138         for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
    139                           'key_type', 'key_usage']:
    140             self.run_one(inputs, type_word)
    141 
    142     def report(self, out: typing_util.Writable) -> None:
    143         """Describe each case where the output is not as expected.
    144 
    145         Write the errors to ``out``.
    146         Also write a total.
    147         """
    148         for error in self.errors:
    149             out.write('For {} "{}", got "{}" (value: {})\n'
    150                       .format(error.type, error.expression,
    151                               error.output, error.value))
    152         out.write('{} test cases'.format(self.count))
    153         if self.errors:
    154             out.write(', {} FAIL\n'.format(len(self.errors)))
    155         else:
    156             out.write(' PASS\n')
    157 
    158 HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
    159 
    160 if build_tree.is_mbedtls_3_6():
    161     TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
    162 else:
    163     TEST_SUITES = ['tf-psa-crypto/tests/suites/test_suite_psa_crypto_metadata.data']
    164 
    165 def main():
    166     parser = argparse.ArgumentParser(description=globals()['__doc__'])
    167     if build_tree.is_mbedtls_3_6():
    168         parser.add_argument('--include', '-I',
    169                             action='append', default=['include'],
    170                             help='Directory for header files')
    171     else:
    172         parser.add_argument('--include', '-I',
    173                             action='append', default=['tf-psa-crypto/include',
    174                                                       'tf-psa-crypto/drivers/builtin/include',
    175                                                       'tf-psa-crypto/drivers/everest/include',
    176                                                       'tf-psa-crypto/drivers/everest/include/' +
    177                                                       'tf-psa-crypto/private',
    178                                                       'include'],
    179                             help='Directory for header files')
    180     parser.add_argument('--keep-c',
    181                         action='store_true', dest='keep_c', default=False,
    182                         help='Keep the intermediate C file')
    183     parser.add_argument('--no-keep-c',
    184                         action='store_false', dest='keep_c',
    185                         help='Don\'t keep the intermediate C file (default)')
    186     if build_tree.is_mbedtls_3_6():
    187         parser.add_argument('--program',
    188                             default='programs/psa/psa_constant_names',
    189                             help='Program to test')
    190     else:
    191         parser.add_argument('--program',
    192                             default='tf-psa-crypto/programs/psa/psa_constant_names',
    193                             help='Program to test')
    194     parser.add_argument('--show',
    195                         action='store_true',
    196                         help='Show tested values on stdout')
    197     parser.add_argument('--no-show',
    198                         action='store_false', dest='show',
    199                         help='Don\'t show tested values (default)')
    200     options = parser.parse_args()
    201     headers = [os.path.join(options.include[0], h) for h in HEADERS]
    202     inputs = gather_inputs(headers, TEST_SUITES)
    203     tests = Tests(options)
    204     tests.run_all(inputs)
    205     tests.report(sys.stdout)
    206     if tests.errors:
    207         sys.exit(1)
    208 
    209 if __name__ == '__main__':
    210     main()