quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

certs.py (22230B)


      1 #!/usr/bin/env python3
      2 # -*- coding: utf-8 -*-
      3 #***************************************************************************
      4 #                                  _   _ ____  _
      5 #  Project                     ___| | | |  _ \| |
      6 #                             / __| | | | |_) | |
      7 #                            | (__| |_| |  _ <| |___
      8 #                             \___|\___/|_| \_\_____|
      9 #
     10 # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
     11 #
     12 # This software is licensed as described in the file COPYING, which
     13 # you should have received as part of this distribution. The terms
     14 # are also available at https://curl.se/docs/copyright.html.
     15 #
     16 # You may opt to use, copy, modify, merge, publish, distribute and/or sell
     17 # copies of the Software, and permit persons to whom the Software is
     18 # furnished to do so, under the terms of the COPYING file.
     19 #
     20 # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
     21 # KIND, either express or implied.
     22 #
     23 # SPDX-License-Identifier: curl
     24 #
     25 ###########################################################################
     26 #
     27 import base64
     28 import ipaddress
     29 import os
     30 import re
     31 from datetime import timedelta, datetime, timezone
     32 from typing import List, Any, Optional
     33 
     34 from cryptography import x509
     35 from cryptography.hazmat.backends import default_backend
     36 from cryptography.hazmat.primitives import hashes
     37 from cryptography.hazmat.primitives._serialization import PublicFormat
     38 from cryptography.hazmat.primitives.asymmetric import ec, rsa
     39 from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
     40 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
     41 from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
     42 from cryptography.x509 import ExtendedKeyUsageOID, NameOID
     43 
     44 
     45 EC_SUPPORTED = {}
     46 EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
     47     ec.SECP192R1,
     48     ec.SECP224R1,
     49     ec.SECP256R1,
     50     ec.SECP384R1,
     51 ]])
     52 
     53 
     54 def _private_key(key_type):
     55     if isinstance(key_type, str):
     56         key_type = key_type.upper()
     57         m = re.match(r'^(RSA)?(\d+)$', key_type)
     58         if m:
     59             key_type = int(m.group(2))
     60 
     61     if isinstance(key_type, int):
     62         return rsa.generate_private_key(
     63             public_exponent=65537,
     64             key_size=key_type,
     65             backend=default_backend()
     66         )
     67     if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
     68         key_type = EC_SUPPORTED[key_type]
     69     return ec.generate_private_key(
     70         curve=key_type,
     71         backend=default_backend()
     72     )
     73 
     74 
     75 class CertificateSpec:
     76 
     77     def __init__(self, name: Optional[str] = None,
     78                  domains: Optional[List[str]] = None,
     79                  email: Optional[str] = None,
     80                  key_type: Optional[str] = None,
     81                  single_file: bool = False,
     82                  valid_from: timedelta = timedelta(days=-1),
     83                  valid_to: timedelta = timedelta(days=89),
     84                  client: bool = False,
     85                  check_valid: bool = True,
     86                  sub_specs: Optional[List['CertificateSpec']] = None):
     87         self._name = name
     88         self.domains = domains
     89         self.client = client
     90         self.email = email
     91         self.key_type = key_type
     92         self.single_file = single_file
     93         self.valid_from = valid_from
     94         self.valid_to = valid_to
     95         self.sub_specs = sub_specs
     96         self.check_valid = check_valid
     97 
     98     @property
     99     def name(self) -> Optional[str]:
    100         if self._name:
    101             return self._name
    102         elif self.domains:
    103             return self.domains[0]
    104         return None
    105 
    106     @property
    107     def type(self) -> Optional[str]:
    108         if self.domains and len(self.domains):
    109             return "server"
    110         elif self.client:
    111             return "client"
    112         elif self.name:
    113             return "ca"
    114         return None
    115 
    116 
    117 class Credentials:
    118 
    119     def __init__(self,
    120                  name: str,
    121                  cert: Any,
    122                  pkey: Any,
    123                  issuer: Optional['Credentials'] = None):
    124         self._name = name
    125         self._cert = cert
    126         self._pkey = pkey
    127         self._issuer = issuer
    128         self._cert_file = None
    129         self._pkey_file = None
    130         self._store = None
    131         self._combined_file = None
    132 
    133     @property
    134     def name(self) -> str:
    135         return self._name
    136 
    137     @property
    138     def subject(self) -> x509.Name:
    139         return self._cert.subject
    140 
    141     @property
    142     def key_type(self):
    143         if isinstance(self._pkey, RSAPrivateKey):
    144             return f"rsa{self._pkey.key_size}"
    145         elif isinstance(self._pkey, EllipticCurvePrivateKey):
    146             return f"{self._pkey.curve.name}"
    147         else:
    148             raise Exception(f"unknown key type: {self._pkey}")
    149 
    150     @property
    151     def private_key(self) -> Any:
    152         return self._pkey
    153 
    154     def pub_sha256_b64(self) -> Any:
    155         pubkey = self._pkey.public_key()
    156         sha256 = hashes.Hash(algorithm=hashes.SHA256())
    157         sha256.update(pubkey.public_bytes(
    158             encoding=Encoding.DER,
    159             format=PublicFormat.SubjectPublicKeyInfo
    160         ))
    161         return base64.b64encode(sha256.finalize()).decode('utf8')
    162 
    163     @property
    164     def certificate(self) -> Any:
    165         return self._cert
    166 
    167     @property
    168     def cert_pem(self) -> bytes:
    169         return self._cert.public_bytes(Encoding.PEM)
    170 
    171     @property
    172     def pkey_pem(self) -> bytes:
    173         return self._pkey.private_bytes(
    174             Encoding.PEM,
    175             PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
    176             NoEncryption())
    177 
    178     @property
    179     def issuer(self) -> Optional['Credentials']:
    180         return self._issuer
    181 
    182     def set_store(self, store: 'CertStore'):
    183         self._store = store
    184 
    185     def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
    186                   combined_file: Optional[str] = None):
    187         self._cert_file = cert_file
    188         self._pkey_file = pkey_file
    189         self._combined_file = combined_file
    190 
    191     @property
    192     def cert_file(self) -> str:
    193         return self._cert_file
    194 
    195     @property
    196     def pkey_file(self) -> Optional[str]:
    197         return self._pkey_file
    198 
    199     @property
    200     def combined_file(self) -> Optional[str]:
    201         return self._combined_file
    202 
    203     def get_first(self, name) -> Optional['Credentials']:
    204         creds = self._store.get_credentials_for_name(name) if self._store else []
    205         return creds[0] if len(creds) else None
    206 
    207     def get_credentials_for_name(self, name) -> List['Credentials']:
    208         return self._store.get_credentials_for_name(name) if self._store else []
    209 
    210     def issue_certs(self, specs: List[CertificateSpec],
    211                     chain: Optional[List['Credentials']] = None) -> List['Credentials']:
    212         return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
    213 
    214     def issue_cert(self, spec: CertificateSpec,
    215                    chain: Optional[List['Credentials']] = None) -> 'Credentials':
    216         key_type = spec.key_type if spec.key_type else self.key_type
    217         creds = None
    218         if self._store:
    219             creds = self._store.load_credentials(
    220                 name=spec.name, key_type=key_type, single_file=spec.single_file,
    221                 issuer=self, check_valid=spec.check_valid)
    222         if creds is None:
    223             creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
    224                                               valid_from=spec.valid_from, valid_to=spec.valid_to)
    225             if self._store:
    226                 self._store.save(creds, single_file=spec.single_file)
    227                 if spec.type == "ca":
    228                     self._store.save_chain(creds, "ca", with_root=True)
    229 
    230         if spec.sub_specs:
    231             if self._store:
    232                 sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
    233                 creds.set_store(sub_store)
    234             subchain = chain.copy() if chain else []
    235             subchain.append(self)
    236             creds.issue_certs(spec.sub_specs, chain=subchain)
    237         return creds
    238 
    239 
    240 class CertStore:
    241 
    242     def __init__(self, fpath: str):
    243         self._store_dir = fpath
    244         if not os.path.exists(self._store_dir):
    245             os.makedirs(self._store_dir)
    246         self._creds_by_name = {}
    247 
    248     @property
    249     def path(self) -> str:
    250         return self._store_dir
    251 
    252     def save(self, creds: Credentials, name: Optional[str] = None,
    253              chain: Optional[List[Credentials]] = None,
    254              single_file: bool = False) -> None:
    255         name = name if name is not None else creds.name
    256         cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
    257         pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
    258         comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
    259         if single_file:
    260             pkey_file = None
    261         with open(cert_file, "wb") as fd:
    262             fd.write(creds.cert_pem)
    263             if chain:
    264                 for c in chain:
    265                     fd.write(c.cert_pem)
    266             if pkey_file is None:
    267                 fd.write(creds.pkey_pem)
    268         if pkey_file is not None:
    269             with open(pkey_file, "wb") as fd:
    270                 fd.write(creds.pkey_pem)
    271         with open(comb_file, "wb") as fd:
    272             fd.write(creds.cert_pem)
    273             if chain:
    274                 for c in chain:
    275                     fd.write(c.cert_pem)
    276             fd.write(creds.pkey_pem)
    277         creds.set_files(cert_file, pkey_file, comb_file)
    278         self._add_credentials(name, creds)
    279 
    280     def save_chain(self, creds: Credentials, infix: str, with_root=False):
    281         name = creds.name
    282         chain = [creds]
    283         while creds.issuer is not None:
    284             creds = creds.issuer
    285             chain.append(creds)
    286         if not with_root and len(chain) > 1:
    287             chain = chain[:-1]
    288         chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
    289         with open(chain_file, "wb") as fd:
    290             for c in chain:
    291                 fd.write(c.cert_pem)
    292 
    293     def _add_credentials(self, name: str, creds: Credentials):
    294         if name not in self._creds_by_name:
    295             self._creds_by_name[name] = []
    296         self._creds_by_name[name].append(creds)
    297 
    298     def get_credentials_for_name(self, name) -> List[Credentials]:
    299         return self._creds_by_name[name] if name in self._creds_by_name else []
    300 
    301     def get_cert_file(self, name: str, key_type=None) -> str:
    302         key_infix = ".{0}".format(key_type) if key_type is not None else ""
    303         return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
    304 
    305     def get_pkey_file(self, name: str, key_type=None) -> str:
    306         key_infix = ".{0}".format(key_type) if key_type is not None else ""
    307         return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
    308 
    309     def get_combined_file(self, name: str, key_type=None) -> str:
    310         return os.path.join(self._store_dir, f'{name}.pem')
    311 
    312     def load_pem_cert(self, fpath: str) -> x509.Certificate:
    313         with open(fpath) as fd:
    314             return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
    315 
    316     def load_pem_pkey(self, fpath: str):
    317         with open(fpath) as fd:
    318             return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
    319 
    320     def load_credentials(self, name: str, key_type=None,
    321                          single_file: bool = False,
    322                          issuer: Optional[Credentials] = None,
    323                          check_valid: bool = False):
    324         cert_file = self.get_cert_file(name=name, key_type=key_type)
    325         pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
    326         comb_file = self.get_combined_file(name=name, key_type=key_type)
    327         if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
    328             cert = self.load_pem_cert(cert_file)
    329             pkey = self.load_pem_pkey(pkey_file)
    330             try:
    331                 now = datetime.now(tz=timezone.utc)
    332                 if check_valid and \
    333                     ((cert.not_valid_after_utc < now) or
    334                      (cert.not_valid_before_utc > now)):
    335                     return None
    336             except AttributeError:  # older python
    337                 now = datetime.now()
    338                 if check_valid and \
    339                         ((cert.not_valid_after < now) or
    340                          (cert.not_valid_before > now)):
    341                     return None
    342             creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
    343             creds.set_store(self)
    344             creds.set_files(cert_file, pkey_file, comb_file)
    345             self._add_credentials(name, creds)
    346             return creds
    347         return None
    348 
    349 
    350 class TestCA:
    351 
    352     @classmethod
    353     def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
    354         store = CertStore(fpath=store_dir)
    355         creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
    356         if creds is None:
    357             creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
    358             store.save(creds, name="ca")
    359             creds.set_store(store)
    360         return creds
    361 
    362     @staticmethod
    363     def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
    364                            valid_from: timedelta = timedelta(days=-1),
    365                            valid_to: timedelta = timedelta(days=89),
    366                            ) -> Credentials:
    367         """
    368         Create a certificate signed by this CA for the given domains.
    369 
    370         :returns: the certificate and private key PEM file paths
    371         """
    372         if spec.domains and len(spec.domains):
    373             creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
    374                                                     issuer=issuer, valid_from=valid_from,
    375                                                     valid_to=valid_to, key_type=key_type)
    376         elif spec.client:
    377             creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
    378                                                     email=spec.email, valid_from=valid_from,
    379                                                     valid_to=valid_to, key_type=key_type)
    380         elif spec.name:
    381             creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
    382                                                 valid_from=valid_from, valid_to=valid_to,
    383                                                 key_type=key_type)
    384         else:
    385             raise Exception(f"unrecognized certificate specification: {spec}")
    386         return creds
    387 
    388     @staticmethod
    389     def _make_x509_name(org_name: Optional[str] = None, common_name: Optional[str] = None, parent: x509.Name = None) -> x509.Name:
    390         name_pieces = []
    391         if org_name:
    392             oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
    393             name_pieces.append(x509.NameAttribute(oid, org_name))
    394         elif common_name:
    395             name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
    396         if parent:
    397             name_pieces.extend(list(parent))
    398         return x509.Name(name_pieces)
    399 
    400     @staticmethod
    401     def _make_csr(
    402             subject: x509.Name,
    403             pkey: Any,
    404             issuer_subject: Optional[Credentials],
    405             valid_from_delta: Optional[timedelta] = None,
    406             valid_until_delta: Optional[timedelta] = None
    407     ) -> x509.CertificateBuilder:
    408         pubkey = pkey.public_key()
    409         issuer_subject = issuer_subject if issuer_subject is not None else subject
    410 
    411         valid_from = datetime.now()
    412         if valid_until_delta is not None:
    413             valid_from += valid_from_delta
    414         valid_until = datetime.now()
    415         if valid_until_delta is not None:
    416             valid_until += valid_until_delta
    417 
    418         return (
    419             x509.CertificateBuilder()
    420             .subject_name(subject)
    421             .issuer_name(issuer_subject)
    422             .public_key(pubkey)
    423             .not_valid_before(valid_from)
    424             .not_valid_after(valid_until)
    425             .serial_number(x509.random_serial_number())
    426             .add_extension(
    427                 x509.SubjectKeyIdentifier.from_public_key(pubkey),
    428                 critical=False,
    429             )
    430         )
    431 
    432     @staticmethod
    433     def _add_ca_usages(csr: Any) -> Any:
    434         return csr.add_extension(
    435             x509.BasicConstraints(ca=True, path_length=9),
    436             critical=True,
    437         ).add_extension(
    438             x509.KeyUsage(
    439                 digital_signature=True,
    440                 content_commitment=False,
    441                 key_encipherment=False,
    442                 data_encipherment=False,
    443                 key_agreement=False,
    444                 key_cert_sign=True,
    445                 crl_sign=True,
    446                 encipher_only=False,
    447                 decipher_only=False),
    448             critical=True
    449         ).add_extension(
    450             x509.ExtendedKeyUsage([
    451                 ExtendedKeyUsageOID.CLIENT_AUTH,
    452                 ExtendedKeyUsageOID.SERVER_AUTH,
    453                 ExtendedKeyUsageOID.CODE_SIGNING,
    454             ]),
    455             critical=True
    456         )
    457 
    458     @staticmethod
    459     def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
    460         names = []
    461         for name in domains:
    462             try:
    463                 names.append(x509.IPAddress(ipaddress.ip_address(name)))
    464             # TODO: specify specific exceptions here
    465             except:  # noqa: E722
    466                 names.append(x509.DNSName(name))
    467 
    468         return csr.add_extension(
    469             x509.BasicConstraints(ca=False, path_length=None),
    470             critical=True,
    471         ).add_extension(
    472             x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
    473                 issuer.certificate.extensions.get_extension_for_class(
    474                     x509.SubjectKeyIdentifier).value),
    475             critical=False
    476         ).add_extension(
    477             x509.SubjectAlternativeName(names), critical=True,
    478         ).add_extension(
    479             x509.ExtendedKeyUsage([
    480                 ExtendedKeyUsageOID.SERVER_AUTH,
    481             ]),
    482             critical=False
    483         )
    484 
    485     @staticmethod
    486     def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: Optional[str] = None) -> Any:
    487         cert = csr.add_extension(
    488             x509.BasicConstraints(ca=False, path_length=None),
    489             critical=True,
    490         ).add_extension(
    491             x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
    492                 issuer.certificate.extensions.get_extension_for_class(
    493                     x509.SubjectKeyIdentifier).value),
    494             critical=False
    495         )
    496         if rfc82name:
    497             cert.add_extension(
    498                 x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
    499                 critical=True,
    500             )
    501         cert.add_extension(
    502             x509.ExtendedKeyUsage([
    503                 ExtendedKeyUsageOID.CLIENT_AUTH,
    504             ]),
    505             critical=True
    506         )
    507         return cert
    508 
    509     @staticmethod
    510     def _make_ca_credentials(name, key_type: Any,
    511                              issuer: Optional[Credentials] = None,
    512                              valid_from: timedelta = timedelta(days=-1),
    513                              valid_to: timedelta = timedelta(days=89),
    514                              ) -> Credentials:
    515         pkey = _private_key(key_type=key_type)
    516         if issuer is not None:
    517             issuer_subject = issuer.certificate.subject
    518             issuer_key = issuer.private_key
    519         else:
    520             issuer_subject = None
    521             issuer_key = pkey
    522         subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
    523         csr = TestCA._make_csr(subject=subject,
    524                                issuer_subject=issuer_subject, pkey=pkey,
    525                                valid_from_delta=valid_from, valid_until_delta=valid_to)
    526         csr = TestCA._add_ca_usages(csr)
    527         cert = csr.sign(private_key=issuer_key,
    528                         algorithm=hashes.SHA256(),
    529                         backend=default_backend())
    530         return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
    531 
    532     @staticmethod
    533     def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
    534                                  key_type: Any,
    535                                  valid_from: timedelta = timedelta(days=-1),
    536                                  valid_to: timedelta = timedelta(days=89),
    537                                  ) -> Credentials:
    538         pkey = _private_key(key_type=key_type)
    539         subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
    540         csr = TestCA._make_csr(subject=subject,
    541                                issuer_subject=issuer.certificate.subject, pkey=pkey,
    542                                valid_from_delta=valid_from, valid_until_delta=valid_to)
    543         csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
    544         cert = csr.sign(private_key=issuer.private_key,
    545                         algorithm=hashes.SHA256(),
    546                         backend=default_backend())
    547         return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
    548 
    549     @staticmethod
    550     def _make_client_credentials(name: str,
    551                                  issuer: Credentials, email: Optional[str],
    552                                  key_type: Any,
    553                                  valid_from: timedelta = timedelta(days=-1),
    554                                  valid_to: timedelta = timedelta(days=89),
    555                                  ) -> Credentials:
    556         pkey = _private_key(key_type=key_type)
    557         subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
    558         csr = TestCA._make_csr(subject=subject,
    559                                issuer_subject=issuer.certificate.subject, pkey=pkey,
    560                                valid_from_delta=valid_from, valid_until_delta=valid_to)
    561         csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
    562         cert = csr.sign(private_key=issuer.private_key,
    563                         algorithm=hashes.SHA256(),
    564                         backend=default_backend())
    565         return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)