quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

audit-validity-dates.py (16998B)


      1 #!/usr/bin/env python3
      2 #
      3 # Copyright The Mbed TLS Contributors
      4 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
      5 
      6 """Audit validity date of X509 crt/crl/csr.
      7 
      8 This script is used to audit the validity date of crt/crl/csr used for testing.
      9 It prints the information about X.509 objects excluding the objects that
     10 are valid throughout the desired validity period. The data are collected
     11 from framework/data_files/ and tests/suites/*.data files by default.
     12 """
     13 
     14 import os
     15 import re
     16 import typing
     17 import argparse
     18 import datetime
     19 import glob
     20 import logging
     21 import hashlib
     22 from enum import Enum
     23 
     24 # The script requires cryptography >= 35.0.0 which is only available
     25 # for Python >= 3.6.
     26 import cryptography
     27 from cryptography import x509
     28 
     29 from generate_test_code import FileWrapper
     30 
     31 import scripts_path # pylint: disable=unused-import
     32 from mbedtls_framework import build_tree
     33 from mbedtls_framework import logging_util
     34 
     35 def check_cryptography_version():
     36     match = re.match(r'^[0-9]+', cryptography.__version__)
     37     if match is None or int(match.group(0)) < 35:
     38         raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
     39                         + "({} is too old)".format(cryptography.__version__))
     40 
     41 class DataType(Enum):
     42     CRT = 1 # Certificate
     43     CRL = 2 # Certificate Revocation List
     44     CSR = 3 # Certificate Signing Request
     45 
     46 
     47 class DataFormat(Enum):
     48     PEM = 1 # Privacy-Enhanced Mail
     49     DER = 2 # Distinguished Encoding Rules
     50 
     51 
     52 class AuditData:
     53     """Store data location, type and validity period of X.509 objects."""
     54     #pylint: disable=too-few-public-methods
     55     def __init__(self, data_type: DataType, x509_obj):
     56         self.data_type = data_type
     57         # the locations that the x509 object could be found
     58         self.locations = [] # type: typing.List[str]
     59         self.fill_validity_duration(x509_obj)
     60         self._obj = x509_obj
     61         encoding = cryptography.hazmat.primitives.serialization.Encoding.DER
     62         self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest()
     63 
     64     @property
     65     def identifier(self):
     66         """
     67         Identifier of the underlying X.509 object, which is consistent across
     68         different runs.
     69         """
     70         return self._identifier
     71 
     72     def fill_validity_duration(self, x509_obj):
     73         """Read validity period from an X.509 object."""
     74         # Certificate expires after "not_valid_after"
     75         # Certificate is invalid before "not_valid_before"
     76         if self.data_type == DataType.CRT:
     77             self.not_valid_after = x509_obj.not_valid_after
     78             self.not_valid_before = x509_obj.not_valid_before
     79         # CertificateRevocationList expires after "next_update"
     80         # CertificateRevocationList is invalid before "last_update"
     81         elif self.data_type == DataType.CRL:
     82             self.not_valid_after = x509_obj.next_update
     83             self.not_valid_before = x509_obj.last_update
     84         # CertificateSigningRequest is always valid.
     85         elif self.data_type == DataType.CSR:
     86             self.not_valid_after = datetime.datetime.max
     87             self.not_valid_before = datetime.datetime.min
     88         else:
     89             raise ValueError("Unsupported file_type: {}".format(self.data_type))
     90 
     91 
     92 class X509Parser:
     93     """A parser class to parse crt/crl/csr file or data in PEM/DER format."""
     94     PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}'
     95     PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
     96     PEM_TAGS = {
     97         DataType.CRT: 'CERTIFICATE',
     98         DataType.CRL: 'X509 CRL',
     99         DataType.CSR: 'CERTIFICATE REQUEST'
    100     }
    101 
    102     def __init__(self,
    103                  backends:
    104                  typing.Dict[DataType,
    105                              typing.Dict[DataFormat,
    106                                          typing.Callable[[bytes], object]]]) \
    107     -> None:
    108         self.backends = backends
    109         self.__generate_parsers()
    110 
    111     def __generate_parser(self, data_type: DataType):
    112         """Parser generator for a specific DataType"""
    113         tag = self.PEM_TAGS[data_type]
    114         pem_loader = self.backends[data_type][DataFormat.PEM]
    115         der_loader = self.backends[data_type][DataFormat.DER]
    116         def wrapper(data: bytes):
    117             pem_type = X509Parser.pem_data_type(data)
    118             # It is in PEM format with target tag
    119             if pem_type == tag:
    120                 return pem_loader(data)
    121             # It is in PEM format without target tag
    122             if pem_type:
    123                 return None
    124             # It might be in DER format
    125             try:
    126                 result = der_loader(data)
    127             except ValueError:
    128                 result = None
    129             return result
    130         wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag)
    131         return wrapper
    132 
    133     def __generate_parsers(self):
    134         """Generate parsers for all support DataType"""
    135         self.parsers = {}
    136         for data_type, _ in self.PEM_TAGS.items():
    137             self.parsers[data_type] = self.__generate_parser(data_type)
    138 
    139     def __getitem__(self, item):
    140         return self.parsers[item]
    141 
    142     @staticmethod
    143     def pem_data_type(data: bytes) -> typing.Optional[str]:
    144         """Get the tag from the data in PEM format
    145 
    146         :param data: data to be checked in binary mode.
    147         :return: PEM tag or "" when no tag detected.
    148         """
    149         m = re.search(X509Parser.PEM_TAG_REGEX, data)
    150         if m is not None:
    151             return m.group('type').decode('UTF-8')
    152         else:
    153             return None
    154 
    155     @staticmethod
    156     def check_hex_string(hex_str: str) -> bool:
    157         """Check if the hex string is possibly DER data."""
    158         hex_len = len(hex_str)
    159         # At least 6 hex char for 3 bytes: Type + Length + Content
    160         if hex_len < 6:
    161             return False
    162         # Check if Type (1 byte) is SEQUENCE.
    163         if hex_str[0:2] != '30':
    164             return False
    165         # Check LENGTH (1 byte) value
    166         content_len = int(hex_str[2:4], base=16)
    167         consumed = 4
    168         if content_len in (128, 255):
    169             # Indefinite or Reserved
    170             return False
    171         elif content_len > 127:
    172             # Definite, Long
    173             length_len = (content_len - 128) * 2
    174             content_len = int(hex_str[consumed:consumed+length_len], base=16)
    175             consumed += length_len
    176         # Check LENGTH
    177         if hex_len != content_len * 2 + consumed:
    178             return False
    179         return True
    180 
    181 
    182 class Auditor:
    183     """
    184     A base class that uses X509Parser to parse files to a list of AuditData.
    185 
    186     A subclass must implement the following methods:
    187       - collect_default_files: Return a list of file names that are defaultly
    188         used for parsing (auditing). The list will be stored in
    189         Auditor.default_files.
    190       - parse_file: Method that parses a single file to a list of AuditData.
    191 
    192     A subclass may override the following methods:
    193       - parse_bytes: Defaultly, it parses `bytes` that contains only one valid
    194         X.509 data(DER/PEM format) to an X.509 object.
    195       - walk_all: Defaultly, it iterates over all the files in the provided
    196         file name list, calls `parse_file` for each file and stores the results
    197         by extending the `results` passed to the function.
    198     """
    199     def __init__(self, logger):
    200         self.logger = logger
    201         self.default_files = self.collect_default_files()
    202         self.parser = X509Parser({
    203             DataType.CRT: {
    204                 DataFormat.PEM: x509.load_pem_x509_certificate,
    205                 DataFormat.DER: x509.load_der_x509_certificate
    206             },
    207             DataType.CRL: {
    208                 DataFormat.PEM: x509.load_pem_x509_crl,
    209                 DataFormat.DER: x509.load_der_x509_crl
    210             },
    211             DataType.CSR: {
    212                 DataFormat.PEM: x509.load_pem_x509_csr,
    213                 DataFormat.DER: x509.load_der_x509_csr
    214             },
    215         })
    216 
    217     def collect_default_files(self) -> typing.List[str]:
    218         """Collect the default files for parsing."""
    219         raise NotImplementedError
    220 
    221     def parse_file(self, filename: str) -> typing.List[AuditData]:
    222         """
    223         Parse a list of AuditData from file.
    224 
    225         :param filename: name of the file to parse.
    226         :return list of AuditData parsed from the file.
    227         """
    228         raise NotImplementedError
    229 
    230     def parse_bytes(self, data: bytes):
    231         """Parse AuditData from bytes."""
    232         for data_type in list(DataType):
    233             try:
    234                 result = self.parser[data_type](data)
    235             except ValueError as val_error:
    236                 result = None
    237                 self.logger.warning(val_error)
    238             if result is not None:
    239                 audit_data = AuditData(data_type, result)
    240                 return audit_data
    241         return None
    242 
    243     def walk_all(self,
    244                  results: typing.Dict[str, AuditData],
    245                  file_list: typing.Optional[typing.List[str]] = None) \
    246         -> None:
    247         """
    248         Iterate over all the files in the list and get audit data. The
    249         results will be written to `results` passed to this function.
    250 
    251         :param results: The dictionary used to store the parsed
    252                         AuditData. The keys of this dictionary should
    253                         be the identifier of the AuditData.
    254         """
    255         if file_list is None:
    256             file_list = self.default_files
    257         for filename in file_list:
    258             data_list = self.parse_file(filename)
    259             for d in data_list:
    260                 if d.identifier in results:
    261                     results[d.identifier].locations.extend(d.locations)
    262                 else:
    263                     results[d.identifier] = d
    264 
    265     @staticmethod
    266     def find_test_dir():
    267         """Get the relative path for the Mbed TLS test directory."""
    268         return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests')
    269 
    270 
    271 class TestDataAuditor(Auditor):
    272     """Class for auditing files in `framework/data_files/`"""
    273 
    274     def collect_default_files(self):
    275         """Collect all files in `framework/data_files/`"""
    276         test_data_glob = os.path.join(build_tree.guess_mbedtls_root(),
    277                                       'framework', 'data_files/**')
    278         data_files = [f for f in glob.glob(test_data_glob, recursive=True)
    279                       if os.path.isfile(f)]
    280         return data_files
    281 
    282     def parse_file(self, filename: str) -> typing.List[AuditData]:
    283         """
    284         Parse a list of AuditData from data file.
    285 
    286         :param filename: name of the file to parse.
    287         :return list of AuditData parsed from the file.
    288         """
    289         with open(filename, 'rb') as f:
    290             data = f.read()
    291 
    292         results = []
    293         # Try to parse all PEM blocks.
    294         is_pem = False
    295         for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
    296             is_pem = True
    297             result = self.parse_bytes(data[m.start():m.end()])
    298             if result is not None:
    299                 result.locations.append("{}#{}".format(filename, idx))
    300                 results.append(result)
    301 
    302         # Might be DER format.
    303         if not is_pem:
    304             result = self.parse_bytes(data)
    305             if result is not None:
    306                 result.locations.append("{}".format(filename))
    307                 results.append(result)
    308 
    309         return results
    310 
    311 
    312 def parse_suite_data(data_f):
    313     """
    314     Parses .data file for test arguments that possiblly have a
    315     valid X.509 data. If you need a more precise parser, please
    316     use generate_test_code.parse_test_data instead.
    317 
    318     :param data_f: file object of the data file.
    319     :return: Generator that yields test function argument list.
    320     """
    321     for line in data_f:
    322         line = line.strip()
    323         # Skip comments
    324         if line.startswith('#'):
    325             continue
    326 
    327         # Check parameters line
    328         match = re.search(r'\A\w+(.*:)?\"', line)
    329         if match:
    330             # Read test vectors
    331             parts = re.split(r'(?<!\\):', line)
    332             parts = [x for x in parts if x]
    333             args = parts[1:]
    334             yield args
    335 
    336 
    337 class SuiteDataAuditor(Auditor):
    338     """Class for auditing files in `tests/suites/*.data`"""
    339 
    340     def collect_default_files(self):
    341         """Collect all files in `tests/suites/*.data`"""
    342         test_dir = self.find_test_dir()
    343         suites_data_folder = os.path.join(test_dir, 'suites')
    344         data_files = glob.glob(os.path.join(suites_data_folder, '*.data'))
    345         return data_files
    346 
    347     def parse_file(self, filename: str):
    348         """
    349         Parse a list of AuditData from test suite data file.
    350 
    351         :param filename: name of the file to parse.
    352         :return list of AuditData parsed from the file.
    353         """
    354         audit_data_list = []
    355         data_f = FileWrapper(filename)
    356         for test_args in parse_suite_data(data_f):
    357             for idx, test_arg in enumerate(test_args):
    358                 match = re.match(r'"(?P<data>[0-9a-fA-F]+)"', test_arg)
    359                 if not match:
    360                     continue
    361                 if not X509Parser.check_hex_string(match.group('data')):
    362                     continue
    363                 audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
    364                 if audit_data is None:
    365                     continue
    366                 audit_data.locations.append("{}:{}:#{}".format(filename,
    367                                                                data_f.line_no,
    368                                                                idx + 1))
    369                 audit_data_list.append(audit_data)
    370 
    371         return audit_data_list
    372 
    373 
    374 def list_all(audit_data: AuditData):
    375     for loc in audit_data.locations:
    376         print("{}\t{:20}\t{:20}\t{:3}\t{}".format(
    377             audit_data.identifier,
    378             audit_data.not_valid_before.isoformat(timespec='seconds'),
    379             audit_data.not_valid_after.isoformat(timespec='seconds'),
    380             audit_data.data_type.name,
    381             loc))
    382 
    383 
    384 def main():
    385     """
    386     Perform argument parsing.
    387     """
    388     parser = argparse.ArgumentParser(description=__doc__)
    389 
    390     parser.add_argument('-a', '--all',
    391                         action='store_true',
    392                         help='list the information of all the files')
    393     parser.add_argument('-v', '--verbose',
    394                         action='store_true', dest='verbose',
    395                         help='show logs')
    396     parser.add_argument('--from', dest='start_date',
    397                         help=('Start of desired validity period (UTC, YYYY-MM-DD). '
    398                               'Default: today'),
    399                         metavar='DATE')
    400     parser.add_argument('--to', dest='end_date',
    401                         help=('End of desired validity period (UTC, YYYY-MM-DD). '
    402                               'Default: --from'),
    403                         metavar='DATE')
    404     parser.add_argument('--data-files', action='append', nargs='*',
    405                         help='data files to audit',
    406                         metavar='FILE')
    407     parser.add_argument('--suite-data-files', action='append', nargs='*',
    408                         help='suite data files to audit',
    409                         metavar='FILE')
    410 
    411     args = parser.parse_args()
    412 
    413     # start main routine
    414     # setup logger
    415     logger = logging.getLogger()
    416     logging_util.configure_logger(logger)
    417     logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR)
    418 
    419     td_auditor = TestDataAuditor(logger)
    420     sd_auditor = SuiteDataAuditor(logger)
    421 
    422     data_files = []
    423     suite_data_files = []
    424     if args.data_files is None and args.suite_data_files is None:
    425         data_files = td_auditor.default_files
    426         suite_data_files = sd_auditor.default_files
    427     else:
    428         if args.data_files is not None:
    429             data_files = [x for l in args.data_files for x in l]
    430         if args.suite_data_files is not None:
    431             suite_data_files = [x for l in args.suite_data_files for x in l]
    432 
    433     # validity period start date
    434     if args.start_date:
    435         start_date = datetime.datetime.fromisoformat(args.start_date)
    436     else:
    437         start_date = datetime.datetime.today()
    438     # validity period end date
    439     if args.end_date:
    440         end_date = datetime.datetime.fromisoformat(args.end_date)
    441     else:
    442         end_date = start_date
    443 
    444     # go through all the files
    445     audit_results = {}
    446     td_auditor.walk_all(audit_results, data_files)
    447     sd_auditor.walk_all(audit_results, suite_data_files)
    448 
    449     logger.info("Total: {} objects found!".format(len(audit_results)))
    450 
    451     # we filter out the files whose validity duration covers the provided
    452     # duration.
    453     filter_func = lambda d: (start_date < d.not_valid_before) or \
    454                             (d.not_valid_after < end_date)
    455 
    456     sortby_end = lambda d: d.not_valid_after
    457 
    458     if args.all:
    459         filter_func = None
    460 
    461     # filter and output the results
    462     for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
    463         list_all(d)
    464 
    465     logger.debug("Done!")
    466 
    467 check_cryptography_version()
    468 if __name__ == "__main__":
    469     main()