psa_test_case.py (9099B)
1 """Generate test cases for PSA API calls, with automatic dependencies. 2 """ 3 4 # Copyright The Mbed TLS Contributors 5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 6 # 7 8 import os 9 import re 10 from typing import FrozenSet, List, Optional, Set 11 12 from . import build_tree 13 from . import psa_information 14 from . import test_case 15 16 17 # Skip test cases for which the dependency symbols are not defined. 18 # We assume that this means that a required mechanism is not implemented. 19 # Note that if we erroneously skip generating test cases for 20 # mechanisms that are not implemented, this should be caught 21 # by the NOT_SUPPORTED test cases generated by generate_psa_tests.py 22 # in test_suite_psa_crypto_not_supported and test_suite_psa_crypto_op_fail: 23 # those emit tests with negative dependencies, which will not be skipped here. 24 25 def read_implemented_dependencies(acc: Set[str], filename: str) -> None: 26 with open(filename) as input_stream: 27 for line in input_stream: 28 for symbol in re.findall(r'\bPSA_WANT_\w+\b', line): 29 acc.add(symbol) 30 31 _implemented_dependencies = None #type: Optional[FrozenSet[str]] #pylint: disable=invalid-name 32 33 def find_dependencies_not_implemented(dependencies: List[str]) -> List[str]: 34 """List the dependencies that are not implemented.""" 35 global _implemented_dependencies #pylint: disable=global-statement,invalid-name 36 if _implemented_dependencies is None: 37 # Temporary, while Mbed TLS does not just rely on the TF-PSA-Crypto 38 # build system to build its crypto library. When it does, the first 39 # case can just be removed. 40 41 if build_tree.looks_like_root('.'): 42 if build_tree.looks_like_mbedtls_root('.') and \ 43 (not build_tree.is_mbedtls_3_6()): 44 include_dir = 'tf-psa-crypto/include' 45 else: 46 include_dir = 'include' 47 48 acc = set() #type: Set[str] 49 for filename in [ 50 os.path.join(include_dir, 'psa/crypto_config.h'), 51 os.path.join(include_dir, 'psa/crypto_adjust_config_synonyms.h'), 52 ]: 53 read_implemented_dependencies(acc, filename) 54 _implemented_dependencies = frozenset(acc) 55 return [dep 56 for dep in dependencies 57 if (dep not in _implemented_dependencies and 58 dep.startswith('PSA_WANT'))] 59 60 61 class TestCase(test_case.TestCase): 62 """A PSA test case with automatically inferred dependencies. 63 64 For mechanisms like ECC curves where the support status includes 65 the key bit-size, this class assumes that only one bit-size is 66 involved in a given test case. 67 """ 68 69 def __init__(self, dependency_prefix: Optional[str] = None) -> None: 70 """Construct a test case for a PSA Crypto API call. 71 72 `dependency_prefix`: prefix to use in dependencies. Defaults to 73 ``'PSA_WANT_'``. Use ``'MBEDTLS_PSA_BUILTIN_'`` 74 when specifically testing builtin implementations. 75 """ 76 super().__init__() 77 del self.dependencies 78 self.manual_dependencies = [] #type: List[str] 79 self.automatic_dependencies = set() #type: Set[str] 80 self.dependency_prefix = dependency_prefix #type: Optional[str] 81 self.negated_dependencies = set() #type: Set[str] 82 self.key_bits = None #type: Optional[int] 83 self.key_pair_usage = None #type: Optional[List[str]] 84 85 def set_key_bits(self, key_bits: Optional[int]) -> None: 86 """Use the given key size for automatic dependency generation. 87 88 Call this function before set_arguments() if relevant. 89 90 This is only relevant for ECC and DH keys. For other key types, 91 this information is ignored. 92 """ 93 self.key_bits = key_bits 94 95 def set_key_pair_usage(self, key_pair_usage: Optional[List[str]]) -> None: 96 """Use the given suffixes for key pair dependencies. 97 98 Call this function before set_arguments() if relevant. 99 100 This is only relevant for key pair types. For other key types, 101 this information is ignored. 102 """ 103 self.key_pair_usage = key_pair_usage 104 105 def infer_dependencies(self, arguments: List[str]) -> List[str]: 106 """Infer dependencies based on the test case arguments.""" 107 dependencies = psa_information.automatic_dependencies(*arguments, 108 prefix=self.dependency_prefix) 109 if self.key_bits is not None: 110 dependencies = psa_information.finish_family_dependencies(dependencies, 111 self.key_bits) 112 if self.key_pair_usage is not None: 113 dependencies = psa_information.fix_key_pair_dependencies(dependencies, 114 self.key_pair_usage) 115 if 'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE' in dependencies and \ 116 'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE' not in self.negated_dependencies and \ 117 self.key_bits is not None: 118 size_dependency = ('PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS <= ' + 119 str(self.key_bits)) 120 dependencies.append(size_dependency) 121 return dependencies 122 123 def assumes_not_supported(self, name: str) -> None: 124 """Negate the given mechanism for automatic dependency generation. 125 126 `name` can be either a dependency symbol (``PSA_WANT_xxx``) or 127 a mechanism name (``PSA_KEY_TYPE_xxx``, etc.). 128 129 Call this function before set_arguments() for a test case that should 130 run if the given mechanism is not supported. 131 132 Call modifiers such as set_key_bits() and set_key_pair_usage() before 133 calling this method, if applicable. 134 135 A mechanism is a PSA_XXX symbol, e.g. PSA_KEY_TYPE_AES, PSA_ALG_HMAC, 136 etc. For mechanisms like ECC curves where the support status includes 137 the key bit-size, this class assumes that only one bit-size is 138 involved in a given test case. 139 """ 140 if name.startswith('PSA_WANT_'): 141 self.negated_dependencies.add(name) 142 return 143 if name == 'PSA_KEY_TYPE_RSA_KEY_PAIR' and \ 144 self.key_bits is not None and \ 145 self.key_pair_usage == ['GENERATE']: 146 # When RSA key pair generation is not supported, it could be 147 # due to the specific key size is out of range, or because 148 # RSA key pair generation itself is not supported. Assume the 149 # latter. 150 dep = psa_information.psa_want_symbol(name, prefix=self.dependency_prefix) 151 152 self.negated_dependencies.add(dep + '_GENERATE') 153 return 154 dependencies = self.infer_dependencies([name]) 155 # * If we have more than one dependency to negate, the result would 156 # say that all of the dependencies are disabled, which is not 157 # a desirable outcome: the negation of (A and B) is (!A or !B), 158 # not (!A and !B). 159 # * If we have no dependency to negate, the result wouldn't be a 160 # not-supported case. 161 # Assert that we don't reach either such case. 162 assert len(dependencies) == 1 163 self.negated_dependencies.add(dependencies[0]) 164 165 def set_arguments(self, arguments: List[str]) -> None: 166 """Set test case arguments and automatically infer dependencies.""" 167 super().set_arguments(arguments) 168 dependencies = self.infer_dependencies(arguments) 169 for i in range(len(dependencies)): #pylint: disable=consider-using-enumerate 170 if dependencies[i] in self.negated_dependencies: 171 dependencies[i] = '!' + dependencies[i] 172 self.skip_if_any_not_implemented(dependencies) 173 self.automatic_dependencies.update(dependencies) 174 175 def set_dependencies(self, dependencies: List[str]) -> None: 176 """Override any previously added automatic or manual dependencies. 177 178 Also override any previous instruction to skip the test case. 179 """ 180 self.manual_dependencies = dependencies 181 self.automatic_dependencies.clear() 182 self.skip_reasons = [] 183 184 def add_dependencies(self, dependencies: List[str]) -> None: 185 """Add manual dependencies.""" 186 self.manual_dependencies += dependencies 187 188 def get_dependencies(self) -> List[str]: 189 # Make the output independent of the order in which the dependencies 190 # are calculated by the script. Also avoid duplicates. This makes 191 # the output robust with respect to refactoring of the scripts. 192 dependencies = set(self.manual_dependencies) 193 dependencies.update(self.automatic_dependencies) 194 return sorted(dependencies) 195 196 def skip_if_any_not_implemented(self, dependencies: List[str]) -> None: 197 """Skip the test case if any of the given dependencies is not implemented.""" 198 not_implemented = find_dependencies_not_implemented(dependencies) 199 for dep in not_implemented: 200 self.skip_because('not implemented: ' + dep)