test_psa_constant_names.py (8645B)
1 #!/usr/bin/env python3 2 """Test the program psa_constant_names. 3 Gather constant names from header files and test cases. Compile a C program 4 to print out their numerical values, feed these numerical values to 5 psa_constant_names, and check that the output is the original name. 6 Return 0 if all test cases pass, 1 if the output was not always as expected, 7 or 1 (with a Python backtrace) if there was an operational error. 8 """ 9 10 # Copyright The Mbed TLS Contributors 11 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 12 13 import argparse 14 from collections import namedtuple 15 import os 16 import re 17 import subprocess 18 import sys 19 from typing import Iterable, List, Optional, Tuple 20 21 from mbedtls_framework import build_tree 22 from mbedtls_framework import c_build_helper 23 from mbedtls_framework.macro_collector import InputsForTest, PSAMacroEnumerator 24 from mbedtls_framework import typing_util 25 26 def gather_inputs(headers: Iterable[str], 27 test_suites: Iterable[str], 28 inputs_class=InputsForTest) -> PSAMacroEnumerator: 29 """Read the list of inputs to test psa_constant_names with.""" 30 inputs = inputs_class() 31 for header in headers: 32 inputs.parse_header(header) 33 for test_cases in test_suites: 34 inputs.parse_test_cases(test_cases) 35 inputs.add_numerical_values() 36 inputs.gather_arguments() 37 return inputs 38 39 def run_c(type_word: str, 40 expressions: Iterable[str], 41 include_path: Optional[str] = None, 42 keep_c: bool = False) -> List[str]: 43 """Generate and run a program to print out numerical values of C expressions.""" 44 if type_word == 'status': 45 cast_to = 'long' 46 printf_format = '%ld' 47 else: 48 cast_to = 'unsigned long' 49 printf_format = '0x%08lx' 50 return c_build_helper.get_c_expression_values( 51 cast_to, printf_format, 52 expressions, 53 caller='test_psa_constant_names.py for {} values'.format(type_word), 54 file_label=type_word, 55 header='#include <psa/crypto.h>', 56 include_path=include_path, 57 keep_c=keep_c 58 ) 59 60 NORMALIZE_STRIP_RE = re.compile(r'\s+') 61 def normalize(expr: str) -> str: 62 """Normalize the C expression so as not to care about trivial differences. 63 64 Currently "trivial differences" means whitespace. 65 """ 66 return re.sub(NORMALIZE_STRIP_RE, '', expr) 67 68 ALG_TRUNCATED_TO_SELF_RE = \ 69 re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\(' 70 r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)' 71 r', *16\)\Z') 72 73 def is_simplifiable(expr: str) -> bool: 74 """Determine whether an expression is simplifiable. 75 76 Simplifiable expressions can't be output in their input form, since 77 the output will be the simple form. Therefore they must be excluded 78 from testing. 79 """ 80 if ALG_TRUNCATED_TO_SELF_RE.match(expr): 81 return True 82 return False 83 84 def collect_values(inputs: InputsForTest, 85 type_word: str, 86 include_path: Optional[str] = None, 87 keep_c: bool = False) -> Tuple[List[str], List[str]]: 88 """Generate expressions using known macro names and calculate their values. 89 90 Return a list of pairs of (expr, value) where expr is an expression and 91 value is a string representation of its integer value. 92 """ 93 names = inputs.get_names(type_word) 94 expressions = sorted(expr 95 for expr in inputs.generate_expressions(names) 96 if not is_simplifiable(expr)) 97 values = run_c(type_word, expressions, 98 include_path=include_path, keep_c=keep_c) 99 return expressions, values 100 101 class Tests: 102 """An object representing tests and their results.""" 103 104 Error = namedtuple('Error', 105 ['type', 'expression', 'value', 'output']) 106 107 def __init__(self, options) -> None: 108 self.options = options 109 self.count = 0 110 self.errors = [] #type: List[Tests.Error] 111 112 def run_one(self, inputs: InputsForTest, type_word: str) -> None: 113 """Test psa_constant_names for the specified type. 114 115 Run the program on the names for this type. 116 Use the inputs to figure out what arguments to pass to macros that 117 take arguments. 118 """ 119 expressions, values = collect_values(inputs, type_word, 120 include_path=self.options.include, 121 keep_c=self.options.keep_c) 122 output_bytes = subprocess.check_output([self.options.program, 123 type_word] + values) 124 output = output_bytes.decode('ascii') 125 outputs = output.strip().split('\n') 126 self.count += len(expressions) 127 for expr, value, output in zip(expressions, values, outputs): 128 if self.options.show: 129 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output)) 130 if normalize(expr) != normalize(output): 131 self.errors.append(self.Error(type=type_word, 132 expression=expr, 133 value=value, 134 output=output)) 135 136 def run_all(self, inputs: InputsForTest) -> None: 137 """Run psa_constant_names on all the gathered inputs.""" 138 for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 139 'key_type', 'key_usage']: 140 self.run_one(inputs, type_word) 141 142 def report(self, out: typing_util.Writable) -> None: 143 """Describe each case where the output is not as expected. 144 145 Write the errors to ``out``. 146 Also write a total. 147 """ 148 for error in self.errors: 149 out.write('For {} "{}", got "{}" (value: {})\n' 150 .format(error.type, error.expression, 151 error.output, error.value)) 152 out.write('{} test cases'.format(self.count)) 153 if self.errors: 154 out.write(', {} FAIL\n'.format(len(self.errors))) 155 else: 156 out.write(' PASS\n') 157 158 HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] 159 160 if build_tree.is_mbedtls_3_6(): 161 TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] 162 else: 163 TEST_SUITES = ['tf-psa-crypto/tests/suites/test_suite_psa_crypto_metadata.data'] 164 165 def main(): 166 parser = argparse.ArgumentParser(description=globals()['__doc__']) 167 if build_tree.is_mbedtls_3_6(): 168 parser.add_argument('--include', '-I', 169 action='append', default=['include'], 170 help='Directory for header files') 171 else: 172 parser.add_argument('--include', '-I', 173 action='append', default=['tf-psa-crypto/include', 174 'tf-psa-crypto/drivers/builtin/include', 175 'tf-psa-crypto/drivers/everest/include', 176 'tf-psa-crypto/drivers/everest/include/' + 177 'tf-psa-crypto/private', 178 'include'], 179 help='Directory for header files') 180 parser.add_argument('--keep-c', 181 action='store_true', dest='keep_c', default=False, 182 help='Keep the intermediate C file') 183 parser.add_argument('--no-keep-c', 184 action='store_false', dest='keep_c', 185 help='Don\'t keep the intermediate C file (default)') 186 if build_tree.is_mbedtls_3_6(): 187 parser.add_argument('--program', 188 default='programs/psa/psa_constant_names', 189 help='Program to test') 190 else: 191 parser.add_argument('--program', 192 default='tf-psa-crypto/programs/psa/psa_constant_names', 193 help='Program to test') 194 parser.add_argument('--show', 195 action='store_true', 196 help='Show tested values on stdout') 197 parser.add_argument('--no-show', 198 action='store_false', dest='show', 199 help='Don\'t show tested values (default)') 200 options = parser.parse_args() 201 headers = [os.path.join(options.include[0], h) for h in HEADERS] 202 inputs = gather_inputs(headers, TEST_SUITES) 203 tests = Tests(options) 204 tests.run_all(inputs) 205 tests.report(sys.stdout) 206 if tests.errors: 207 sys.exit(1) 208 209 if __name__ == '__main__': 210 main()