# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License.  You
# may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.  See the License for the specific language governing
# permissions and limitations under the License.

"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
context.

Due to limitations of the CPython asyncio.Protocol implementation for SSL, the async API does not support PyOpenSSL.
"""
from __future__ import annotations

import socket as _socket
import ssl as _stdlibssl
import sys as _sys
import time as _time
from errno import EINTR as _EINTR
from ipaddress import ip_address as _ip_address
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import cryptography.x509 as x509
import service_identity
from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto

from pymongo.errors import ConfigurationError as _ConfigurationError
from pymongo.errors import _CertificateError  # type:ignore[attr-defined]
from pymongo.ocsp_cache import _OCSPCache
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
from pymongo.socket_checker import SocketChecker as _SocketChecker
from pymongo.socket_checker import _errno_from_exception
from pymongo.write_concern import validate_boolean

if TYPE_CHECKING:
    from ssl import VerifyMode


_T = TypeVar("_T")

try:
    import certifi

    _HAVE_CERTIFI = True
except ImportError:
    _HAVE_CERTIFI = False

PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD
# Always available
OP_NO_SSLv2 = _SSL.OP_NO_SSLv2
OP_NO_SSLv3 = _SSL.OP_NO_SSLv3
OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION
# This isn't currently documented for PyOpenSSL
OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0)

# Always available
HAS_SNI = True
IS_PYOPENSSL = True

# Base Exception class
SSLError = _SSL.Error

# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002
_VERIFY_MAP = {
    _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE,
    _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER,
    _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}

_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}


# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
def _is_ip_address(address: Any) -> bool:
    try:
        _ip_address(address)
        return True
    except (ValueError, UnicodeError):
        return False


# According to the docs for socket.send it can raise
# WantX509LookupError and should be retried.
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
BLOCKING_IO_READ_ERROR = _SSL.WantReadError
BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError
BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError


def _ragged_eof(exc: BaseException) -> bool:
    """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
    return exc.args == (-1, "Unexpected EOF")


# https://github.com/pyca/pyopenssl/issues/168
# https://github.com/pyca/pyopenssl/issues/176
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
class _sslConn(_SSL.Connection):
    def __init__(
        self,
        ctx: _SSL.Context,
        sock: Optional[_socket.socket],
        suppress_ragged_eofs: bool,
    ):
        self.socket_checker = _SocketChecker()
        self.suppress_ragged_eofs = suppress_ragged_eofs
        super().__init__(ctx, sock)

    def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
        timeout = self.gettimeout()
        if timeout:
            start = _time.monotonic()
        while True:
            try:
                return call(*args, **kwargs)
            except BLOCKING_IO_ERRORS as exc:
                # Do not retry if the connection is in non-blocking mode.
                if timeout == 0:
                    raise exc
                # Check for closed socket.
                if self.fileno() == -1:
                    if timeout and _time.monotonic() - start > timeout:
                        raise _socket.timeout("timed out") from None
                    raise SSLError("Underlying socket has been closed") from None
                if isinstance(exc, _SSL.WantReadError):
                    want_read = True
                    want_write = False
                elif isinstance(exc, _SSL.WantWriteError):
                    want_read = False
                    want_write = True
                else:
                    want_read = True
                    want_write = True
                self.socket_checker.select(self, want_read, want_write, timeout)
                if timeout and _time.monotonic() - start > timeout:
                    raise _socket.timeout("timed out") from None
                continue

    def do_handshake(self, *args: Any, **kwargs: Any) -> None:
        return self._call(super().do_handshake, *args, **kwargs)

    def recv(self, *args: Any, **kwargs: Any) -> bytes:
        try:
            return self._call(super().recv, *args, **kwargs)
        except _SSL.SysCallError as exc:
            # Suppress ragged EOFs to match the stdlib.
            if self.suppress_ragged_eofs and _ragged_eof(exc):
                return b""
            raise

    def recv_into(self, *args: Any, **kwargs: Any) -> int:
        try:
            return self._call(super().recv_into, *args, **kwargs)
        except _SSL.SysCallError as exc:
            # Suppress ragged EOFs to match the stdlib.
            if self.suppress_ragged_eofs and _ragged_eof(exc):
                return 0
            raise

    def sendall(self, buf: bytes, flags: int = 0) -> None:  # type: ignore[override]
        view = memoryview(buf)
        total_length = len(buf)
        total_sent = 0
        while total_sent < total_length:
            try:
                sent = self._call(super().send, view[total_sent:], flags)
            # XXX: It's not clear if this can actually happen. PyOpenSSL
            # doesn't appear to have any interrupt handling, nor any interrupt
            # errors for OpenSSL connections.
            except OSError as exc:
                if _errno_from_exception(exc) == _EINTR:
                    continue
                raise
            # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756
            # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html
            if sent <= 0:
                raise OSError("connection closed")
            total_sent += sent


class _CallbackData:
    """Data class which is passed to the OCSP callback."""

    def __init__(self) -> None:
        self.trusted_ca_certs: Optional[list[x509.Certificate]] = None
        self.check_ocsp_endpoint: Optional[bool] = None
        self.ocsp_response_cache = _OCSPCache()


class SSLContext:
    """A CPython compatible SSLContext implementation wrapping PyOpenSSL's
    context.
    """

    __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")

    def __init__(self, protocol: int):
        self._protocol = protocol
        self._ctx = _SSL.Context(self._protocol)
        self._callback_data = _CallbackData()
        self._check_hostname = True
        # OCSP
        # XXX: Find a better place to do this someday, since this is client
        # side configuration and wrap_socket tries to support both client and
        # server side sockets.
        self._callback_data.check_ocsp_endpoint = True
        self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)

    @property
    def protocol(self) -> int:
        """The protocol version chosen when constructing the context.
        This attribute is read-only.
        """
        return self._protocol

    def __get_verify_mode(self) -> VerifyMode:
        """Whether to try to verify other peers' certificates and how to
        behave if verification fails. This attribute must be one of
        ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
        """
        return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]

    def __set_verify_mode(self, value: VerifyMode) -> None:
        """Setter for verify_mode."""

        def _cb(
            _connobj: _SSL.Connection,
            _x509obj: _crypto.X509,
            _errnum: int,
            _errdepth: int,
            retcode: int,
        ) -> bool:
            # It seems we don't need to do anything here. Twisted doesn't,
            # and OpenSSL's SSL_CTX_set_verify let's you pass NULL
            # for the callback option. It's weird that PyOpenSSL requires
            # this.
            # This is optional in pyopenssl >= 20 and can be removed once minimum
            # supported version is bumped
            # See: pyopenssl.org/en/latest/changelog.html#id47
            return bool(retcode)

        self._ctx.set_verify(_VERIFY_MAP[value], _cb)

    verify_mode = property(__get_verify_mode, __set_verify_mode)

    def __get_check_hostname(self) -> bool:
        return self._check_hostname

    def __set_check_hostname(self, value: Any) -> None:
        validate_boolean("check_hostname", value)
        self._check_hostname = value

    check_hostname = property(__get_check_hostname, __set_check_hostname)

    def __get_check_ocsp_endpoint(self) -> Optional[bool]:
        return self._callback_data.check_ocsp_endpoint

    def __set_check_ocsp_endpoint(self, value: bool) -> None:
        validate_boolean("check_ocsp", value)
        self._callback_data.check_ocsp_endpoint = value

    check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)

    def __get_options(self) -> int:
        # Calling set_options adds the option to the existing bitmask and
        # returns the new bitmask.
        # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
        return self._ctx.set_options(0)

    def __set_options(self, value: int) -> None:
        # Explicitly convert to int, since newer CPython versions
        # use enum.IntFlag for options. The values are the same
        # regardless of implementation.
        self._ctx.set_options(int(value))

    options = property(__get_options, __set_options)

    def load_cert_chain(
        self,
        certfile: Union[str, bytes],
        keyfile: Union[str, bytes, None] = None,
        password: Optional[str] = None,
    ) -> None:
        """Load a private key and the corresponding certificate. The certfile
        string must be the path to a single file in PEM format containing the
        certificate as well as any number of CA certificates needed to
        establish the certificate's authenticity. The keyfile string, if
        present, must point to a file containing the private key. Otherwise
        the private key will be taken from certfile as well.
        """
        # Match CPython behavior
        # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971
        # Password callback MUST be set first or it will be ignored.
        if password:

            def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: Optional[bytes]) -> bytes:
                # XXX:We could check the password length against what OpenSSL
                # tells us is the max, but we can't raise an exception, so...
                # warn?
                assert password is not None
                return password.encode("utf-8")

            self._ctx.set_passwd_cb(_pwcb)
        self._ctx.use_certificate_chain_file(certfile)
        self._ctx.use_privatekey_file(keyfile or certfile)
        self._ctx.check_privatekey()

    def load_verify_locations(
        self, cafile: Optional[str] = None, capath: Optional[str] = None
    ) -> None:
        """Load a set of "certification authority"(CA) certificates used to
        validate other peers' certificates when `~verify_mode` is other than
        ssl.CERT_NONE.
        """
        self._ctx.load_verify_locations(cafile, capath)
        # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
        if not hasattr(_SSL.Connection, "get_verified_chain"):
            assert cafile is not None
            self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)

    def _load_certifi(self) -> None:
        """Attempt to load CA certs from certifi."""
        if _HAVE_CERTIFI:
            self.load_verify_locations(certifi.where())
        else:
            raise _ConfigurationError(
                "tlsAllowInvalidCertificates is False but no system "
                "CA certificates could be loaded. Please install the "
                "certifi package, or provide a path to a CA file using "
                "the tlsCAFile option"
            )

    def _load_wincerts(self, store: str) -> None:
        """Attempt to load CA certs from Windows trust store."""
        cert_store = self._ctx.get_cert_store()
        assert cert_store is not None
        oid = _stdlibssl.Purpose.SERVER_AUTH.oid

        for cert, encoding, trust in _stdlibssl.enum_certificates(store):  # type: ignore
            if encoding == "x509_asn":
                if trust is True or oid in trust:
                    cert_store.add_cert(
                        _crypto.X509.from_cryptography(x509.load_der_x509_certificate(cert))
                    )

    def load_default_certs(self) -> None:
        """A PyOpenSSL version of load_default_certs from CPython."""
        # PyOpenSSL is incapable of loading CA certs from Windows, and mostly
        # incapable on macOS.
        # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths
        if _sys.platform == "win32":
            try:
                for storename in ("CA", "ROOT"):
                    self._load_wincerts(storename)
            except PermissionError:
                # Fall back to certifi
                self._load_certifi()
        elif _sys.platform == "darwin":
            self._load_certifi()
        self._ctx.set_default_verify_paths()

    def set_default_verify_paths(self) -> None:
        """Specify that the platform provided CA certificates are to be used
        for verification purposes.
        """
        # Note: See PyOpenSSL's docs for limitations, which are similar
        # but not that same as CPython's.
        self._ctx.set_default_verify_paths()

    def wrap_socket(
        self,
        sock: _socket.socket,
        server_side: bool = False,
        do_handshake_on_connect: bool = True,
        suppress_ragged_eofs: bool = True,
        server_hostname: Optional[str] = None,
        session: Optional[_SSL.Session] = None,
    ) -> _sslConn:
        """Wrap an existing Python socket connection and return a TLS socket
        object.
        """
        ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
        if session:
            ssl_conn.set_session(session)
        if server_side is True:
            ssl_conn.set_accept_state()
        else:
            # SNI
            if server_hostname and not _is_ip_address(server_hostname):
                # XXX: Do this in a callback registered with
                # SSLContext.set_info_callback? See Twisted for an example.
                ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
            if self.verify_mode != _stdlibssl.CERT_NONE:
                # Request a stapled OCSP response.
                ssl_conn.request_ocsp()
            ssl_conn.set_connect_state()
        # If this wasn't true the caller of wrap_socket would call
        # do_handshake()
        if do_handshake_on_connect:
            # XXX: If we do hostname checking in a callback we can get rid
            # of this call to do_handshake() since the handshake
            # will happen automatically later.
            ssl_conn.do_handshake()
            # XXX: Do this in a callback registered with
            # SSLContext.set_info_callback? See Twisted for an example.
            if self.check_hostname and server_hostname is not None:
                from service_identity import pyopenssl

                try:
                    if _is_ip_address(server_hostname):
                        pyopenssl.verify_ip_address(ssl_conn, server_hostname)
                    else:
                        pyopenssl.verify_hostname(ssl_conn, server_hostname)
                except (  # type:ignore[misc]
                    service_identity.SICertificateError,
                    service_identity.SIVerificationError,
                ) as exc:
                    raise _CertificateError(str(exc)) from None
        return ssl_conn
