quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

psa_storage.py (8944B)


      1 """Knowledge about the PSA key store as implemented in Mbed TLS.
      2 
      3 Note that if you need to make a change that affects how keys are
      4 stored, this may indicate that the key store is changing in a
      5 backward-incompatible way! Think carefully about backward compatibility
      6 before changing how test data is constructed or validated.
      7 """
      8 
      9 # Copyright The Mbed TLS Contributors
     10 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
     11 #
     12 
     13 import re
     14 import struct
     15 from typing import Dict, List, Optional, Set, Union
     16 import unittest
     17 
     18 from . import c_build_helper
     19 from . import build_tree
     20 
     21 
     22 class Expr:
     23     """Representation of a C expression with a known or knowable numerical value."""
     24 
     25     def __init__(self, content: Union[int, str]):
     26         if isinstance(content, int):
     27             digits = 8 if content > 0xffff else 4
     28             self.string = '{0:#0{1}x}'.format(content, digits + 2)
     29             self.value_if_known = content #type: Optional[int]
     30         else:
     31             self.string = content
     32             self.unknown_values.add(self.normalize(content))
     33             self.value_if_known = None
     34 
     35     value_cache = {} #type: Dict[str, int]
     36     """Cache of known values of expressions."""
     37 
     38     unknown_values = set() #type: Set[str]
     39     """Expressions whose values are not present in `value_cache` yet."""
     40 
     41     def update_cache(self) -> None:
     42         """Update `value_cache` for expressions registered in `unknown_values`."""
     43         expressions = sorted(self.unknown_values)
     44         # Temporary, while Mbed TLS does not just rely on the TF-PSA-Crypto
     45         # build system to build its crypto library. When it does, the first
     46         # case can just be removed.
     47 
     48         if build_tree.looks_like_root('.'):
     49             includes = ['include']
     50             if build_tree.looks_like_tf_psa_crypto_root('.'):
     51                 includes.append('drivers/builtin/include')
     52                 includes.append('drivers/everest/include')
     53                 includes.append('drivers/everest/include/tf-psa-crypto/private/')
     54             elif not build_tree.is_mbedtls_3_6():
     55                 includes.append('tf-psa-crypto/include')
     56                 includes.append('tf-psa-crypto/drivers/builtin/include')
     57                 includes.append('tf-psa-crypto/drivers/everest/include')
     58                 includes.append('tf-psa-crypto/drivers/everest/include/tf-psa-crypto/private/')
     59 
     60         values = c_build_helper.get_c_expression_values(
     61             'unsigned long', '%lu',
     62             expressions,
     63             header="""
     64             #include <psa/crypto.h>
     65             """,
     66             include_path=includes) #type: List[str]
     67         for e, v in zip(expressions, values):
     68             self.value_cache[e] = int(v, 0)
     69         self.unknown_values.clear()
     70 
     71     @staticmethod
     72     def normalize(string: str) -> str:
     73         """Put the given C expression in a canonical form.
     74 
     75         This function is only intended to give correct results for the
     76         relatively simple kind of C expression typically used with this
     77         module.
     78         """
     79         return re.sub(r'\s+', r'', string)
     80 
     81     def value(self) -> int:
     82         """Return the numerical value of the expression."""
     83         if self.value_if_known is None:
     84             if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
     85                 return int(self.string, 0)
     86             normalized = self.normalize(self.string)
     87             if normalized not in self.value_cache:
     88                 self.update_cache()
     89             self.value_if_known = self.value_cache[normalized]
     90         return self.value_if_known
     91 
     92 Exprable = Union[str, int, Expr]
     93 """Something that can be converted to a C expression with a known numerical value."""
     94 
     95 def as_expr(thing: Exprable) -> Expr:
     96     """Return an `Expr` object for `thing`.
     97 
     98     If `thing` is already an `Expr` object, return it. Otherwise build a new
     99     `Expr` object from `thing`. `thing` can be an integer or a string that
    100     contains a C expression.
    101     """
    102     if isinstance(thing, Expr):
    103         return thing
    104     else:
    105         return Expr(thing)
    106 
    107 
    108 class Key:
    109     """Representation of a PSA crypto key object and its storage encoding.
    110     """
    111 
    112     LATEST_VERSION = 0
    113     """The latest version of the storage format."""
    114 
    115     def __init__(self, *,
    116                  version: Optional[int] = None,
    117                  id: Optional[int] = None, #pylint: disable=redefined-builtin
    118                  lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
    119                  type: Exprable, #pylint: disable=redefined-builtin
    120                  bits: int,
    121                  usage: Exprable, alg: Exprable, alg2: Exprable,
    122                  material: bytes #pylint: disable=used-before-assignment
    123                 ) -> None:
    124         self.version = self.LATEST_VERSION if version is None else version
    125         self.id = id #pylint: disable=invalid-name #type: Optional[int]
    126         self.lifetime = as_expr(lifetime) #type: Expr
    127         self.type = as_expr(type) #type: Expr
    128         self.bits = bits #type: int
    129         self.usage = as_expr(usage) #type: Expr
    130         self.alg = as_expr(alg) #type: Expr
    131         self.alg2 = as_expr(alg2) #type: Expr
    132         self.material = material #type: bytes
    133 
    134     MAGIC = b'PSA\000KEY\000'
    135 
    136     @staticmethod
    137     def pack(
    138             fmt: str,
    139             *args: Union[int, Expr]
    140     ) -> bytes: #pylint: disable=used-before-assignment
    141         """Pack the given arguments into a byte string according to the given format.
    142 
    143         This function is similar to `struct.pack`, but with the following differences:
    144         * All integer values are encoded with standard sizes and in
    145           little-endian representation. `fmt` must not include an endianness
    146           prefix.
    147         * Arguments can be `Expr` objects instead of integers.
    148         * Only integer-valued elements are supported.
    149         """
    150         return struct.pack('<' + fmt, # little-endian, standard sizes
    151                            *[arg.value() if isinstance(arg, Expr) else arg
    152                              for arg in args])
    153 
    154     def bytes(self) -> bytes:
    155         """Return the representation of the key in storage as a byte array.
    156 
    157         This is the content of the PSA storage file. When PSA storage is
    158         implemented over stdio files, this does not include any wrapping made
    159         by the PSA-storage-over-stdio-file implementation.
    160 
    161         Note that if you need to make a change in this function,
    162         this may indicate that the key store is changing in a
    163         backward-incompatible way! Think carefully about backward
    164         compatibility before making any change here.
    165         """
    166         header = self.MAGIC + self.pack('L', self.version)
    167         if self.version == 0:
    168             attributes = self.pack('LHHLLL',
    169                                    self.lifetime, self.type, self.bits,
    170                                    self.usage, self.alg, self.alg2)
    171             material = self.pack('L', len(self.material)) + self.material
    172         else:
    173             raise NotImplementedError
    174         return header + attributes + material
    175 
    176     def hex(self) -> str:
    177         """Return the representation of the key as a hexadecimal string.
    178 
    179         This is the hexadecimal representation of `self.bytes`.
    180         """
    181         return self.bytes().hex()
    182 
    183     def location_value(self) -> int:
    184         """The numerical value of the location encoded in the key's lifetime."""
    185         return self.lifetime.value() >> 8
    186 
    187 
    188 class TestKey(unittest.TestCase):
    189     # pylint: disable=line-too-long
    190     """A few smoke tests for the functionality of the `Key` class."""
    191 
    192     def test_numerical(self):
    193         key = Key(version=0,
    194                   id=1, lifetime=0x00000001,
    195                   type=0x2400, bits=128,
    196                   usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
    197                   material=b'@ABCDEFGHIJKLMNO')
    198         expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
    199         self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
    200         self.assertEqual(key.hex(), expected_hex)
    201 
    202     def test_names(self):
    203         length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
    204         key = Key(version=0,
    205                   id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
    206                   type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
    207                   usage=0, alg=0, alg2=0,
    208                   material=b'\x00' * length)
    209         expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
    210         self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
    211         self.assertEqual(key.hex(), expected_hex)
    212 
    213     def test_defaults(self):
    214         key = Key(type=0x1001, bits=8,
    215                   usage=0, alg=0, alg2=0,
    216                   material=b'\x2a')
    217         expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
    218         self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
    219         self.assertEqual(key.hex(), expected_hex)