#!/usr/bin/env python

"""Unit tests for M2Crypto.X509.

Contributed by Toby Allsopp <toby@MI6.GEN.NZ> under M2Crypto's license.

Portions created by Open Source Applications Foundation (OSAF) are
Copyright (C) 2004-2005 OSAF. All Rights Reserved.
Author: Heikki Toivonen
"""

import base64
import logging
import os
import platform
import textwrap
import time
import warnings

from M2Crypto import ASN1, BIO, EVP, RSA, Rand, X509, m2  # noqa
from M2Crypto.util import expectedFailureIf
from tests import unittest

log = logging.getLogger(__name__)


class X509TestCase(unittest.TestCase):

    def callback(self, *args):
        pass

    def setUp(self):
        self.expected_hash = (
            "1A041EA7A3E77809868B8620B89A246DCAE229A5FC830CF5C26BB479F4CC1D8A"
        )

    def mkreq(self, bits, ca=0):
        pk = EVP.PKey()
        x = X509.Request()
        rsa = RSA.gen_key(bits, 65537, self.callback)
        pk.assign_rsa(rsa)
        rsa = None  # should not be freed here
        x.set_pubkey(pk)
        name = x.get_subject()
        name.C = "UK"
        name.CN = "OpenSSL Group"
        if not ca:
            ext1 = X509.new_extension("subjectAltName", "DNS:foobar.example.com")
            ext2 = X509.new_extension("nsComment", "Hello there")
            extstack = X509.X509_Extension_Stack()
            extstack.push(ext1)
            extstack.push(ext2)
            x.add_extensions(extstack)

        with self.assertRaises(ValueError):
            x.sign(pk, "sha513")

        x.sign(pk, "sha256")
        self.assertTrue(x.verify(pk))
        pk2 = x.get_pubkey()
        self.assertTrue(x.verify(pk2))
        return x, pk

    def test_ext(self):
        with self.assertRaises(ValueError):
            X509.new_extension("subjectKeyIdentifier", "hash")

        ext = X509.new_extension("subjectAltName", "DNS:foobar.example.com")
        self.assertEqual(ext.get_value(), "DNS:foobar.example.com")
        self.assertEqual(ext.get_value(indent=2), "  DNS:foobar.example.com")
        self.assertEqual(
            ext.get_value(flag=m2.X509V3_EXT_PARSE_UNKNOWN),
            "DNS:foobar.example.com",
        )

    def test_ext_error(self):
        with self.assertRaises(X509.X509Error):
            X509.new_extension("nonsensicalName", "blabla")

    def test_extstack(self):
        # new
        ext1 = X509.new_extension("subjectAltName", "DNS:foobar.example.com")
        ext2 = X509.new_extension("nsComment", "Hello there")
        extstack = X509.X509_Extension_Stack()

        # push
        extstack.push(ext1)
        extstack.push(ext2)
        self.assertEqual(extstack[1].get_name(), "nsComment")
        self.assertEqual(len(extstack), 2)

        # iterator
        i = 0
        for e in extstack:
            i += 1
            self.assertGreater(len(e.get_name()), 0)
        self.assertEqual(i, 2)

        # pop
        ext3 = extstack.pop()
        self.assertEqual(len(extstack), 1)
        self.assertEqual(extstack[0].get_name(), "subjectAltName")
        extstack.push(ext3)
        self.assertEqual(len(extstack), 2)
        self.assertEqual(extstack[1].get_name(), "nsComment")

        self.assertIsNotNone(extstack.pop())
        self.assertIsNotNone(extstack.pop())
        self.assertIsNone(extstack.pop())

    def test_x509_name(self):
        n = X509.X509_Name()
        # It seems this actually needs to be a real 2 letter country code
        n.C = "US"
        self.assertEqual(n.C, "US")
        n.SP = "State or Province"
        self.assertEqual(n.SP, "State or Province")
        n.L = "locality name"
        self.assertEqual(n.L, "locality name")
        # Yes, 'orhanization' is a typo, I know it and you're smart.
        # However, fixing this typo would break later hashes.
        # I don't think it is worthy of troubles.
        n.O = "orhanization name"
        self.assertEqual(n.O, "orhanization name")
        n.OU = "org unit"
        self.assertEqual(n.OU, "org unit")
        n.CN = "common name"
        self.assertEqual(n.CN, "common name")
        n.Email = "bob@example.com"
        self.assertEqual(n.Email, "bob@example.com")
        n.serialNumber = "1234"
        self.assertEqual(n.serialNumber, "1234")
        n.SN = "surname"
        self.assertEqual(n.SN, "surname")
        n.GN = "given name"
        self.assertEqual(n.GN, "given name")
        self.assertEqual(
            n.as_text(),
            "C=US, ST=State or Province, "
            + "L=locality name, O=orhanization name, "
            + "OU=org unit, CN=common "
            + "name/emailAddress=bob@example.com"
            + "/serialNumber=1234, "
            + "SN=surname, GN=given name",
        )
        self.assertEqual(
            len(n),
            10,
            "X509_Name has inappropriate length %d " % len(n),
        )
        n.givenName = "name given"
        self.assertEqual(n.GN, "given name")  # Just gets the first
        self.assertEqual(
            n.as_text(),
            "C=US, ST=State or Province, "
            + "L=locality name, O=orhanization name, "
            + "OU=org unit, "
            + "CN=common name/emailAddress=bob@example.com"
            + "/serialNumber=1234, "
            + "SN=surname, GN=given name, GN=name given",
        )
        self.assertEqual(
            len(n),
            11,
            "After adding one more attribute X509_Name should "
            + "have 11 and not %d attributes." % len(n),
        )
        n.add_entry_by_txt(
            field="CN",
            type=ASN1.MBSTRING_ASC,
            entry="Proxy",
            len=-1,
            loc=-1,
            set=0,
        )
        self.assertEqual(
            len(n),
            12,
            "After adding one more attribute X509_Name should "
            + "have 12 and not %d attributes." % len(n),
        )
        self.assertEqual(n.entry_count(), 12, n.entry_count())
        self.assertEqual(
            n.as_text(),
            "C=US, ST=State or Province, "
            + "L=locality name, O=orhanization name, "
            + "OU=org unit, "
            + "CN=common name/emailAddress=bob@example.com"
            + "/serialNumber=1234, "
            + "SN=surname, GN=given name, GN=name given, "
            + "CN=Proxy",
        )

        with self.assertRaises(AttributeError):
            n.__getattr__("foobar")
        n.foobar = 1
        self.assertEqual(n.foobar, 1)

        # X509_Name_Entry tests
        l = 0
        for entry in n:
            self.assertIsInstance(entry, X509.X509_Name_Entry)
            self.assertIsInstance(entry.get_object(), ASN1.ASN1_Object)
            self.assertIsInstance(entry.get_data(), ASN1.ASN1_String)
            l += 1
        self.assertEqual(l, 12, l)

        l = 0
        for cn in n.get_entries_by_nid(m2.NID_commonName):
            self.assertIsInstance(cn, X509.X509_Name_Entry)
            self.assertIsInstance(cn.get_object(), ASN1.ASN1_Object)
            data = cn.get_data()
            self.assertIsInstance(data, ASN1.ASN1_String)
            t = data.as_text()
            self.assertIn(
                t,
                (
                    "common name",
                    "Proxy",
                ),
            )
            l += 1
        self.assertEqual(
            l,
            2,
            "X509_Name has %d commonName entries instead " "of expected 2" % l,
        )

        # The target list is not deleted when the loop is finished
        # https://docs.python.org/2.7/reference\
        #        /compound_stmts.html#the-for-statement
        # so this checks what are the attributes of the last value of
        # ``cn`` variable.
        cn.set_data(b"Hello There!")
        self.assertEqual(cn.get_data().as_text(), "Hello There!")

        # OpenSSL 1.0.1h switched from encoding strings as PRINTABLESTRING (the
        # first hash value) to UTF8STRING (the second one)
        self.assertIn(
            n.as_hash(),
            (1697185131, 1370641112, 333998119),
            "Unexpected value of the X509_Name hash %s" % n.as_hash(),
        )

        self.assertRaises(IndexError, lambda: n[100])
        self.assertIsNotNone(n[10])

    def test_mkreq(self):
        (req, _) = self.mkreq(1024)
        req.save_pem("tests/tmp_request.pem")
        req2 = X509.load_request("tests/tmp_request.pem")
        os.remove("tests/tmp_request.pem")
        req.save("tests/tmp_request.pem")
        req3 = X509.load_request("tests/tmp_request.pem")
        os.remove("tests/tmp_request.pem")
        req.save("tests/tmp_request.der", format=X509.FORMAT_DER)
        req4 = X509.load_request("tests/tmp_request.der", format=X509.FORMAT_DER)
        os.remove("tests/tmp_request.der")
        if m2.OPENSSL_VERSION_NUMBER >= 0x30000000:
            req2t = req2.as_text().replace(
                " Public-Key: (1024 bit)",
                " RSA Public-Key: (1024 bit)",
            )
            req3t = req3.as_text().replace(
                " Public-Key: (1024 bit)",
                " RSA Public-Key: (1024 bit)",
            )
            req4t = req3.as_text().replace(
                " Public-Key: (1024 bit)",
                " RSA Public-Key: (1024 bit)",
            )
        else:
            req2t = req2.as_text()
            req3t = req3.as_text()
            req4t = req3.as_text()

        self.assertEqual(req.as_pem(), req2.as_pem())
        self.assertEqual(req.as_text(), req2t)
        self.assertEqual(req.as_der(), req2.as_der())
        self.assertEqual(req.as_pem(), req3.as_pem())
        self.assertEqual(req.as_text(), req3t)
        self.assertEqual(req.as_der(), req3.as_der())
        self.assertEqual(req.as_pem(), req4.as_pem())
        self.assertEqual(req.as_text(), req4t)
        self.assertEqual(req.as_der(), req4.as_der())
        self.assertEqual(req.get_version(), 0)

        if m2.OPENSSL_VERSION_NUMBER < 0x30400000:
            req.set_version(1)
            self.assertEqual(req.get_version(), 1)
        req.set_version(0)
        self.assertEqual(req.get_version(), 0)

    @unittest.skipIf(platform.system() == "Windows", "Skip on Windows. TODO")
    def test_mkcert(self):
        for utc in (True, False):
            req, pk = self.mkreq(1024)
            pkey = req.get_pubkey()
            self.assertTrue(req.verify(pkey))
            sub = req.get_subject()
            self.assertEqual(
                len(sub),
                2,
                "Subject should be long 2 items not %d" % len(sub),
            )

            cert = X509.X509()
            cert.set_serial_number(1)
            cert.set_version(2)
            cert.set_subject(sub)
            t = int(time.time()) + time.timezone
            log.debug(
                "t = %s",
                time.strftime("%a, %d %b %Y %H:%M:%S %z", time.localtime(t)),
            )
            if utc:
                now = ASN1.ASN1_UTCTIME()
            else:
                now = ASN1.ASN1_TIME()
            now.set_time(t)
            log.debug("now = %s", now)
            now_plus_year = ASN1.ASN1_TIME()
            now_plus_year.set_time(t + 60 * 60 * 24 * 365)
            log.debug("now_plus_year = %s", now_plus_year)
            cert.set_not_before(now)
            cert.set_not_after(now_plus_year)
            log.debug("cert = %s", cert.get_not_before())
            self.assertEqual(str(cert.get_not_before()), str(now))
            self.assertEqual(str(cert.get_not_after()), str(now_plus_year))

            issuer = X509.X509_Name()
            issuer.CN = "The Issuer Monkey"
            issuer.O = "The Organization Otherwise Known as My CA, Inc."
            cert.set_issuer(issuer)
            cert.set_pubkey(pkey)
            cert.set_pubkey(cert.get_pubkey())  # Make sure get/set work

            ext = X509.new_extension("subjectAltName", "DNS:foobar.example.com")
            ext.set_critical(0)
            self.assertEqual(ext.get_critical(), 0)
            cert.add_ext(ext)

            cert.sign(pk, "sha256")
            with self.assertRaises(ValueError):
                cert.sign(pk, "nosuchalgo")

            self.assertTrue(
                cert.get_ext("subjectAltName").get_name(),
                "subjectAltName",
            )
            self.assertTrue(cert.get_ext_at(0).get_name(), "subjectAltName")
            self.assertTrue(
                cert.get_ext_at(0).get_value(),
                "DNS:foobar.example.com",
            )
            self.assertEqual(
                cert.get_ext_count(),
                1,
                "Certificate should have now 1 extension not %d" % cert.get_ext_count(),
            )
            with self.assertRaises(IndexError):
                cert.get_ext_at(1)
            self.assertTrue(cert.verify())
            self.assertTrue(cert.verify(pkey))
            self.assertTrue(cert.verify(cert.get_pubkey()))
            self.assertEqual(cert.get_version(), 2)
            self.assertEqual(cert.get_serial_number(), 1)
            self.assertEqual(cert.get_issuer().CN, "The Issuer Monkey")

            self.assertFalse(cert.check_ca())
            self.assertFalse(cert.check_purpose(m2.X509_PURPOSE_SSL_SERVER, 1))
            self.assertFalse(cert.check_purpose(m2.X509_PURPOSE_NS_SSL_SERVER, 1))
            self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_SSL_SERVER, 0))
            self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_NS_SSL_SERVER, 0))
            self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_ANY, 0))

    def mkcacert(self, utc):
        req, pk = self.mkreq(1024, ca=1)
        pkey = req.get_pubkey()
        sub = req.get_subject()
        cert = X509.X509()
        cert.set_serial_number(1)
        cert.set_version(2)
        cert.set_subject(sub)
        t = int(time.time()) + time.timezone
        if utc:
            now = ASN1.ASN1_UTCTIME()
        else:
            now = ASN1.ASN1_TIME()
        now.set_time(t)
        now_plus_year = ASN1.ASN1_TIME()
        now_plus_year.set_time(t + 60 * 60 * 24 * 365)
        cert.set_not_before(now)
        cert.set_not_after(now_plus_year)
        issuer = X509.X509_Name()
        issuer.C = "UK"
        issuer.CN = "OpenSSL Group"
        cert.set_issuer(issuer)
        cert.set_pubkey(pkey)
        ext = X509.new_extension("basicConstraints", "CA:TRUE")
        cert.add_ext(ext)
        cert.sign(pk, "sha256")

        self.assertTrue(cert.check_ca())
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_SSL_SERVER, 1))
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_NS_SSL_SERVER, 1))
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_ANY, 1))
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_SSL_SERVER, 0))
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_NS_SSL_SERVER, 0))
        self.assertTrue(cert.check_purpose(m2.X509_PURPOSE_ANY, 0))

        return cert, pk, pkey

    def test_mkcacert(self):
        for utc in (True, False):
            cacert, _, pkey = self.mkcacert(utc)
            self.assertTrue(cacert.verify(pkey))

    def test_mkproxycert(self):
        for utc in (True, False):
            cacert, pk1, _ = self.mkcacert(utc)
            end_entity_cert_req, pk2 = self.mkreq(1024)
            end_entity_cert = self.make_eecert(cacert, utc)
            end_entity_cert.set_subject(end_entity_cert_req.get_subject())
            end_entity_cert.set_pubkey(end_entity_cert_req.get_pubkey())
            end_entity_cert.sign(pk1, "sha256")
            proxycert = self.make_proxycert(end_entity_cert, utc)
            proxycert.sign(pk2, "sha256")
            self.assertTrue(proxycert.verify(pk2))
            self.assertEqual(proxycert.get_ext_at(0).get_name(), "proxyCertInfo")
            self.assertEqual(
                proxycert.get_ext_at(0).get_value().strip(),
                "Path Length Constraint: infinite\n" + "Policy Language: Inherit all",
            )
            self.assertEqual(
                proxycert.get_ext_count(),
                1,
                proxycert.get_ext_count(),
            )
            self.assertEqual(
                proxycert.get_subject().as_text(),
                "C=UK, CN=OpenSSL Group, CN=Proxy",
            )
            self.assertEqual(
                proxycert.get_subject().as_text(indent=2, flags=m2.XN_FLAG_RFC2253),
                "  CN=Proxy,CN=OpenSSL Group,C=UK",
            )

    @staticmethod
    def make_eecert(cacert, utc):
        eecert = X509.X509()
        eecert.set_serial_number(2)
        eecert.set_version(2)
        t = int(time.time()) + time.timezone
        if utc:
            now = ASN1.ASN1_UTCTIME()
        else:
            now = ASN1.ASN1_TIME()
        now.set_time(t)
        now_plus_year = ASN1.ASN1_TIME()
        now_plus_year.set_time(t + 60 * 60 * 24 * 365)
        eecert.set_not_before(now)
        eecert.set_not_after(now_plus_year)
        eecert.set_issuer(cacert.get_subject())
        return eecert

    def make_proxycert(self, eecert, utc):
        proxycert = X509.X509()
        pk2 = EVP.PKey()
        proxykey = RSA.gen_key(1024, 65537, self.callback)
        pk2.assign_rsa(proxykey)
        proxycert.set_pubkey(pk2)
        proxycert.set_version(2)
        if utc:
            not_before = ASN1.ASN1_UTCTIME()
            not_after = ASN1.ASN1_UTCTIME()
        else:
            not_before = ASN1.ASN1_TIME()
            not_after = ASN1.ASN1_TIME()
        not_before.set_time(int(time.time()))
        offset = 12 * 3600
        not_after.set_time(int(time.time()) + offset)
        proxycert.set_not_before(not_before)
        proxycert.set_not_after(not_after)
        proxycert.set_issuer_name(eecert.get_subject())
        proxycert.set_serial_number(12345678)
        issuer_name_string = eecert.get_subject().as_text()
        seq = issuer_name_string.split(",")

        subject_name = X509.X509_Name()
        for entry in seq:
            l = entry.split("=")
            subject_name.add_entry_by_txt(
                field=l[0].strip(),
                type=ASN1.MBSTRING_ASC,
                entry=l[1],
                len=-1,
                loc=-1,
                set=0,
            )

        subject_name.add_entry_by_txt(
            field="CN",
            type=ASN1.MBSTRING_ASC,
            entry="Proxy",
            len=-1,
            loc=-1,
            set=0,
        )

        proxycert.set_subject_name(subject_name)
        # XXX leaks 8 bytes
        pci_ext = X509.new_extension(
            "proxyCertInfo", "critical,language:Inherit all", 1
        )
        proxycert.add_ext(pci_ext)
        return proxycert

    def test_fingerprint(self):
        x509 = X509.load_cert("tests/x509.pem")
        fp = x509.get_fingerprint("sha256")
        self.assertEqual(fp, self.expected_hash)

    def test_load_der_string(self):
        with open("tests/x509.der", "rb") as f:
            x509 = X509.load_cert_der_string(f.read())

        fp = x509.get_fingerprint("sha256")
        self.assertEqual(fp, self.expected_hash)

    def test_save_der_string(self):
        x509 = X509.load_cert("tests/x509.pem")
        s = x509.as_der()
        with open("tests/x509.der", "rb") as f:
            s2 = f.read()

        self.assertEqual(s, s2)

    def test_load(self):
        x509 = X509.load_cert("tests/x509.pem")
        x5092 = X509.load_cert("tests/x509.der", format=X509.FORMAT_DER)
        self.assertEqual(x509.as_text(), x5092.as_text())
        self.assertEqual(x509.as_pem(), x5092.as_pem())
        self.assertEqual(x509.as_der(), x5092.as_der())
        return

    def test_load_bio(self):
        with BIO.openfile("tests/x509.pem") as bio:
            with BIO.openfile("tests/x509.der") as bio2:
                x509 = X509.load_cert_bio(bio)
                x5092 = X509.load_cert_bio(bio2, format=X509.FORMAT_DER)

        with self.assertRaises(ValueError):
            X509.load_cert_bio(bio2, format=45678)

        self.assertEqual(x509.as_text(), x5092.as_text())
        self.assertEqual(x509.as_pem(), x5092.as_pem())
        self.assertEqual(x509.as_der(), x5092.as_der())

    def test_load_string(self):
        with open("tests/x509.pem") as f:
            s = f.read()

        with open("tests/x509.der", "rb") as f2:
            s2 = f2.read()

        x509 = X509.load_cert_string(s)
        x5092 = X509.load_cert_string(s2, X509.FORMAT_DER)
        self.assertEqual(x509.as_text(), x5092.as_text())
        self.assertEqual(x509.as_pem(), x5092.as_pem())
        self.assertEqual(x509.as_der(), x5092.as_der())

    def test_load_request_bio(self):
        (req, _) = self.mkreq(1024)

        r1 = X509.load_request_der_string(req.as_der())
        r2 = X509.load_request_string(req.as_der(), X509.FORMAT_DER)
        r3 = X509.load_request_string(req.as_pem(), X509.FORMAT_PEM)

        r4 = X509.load_request_bio(BIO.MemoryBuffer(req.as_der()), X509.FORMAT_DER)
        r5 = X509.load_request_bio(BIO.MemoryBuffer(req.as_pem()), X509.FORMAT_PEM)

        for r in [r1, r2, r3, r4, r5]:
            self.assertEqual(req.as_der(), r.as_der())

        with self.assertRaises(ValueError):
            X509.load_request_bio(BIO.MemoryBuffer(req.as_pem()), 345678)

    def test_save(self):
        x509 = X509.load_cert("tests/x509.pem")
        with open("tests/x509.pem", "r") as f:
            l_tmp = f.readlines()
            # -----BEGIN CERTIFICATE----- : -----END CERTIFICATE-----
            beg_idx = l_tmp.index("-----BEGIN CERTIFICATE-----\n")
            end_idx = l_tmp.index("-----END CERTIFICATE-----\n")
            x509_pem = "".join(l_tmp[beg_idx : end_idx + 1])

        with open("tests/x509.der", "rb") as f:
            x509_der = f.read()

        x509.save("tests/tmpcert.pem")
        with open("tests/tmpcert.pem") as f:
            s = f.read()

        self.assertEqual(s, x509_pem)
        os.remove("tests/tmpcert.pem")
        x509.save("tests/tmpcert.der", format=X509.FORMAT_DER)
        with open("tests/tmpcert.der", "rb") as f:
            s = f.read()

        self.assertEqual(s, x509_der)
        os.remove("tests/tmpcert.der")

    def test_malformed_data(self):
        try:
            with self.assertRaises(X509.X509Error):
                X509.load_cert_string("Hello")
            with self.assertRaises(X509.X509Error):
                X509.load_cert_der_string("Hello")
            with self.assertRaises(X509.X509Error):
                X509.new_stack_from_der(b"Hello")
            with self.assertRaises(X509.X509Error):
                X509.load_cert("tests/__init__.py")
            with self.assertRaises(X509.X509Error):
                X509.load_request("tests/__init__.py")
            with self.assertRaises(X509.X509Error):
                X509.load_request_string("Hello")
            with self.assertRaises(X509.X509Error):
                X509.load_request_der_string("Hello")
            with self.assertRaises(X509.X509Error):
                X509.load_crl("tests/__init__.py")
        except SystemError:
            pass

    def test_long_serial(self):
        cert = X509.load_cert("tests/long_serial_cert.pem")
        self.assertEqual(cert.get_serial_number(), 17616841808974579194)

        cert = X509.load_cert("tests/thawte.pem")
        self.assertEqual(
            cert.get_serial_number(),
            127614157056681299805556476275995414779,
        )

    def test_set_long_serial(self):
        cert = X509.X509()
        cert.set_serial_number(127614157056681299805556476275995414779)
        self.assertEqual(
            cert.get_serial_number(),
            127614157056681299805556476275995414779,
        )

    @unittest.skipIf(platform.system() == "Windows", "Skip on Windows. TODO")
    def test_date_after_2050_working(self):
        cert = X509.load_cert("tests/bad_date_cert.crt")
        self.assertEqual(str(cert.get_not_after()), "Feb  9 14:57:46 2116 GMT")

    @unittest.skipIf(platform.system() == "Windows", "Skip on Windows. TODO")
    def test_date_reference_counting(self):
        """x509_get_not_before() and x509_get_not_after() return internal
        pointers into X509. As the returned ASN1_TIME objects do not store any
        reference to the X509 itself, they become invalid when the last
        reference to X509 goes out of scope and the underlying memory is freed.

        https://todo.sr.ht/~mcepl/m2crypto/325
        """
        cert = X509.load_cert("tests/bad_date_cert.crt")
        not_before = cert.get_not_before()
        not_after = cert.get_not_after()
        del cert
        self.assertEqual(str(not_before), "Mar  4 14:57:46 2016 GMT")
        self.assertEqual(str(not_after), "Feb  9 14:57:46 2116 GMT")

    def test_easy_rsa_generated(self):
        """Test loading a cert generated by easy RSA.

        https://github.com/fedora-infra/fedmsg/pull/389
        """
        # Does this raise an exception?
        X509.load_cert("tests/easy_rsa.pem")

    def test_add_subject_key_identifier(self):
        # 1. Create a certificate (it needs a key pair and subject/issuer set up)
        req, pk = self.mkreq(1024)
        cert = X509.X509()
        cert.set_serial_number(3)
        cert.set_version(2)
        cert.set_subject(req.get_subject())
        cert.set_issuer(req.get_subject())  # Self-signed for simplicity
        cert.set_pubkey(pk)

        # 2. Add the Subject Key Identifier
        # Assuming EVP.PKey has get_key_identifier() or similar for the SKID hash
        result = cert.add_subject_key_identifier()
        self.assertEqual(result, 1, "Failed to add subjectKeyIdentifier extension")

        # 3. Verify the extension exists and its value format is correct
        skid_ext = cert.get_ext("subjectKeyIdentifier")

        # Check the name
        self.assertEqual(skid_ext.get_name(), "subjectKeyIdentifier")

        # Check that the value is non-empty and has the correct colon-separated hex format
        value = skid_ext.get_value()
        self.assertGreater(len(value), 10, "SKID value is too short or empty")
        self.assertTrue(
            all(c in "0123456789ABCDEF:" for c in value.upper()),
            "SKID value contains invalid characters",
        )
        self.assertIn(":", value, "SKID value should be colon-separated hex string")

        # 4. Sign the certificate (requires the SKID extension to be correctly formed)
        cert.sign(pk, "sha256")
        self.assertTrue(
            cert.verify(pk), "Certificate signature failed after adding SKID"
        )


class X509StackTestCase(unittest.TestCase):
    def setUp(self):
        if m2.OPENSSL_VERSION_NUMBER >= 0x30000000:
            self.expected_subject = (
                "/DC=org/DC=doegrids/OU=Services/CN=host\\/bosshog.lbl.gov"
            )
        else:
            self.expected_subject = (
                "/DC=org/DC=doegrids/OU=Services/CN=host/bosshog.lbl.gov"
            )

    def test_make_stack_from_der(self):
        with open("tests/der_encoded_seq.b64", "rb") as f:
            b64 = f.read()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", DeprecationWarning)
            seq = base64.decodebytes(b64)

        stack = X509.new_stack_from_der(seq)
        cert = stack.pop()
        self.assertIsNone(stack.pop())

        cert.foobar = 1
        self.assertEqual(cert.foobar, 1)

        subject = cert.get_subject()
        self.assertEqual(str(subject), self.expected_subject)

    def test_make_stack_check_num(self):
        with open("tests/der_encoded_seq.b64", "rb") as f:
            b64 = f.read()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", DeprecationWarning)
            seq = base64.decodebytes(b64)

        stack = X509.new_stack_from_der(seq)
        num = len(stack)
        self.assertEqual(num, 1)
        cert = stack.pop()
        num = len(stack)
        self.assertEqual(num, 0)
        subject = cert.get_subject()
        self.assertEqual(str(subject), self.expected_subject)

    def test_make_stack(self):
        stack = X509.X509_Stack()
        cert = X509.load_cert("tests/x509.pem")
        issuer = X509.load_cert("tests/ca.pem")
        cert_subject1 = cert.get_subject()
        issuer_subject1 = issuer.get_subject()
        stack.push(cert)
        stack.push(issuer)

        # Test stack iterator
        i = 0
        for c in stack:
            i += 1
            self.assertGreater(len(c.get_subject().CN), 0)
        self.assertEqual(i, 2)

        stack.pop()
        cert_pop = stack.pop()
        cert_subject2 = cert_pop.get_subject()
        issuer_subject2 = issuer.get_subject()
        self.assertEqual(str(cert_subject1), str(cert_subject2))
        self.assertEqual(str(issuer_subject1), str(issuer_subject2))

    def test_as_der(self):
        stack = X509.X509_Stack()
        cert = X509.load_cert("tests/x509.pem")
        issuer = X509.load_cert("tests/ca.pem")
        cert_subject1 = cert.get_subject()
        issuer_subject1 = issuer.get_subject()
        stack.push(cert)
        stack.push(issuer)
        der_seq = stack.as_der()
        stack2 = X509.new_stack_from_der(der_seq)
        stack2.pop()
        cert_pop = stack2.pop()
        cert_subject2 = cert_pop.get_subject()
        issuer_subject2 = issuer.get_subject()
        self.assertEqual(str(cert_subject1), str(cert_subject2))
        self.assertEqual(str(issuer_subject1), str(issuer_subject2))


class X509StackDegenerateTestCase(unittest.TestCase):
    def setUp(self):
        self.cert1 = X509.load_cert("tests/signer.pem")
        self.cert2 = X509.load_cert("tests/server.pem")
        self.stack = X509.X509_Stack()
        self.stack.push(self.cert1)
        self.stack.push(self.cert2)

    def test_create_degenerate_method(self):
        """Test X509_Stack.create_degenerate() method."""
        bio = BIO.MemoryBuffer()
        ret = self.stack.create_degenerate(bio)
        self.assertEqual(ret, 1)

        output = bio.read()
        self.assertTrue(len(output) > 0)
        # self.assertIn(b"-----BEGIN PKCS7-----", output)

    def test_save_degenerate_method(self):
        """Test X509_Stack.save_degenerate() method."""
        filename = "tests/test_stack_degenerate.p7c"
        try:
            ret = self.stack.save_degenerate(filename)
            self.assertEqual(ret, 1)
            self.assertTrue(os.path.exists(filename))
        finally:
            if os.path.exists(filename):
                os.unlink(filename)

    def test_create_degenerate_empty_stack(self):
        """Test error handling for empty stack."""
        empty_stack = X509.X509_Stack()
        bio = BIO.MemoryBuffer()

        with self.assertRaises(X509.X509Error):
            empty_stack.create_degenerate(bio)

    def test_create_degenerate_invalid_bio(self):
        """Test error handling for invalid BIO parameter."""
        with self.assertRaises(X509.X509Error):
            self.stack.create_degenerate("not_a_bio")


class X509ExtTestCase(unittest.TestCase):
    def test_ext(self):
        if 0:  # XXX
            # With this leaks 8 bytes:
            name: str = "proxyCertInfo"
            value: str = "critical,language:Inherit all"
        else:
            # With this there are no leaks:
            name: str = "nsComment"
            value: str = "Hello"

        ctx = m2.x509v3_set_nconf()
        x509_ext_ptr = m2.x509v3_ext_conf(None, ctx, name, value)
        X509.X509_Extension(x509_ext_ptr, 1)

    def test_multiple_extensions(self):
        # Testing for https://todo.sr.ht/~mcepl/m2crypto/9
        # Test creating multiple X509 extensions

        # These are extensions that should work without special context
        ext1 = X509.new_extension("basicConstraints", "CA:TRUE")
        self.assertIsNotNone(ext1)

        ext2 = X509.new_extension("subjectAltName", "DNS:example.com")
        self.assertIsNotNone(ext2)

        ext3 = X509.new_extension("keyUsage", "digitalSignature,keyEncipherment")
        self.assertIsNotNone(ext3)

        # Test subjectKeyIdentifier with explicit hex value (not "hash")
        sub_key_id = "1CE6F0585832BC7BBA8EE0231BFF1799B04DCF64"
        ext4 = X509.new_extension("subjectKeyIdentifier", sub_key_id)
        self.assertIsNotNone(ext4)

        # Test with colons in the hex string (should also work)
        sub_key_id_with_colons = "1C:E6:F0:58:58:32:BC:7B:BA:8E:E0:23:1B:FF:17:99:B0:4D:CF:64"
        ext5 = X509.new_extension("subjectKeyIdentifier", sub_key_id_with_colons)
        self.assertIsNotNone(ext5)

        # Verify the extensions have expected values
        self.assertEqual(ext1.get_name(), "basicConstraints")
        self.assertEqual(ext2.get_name(), "subjectAltName")
        self.assertEqual(ext3.get_name(), "keyUsage")
        self.assertEqual(ext4.get_name(), "subjectKeyIdentifier")
        self.assertEqual(ext5.get_name(), "subjectKeyIdentifier")


    @unittest.skip("requires functionality we don't support yet")
    def test_multiple_extensions_older_version(self):
        # Testing for https://todo.sr.ht/~mcepl/m2crypto/9
        sub_key_id = "1C:E6:F0:58:58:32:BC:7B:BA:8E:E0:23:1B:FF:17:99:B0:4D:CF:64"
        auth_id = "1C:E6:F0:58:58:32:BC:7B:BA:8E:E0:23:1B:FF:17:99:B0:4D:CF:64"

        cert_pem_string = textwrap.dedent(
            """\
        -----BEGIN CERTIFICATE-----
        MIIGFjCCA/6gAwIBAgIJAO7rHaO9YDQDMA0GCSqGSIb3DQEBCwUAMHsxCzAJBgNV
        BAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJTG9zIEdhdG9zMRMwEQYDVQQK
        DApOYUplRGEgTExDMSMwIQYDVQQLDBpOYUplRGEgR2VvY2FjaGluZyBTZXJ2aWNl
        czERMA8GA1UEAwwIbG9jYWxfY2EwHhcNMTgwNTMxMTgwOTMwWhcNMzcxMjMxMTgw
        OTMwWjB7MQswCQYDVQQGEwJVUzELMAkGA1UECAwCQ0ExEjAQBgNVBAcMCUxvcyBH
        YXRvczETMBEGA1UECgwKTmFKZURhIExMQzEjMCEGA1UECwwaTmFKZURhIEdlb2Nh
        Y2hpbmcgU2VydmljZXMxETAPBgNVBAMMCGxvY2FsX2NhMIICIjANBgkqhkiG9w0B
        AQEFAAOCAg8AMIICCgKCAgEAwL4VBqghrv9DdUq+63Yty/kaNINIO+ldhY8GxrZd
        KXdJqanZN0nMZaW4lys7OTGKml2TzL1JiOueChky5H+8vbmXF8Mp2j3DIQRlYQae
        m/cijW4Q8QRiUNdsIcB6pB7Oa7JvyTxMsbwQC2MlE9ItNR2zJ1RMpzGvRoO+wheZ
        8zPtXquo+/rJfzoxd2G6/L9Rrwo3Izgwb6NiXbQadg675o/0shmhD1LJT5DcmvjL
        0shj+VasUKAOwgt/5GtkjjeE53VsExOkJKrH/RUodl6dXiXBq4ehtwdFGQnTDpyg
        bKSXI8M3FPU/zYjt4HYDW5R+VlkKYEdMaOwQ4b9waArCZPSh6SoSZ8SyjRN1wKqn
        da+vl5MrzPGTbnN8CXvs89+ti+iT+pdsCc/L3kdwqaV3HNE0pjTg3bChWJ/iNltF
        E3lqTmnUcMfvhHpAcj6txB4YqzvQgh7DQ4KnKwFgXvrS7t2fdgFxVe/Bl7ApXi1t
        Eg8AEjasuLqb/sTWLyvoWog1iJg7uWsv7F3DXloc7q80eAh610KtPTSzcCvOfd9I
        i6+P9yxcOZv4vCbX7rrCJt/scL2/Hz1qqQYA/DcKtvitA2hUAx6AJlMvW14Dw3tc
        nHny5lJ4Ty7ZiGN7Bg9Jj1uJusa0le2dkFwx5WjXZu2QMtjgJT8aBl3iM36jNtY7
        v6cCAwEAAaOBnDCBmTAdBgNVHQ4EFgQUHObwWFgyvHu6juAjG/8XmbBNz2QwHwYD
        VR0jBBgwFoAUHObwWFgyvHu6juAjG/8XmbBNz2QwDwYDVR0TAQH/BAUwAwEB/zAO
        BgNVHQ8BAf8EBAMCAQYwNgYDVR0fBC8wLTAroCmgJ4YlaHR0cHM6Ly9jcmwuc29t
        ZS51cmwuY29tL2xvY2FsX2NhLmNybDANBgkqhkiG9w0BAQsFAAOCAgEAepM/VwzI
        N3aWc08IgF0+J3wYAjDzq2y/ixDXwL/B/XOHElySaiDakiT6HM52Ek/LkFK67Llp
        TZIxCwViBxkkcTBS10ymGfsYY5R7lOx14SUIXPOS/Pvht1IZBuSp5J9woZjEZitk
        InmWYSmA2Q85JtFs86pNQD9gCOCd5hnKK2LqOwrPAcnOJ06FhZFT/psI5MR8XFjD
        /dJUfnkxbK6S77sCslALdsaYdWp6B4gnmZWF3tTxq1IkNKVJuGdcPLg33zFAXzmo
        POzjrTmr+1DUEahBbY/9oGcQQh0Ir9lTdd0Uym40FN/7jDA8G1CeK8lsL+TZ5dTU
        BPCI2LLd2p7c3SddMNM/GUZdoJ4LXKx3JnDu8lYpjOcL21QHjqfNSmAHzX8skery
        jawwi0yijJYwsyrB629ek0p/v16uTojs6JGddOmnz9z1/pBctRw6w83d0jNQc2yY
        g89xOl9q0Z7G6rThyNuJwworN5FaPJB2Pl7pHf2uJZEp0mq1SN3Lcfre2yXig8Tc
        rWjDY8k9VrEJG0G3n9FVv9hKvob9ngUMkmyxE5E4VWyab5gVt2m0XXJjz5Sc3530
        dQ9SXhFS7s3060/yl0BBWnTtfu9zGdKaz4lWo25Q0r7HD5y/MwUCbqpRVqXJxGHY
        d3PEYaXkdwhAi3EbarF7R8r3hKzYCpXxfI4=
        -----END CERTIFICATE-----"""
        )
        m2_x509_cert = X509.load_cert_string(cert_pem_string)

        local_ski = m2_x509_cert.get_ext("subjectKeyIdentifier")
        local_aki = m2_x509_cert.get_ext("authorityKeyIdentifier")

        X509.new_extension("subjectKeyIdentifier", sub_key_id)
        X509.new_extension("authorityKeyIdentifier", "keyid:" + auth_id)


class X509_StoreContextTestCase(unittest.TestCase):

    def test_verify_cert(self):
        # Test with the CA that signed tests/x509.pem
        ca = X509.load_cert("tests/ca.pem")
        cert = X509.load_cert("tests/x509.pem")
        store = X509.X509_Store()
        store.add_x509(ca)
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, cert)
        self.assertTrue(store_ctx.verify_cert())

        # Test with the wrong CA, this CA did not sign tests/x509.pem
        wrong_ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        cert = X509.load_cert("tests/x509.pem")
        store = X509.X509_Store()
        store.add_x509(wrong_ca)
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, cert)
        self.assertFalse(store_ctx.verify_cert())

    def test_verify_with_add_crl(self):
        ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        valid_cert = X509.load_cert("tests/crl_data/certs/valid_cert.pem")
        revoked_cert = X509.load_cert("tests/crl_data/certs/revoked_cert.pem")
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")

        # Verify that a good cert is verified OK
        store = X509.X509_Store()
        store.add_x509(ca)
        store.add_crl(crl)
        store.set_flags(
            X509.m2.X509_V_FLAG_CRL_CHECK | X509.m2.X509_V_FLAG_CRL_CHECK_ALL
        )
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, valid_cert)
        self.assertTrue(store_ctx.verify_cert())

        # Verify that a revoked cert is not verified
        store = X509.X509_Store()
        store.add_x509(ca)
        store.add_crl(crl)
        store.set_flags(
            X509.m2.X509_V_FLAG_CRL_CHECK | X509.m2.X509_V_FLAG_CRL_CHECK_ALL
        )
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, revoked_cert)
        self.assertFalse(store_ctx.verify_cert())

    def test_verify_with_add_crls(self):
        ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        valid_cert = X509.load_cert("tests/crl_data/certs/valid_cert.pem")
        revoked_cert = X509.load_cert("tests/crl_data/certs/revoked_cert.pem")
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")

        # Verify that a good cert is verified OK
        store = X509.X509_Store()
        store.add_x509(ca)
        store.set_flags(
            X509.m2.X509_V_FLAG_CRL_CHECK | X509.m2.X509_V_FLAG_CRL_CHECK_ALL
        )
        crl_stack = X509.CRL_Stack()
        crl_stack.push(crl)
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, valid_cert)
        store_ctx.add_crls(crl_stack)
        self.assertTrue(store_ctx.verify_cert())

        # Verify that a revoked cert is not verified
        store = X509.X509_Store()
        store.add_x509(ca)
        store.set_flags(
            X509.m2.X509_V_FLAG_CRL_CHECK | X509.m2.X509_V_FLAG_CRL_CHECK_ALL
        )
        crl_stack = X509.CRL_Stack()
        crl_stack.push(crl)
        store_ctx = X509.X509_Store_Context()
        store_ctx.init(store, revoked_cert)
        store_ctx.add_crls(crl_stack)
        self.assertFalse(store_ctx.verify_cert())


class CRL_StackTestCase(unittest.TestCase):
    def test_new(self):
        crl_stack = X509.CRL_Stack()
        self.assertIsNotNone(crl_stack)
        self.assertEqual(len(crl_stack), 0)

    def test_push_and_pop(self):
        crl_stack = X509.CRL_Stack()
        crl_a = X509.CRL()
        crl_b = X509.CRL()
        self.assertNotEqual(crl_a, crl_b)
        crl_stack.push(crl_a)
        crl_stack.push(crl_b)
        self.assertEqual(len(crl_stack), 2)
        popped_b = crl_stack.pop()
        self.assertEqual(crl_b, popped_b)
        self.assertEqual(len(crl_stack), 1)
        popped_a = crl_stack.pop()
        self.assertEqual(crl_a, popped_a)
        self.assertEqual(len(crl_stack), 0)


class CRLTestCase(unittest.TestCase):
    def test_new(self):
        crl = X509.CRL()
        self.assertEqual(crl.as_text()[:34], "Certificate Revocation List (CRL):")

    def test_verify(self):
        ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")
        self.assertTrue(crl.verify(ca.get_pubkey()))

        wrong_ca = X509.load_cert("tests/ca.pem")
        self.assertFalse(crl.verify(wrong_ca.get_pubkey()))

    def test_get_issuer(self):
        ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")
        ca_issuer = ca.get_issuer()
        crl_issuer = crl.get_issuer()
        self.assertEqual(ca_issuer.as_hash(), crl_issuer.as_hash())

        wrong_ca = X509.load_cert("tests/ca.pem")
        wrong_ca_issuer = wrong_ca.get_issuer()
        self.assertNotEqual(wrong_ca_issuer.as_hash(), crl_issuer.as_hash())

    def test_load_crl(self):
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")
        self.assertIsNotNone(crl)
        self.assertIsInstance(crl, X509.CRL)

    def test_load_crl_string(self):
        f = open("tests/crl_data/certs/revoking_crl.pem")
        data = f.read()
        f.close()
        crl = X509.load_crl_string(data)
        self.assertIsInstance(crl, X509.CRL)

        ca = X509.load_cert("tests/crl_data/certs/revoking_ca.pem")
        ca_issuer = ca.get_issuer()
        crl_issuer = crl.get_issuer()
        self.assertEqual(ca_issuer.as_hash(), crl_issuer.as_hash())

    def test_get_last_updated(self):
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")
        last_update_dt = crl.get_lastUpdate().get_datetime()
        # Format the datetime object to the expected string format "Nov 26 10:50:25 2025 GMT"
        # The %Z directive does not output "GMT" directly for UTC, so we append it.
        expected_lastUpdate = last_update_dt.strftime("%b %d %H:%M:%S %Y UTC").replace("UTC", "GMT")
        self.assertEqual(str(crl.get_lastUpdate()), expected_lastUpdate)

    def test_get_next_update(self):
        crl = X509.load_crl("tests/crl_data/certs/revoking_crl.pem")
        next_update_dt = crl.get_nextUpdate().get_datetime()
        # Format the datetime object to the expected string format "Nov 26 10:50:25 2025 GMT"
        # The %Z directive does not output "GMT" directly for UTC, so we append it.
        expected_nextUpdate = next_update_dt.strftime("%b %d %H:%M:%S %Y UTC").replace("UTC", "GMT")
        self.assertEqual(str(crl.get_nextUpdate()), expected_nextUpdate)


def suite():
    st = unittest.TestSuite()
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(X509TestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(X509StackTestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(X509StackDegenerateTestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(X509ExtTestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(X509_StoreContextTestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(CRLTestCase))
    st.addTest(unittest.TestLoader().loadTestsFromTestCase(CRL_StackTestCase))
    return st


if __name__ == "__main__":
    Rand.load_file("randpool.dat", -1)
    unittest.TextTestRunner().run(suite())
    Rand.save_file("randpool.dat")
