# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Internal network layer helper methods."""
from __future__ import annotations

import asyncio
import collections
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
from typing import (
    TYPE_CHECKING,
    Any,
    Optional,
    Union,
)

from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress
from pymongo.errors import ProtocolError, _OperationCancelled
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.socket_checker import _errno_from_exception

try:
    from ssl import SSLError, SSLSocket

    _HAVE_SSL = True
except ImportError:
    _HAVE_SSL = False

try:
    from pymongo.pyopenssl_context import _sslConn

    _HAVE_PYOPENSSL = True
except ImportError:
    _HAVE_PYOPENSSL = False
    _sslConn = SSLSocket  # type: ignore[assignment, misc]

from pymongo.ssl_support import (
    BLOCKING_IO_LOOKUP_ERROR,
    BLOCKING_IO_READ_ERROR,
    BLOCKING_IO_WRITE_ERROR,
)

if TYPE_CHECKING:
    from pymongo.asynchronous.pool import AsyncConnection
    from pymongo.synchronous.pool import Connection

_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)


# These socket-based I/O methods are for KMS requests and any other network operations that do not use
# the MongoDB wire protocol
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
    timeout = sock.gettimeout()
    sock.settimeout(0.0)
    loop = asyncio.get_running_loop()
    try:
        if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
            await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
        else:
            await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout)  # type: ignore[arg-type]
    except asyncio.TimeoutError as exc:
        # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
        raise socket.timeout("timed out") from exc
    finally:
        sock.settimeout(timeout)


if sys.platform != "win32":

    async def _async_socket_sendall_ssl(
        sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
    ) -> None:
        view = memoryview(buf)
        sent = 0

        def _is_ready(fut: Future) -> None:
            if fut.done():
                return
            fut.set_result(None)

        while sent < len(buf):
            try:
                sent += sock.send(view[sent:])
            except BLOCKING_IO_ERRORS as exc:
                fd = sock.fileno()
                # Check for closed socket.
                if fd == -1:
                    raise SSLError("Underlying socket has been closed") from None
                if isinstance(exc, BLOCKING_IO_READ_ERROR):
                    fut = loop.create_future()
                    loop.add_reader(fd, _is_ready, fut)
                    try:
                        await fut
                    finally:
                        loop.remove_reader(fd)
                if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
                    fut = loop.create_future()
                    loop.add_writer(fd, _is_ready, fut)
                    try:
                        await fut
                    finally:
                        loop.remove_writer(fd)
                if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
                    fut = loop.create_future()
                    loop.add_reader(fd, _is_ready, fut)
                    try:
                        loop.add_writer(fd, _is_ready, fut)
                        await fut
                    finally:
                        loop.remove_reader(fd)
                        loop.remove_writer(fd)

    async def _async_socket_receive_ssl(
        conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
    ) -> memoryview:
        mv = memoryview(bytearray(length))
        total_read = 0

        def _is_ready(fut: Future) -> None:
            if fut.done():
                return
            fut.set_result(None)

        while total_read < length:
            try:
                read = conn.recv_into(mv[total_read:])
                if read == 0:
                    raise OSError("connection closed")
                # KMS responses update their expected size after the first batch, stop reading after one loop
                if once:
                    return mv[:read]
                total_read += read
            except BLOCKING_IO_ERRORS as exc:
                fd = conn.fileno()
                # Check for closed socket.
                if fd == -1:
                    raise SSLError("Underlying socket has been closed") from None
                if isinstance(exc, BLOCKING_IO_READ_ERROR):
                    fut = loop.create_future()
                    loop.add_reader(fd, _is_ready, fut)
                    try:
                        await fut
                    finally:
                        loop.remove_reader(fd)
                if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
                    fut = loop.create_future()
                    loop.add_writer(fd, _is_ready, fut)
                    try:
                        await fut
                    finally:
                        loop.remove_writer(fd)
                if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
                    fut = loop.create_future()
                    loop.add_reader(fd, _is_ready, fut)
                    try:
                        loop.add_writer(fd, _is_ready, fut)
                        await fut
                    finally:
                        loop.remove_reader(fd)
                        loop.remove_writer(fd)
        return mv

else:
    # The default Windows asyncio event loop does not support loop.add_reader/add_writer:
    # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
    # Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
    async def _async_socket_sendall_ssl(
        sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
    ) -> None:
        view = memoryview(buf)
        total_length = len(buf)
        total_sent = 0
        # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
        # down to 1ms.
        backoff = 0.001
        while total_sent < total_length:
            try:
                sent = sock.send(view[total_sent:])
            except BLOCKING_IO_ERRORS:
                await asyncio.sleep(backoff)
                sent = 0
            if sent > 0:
                backoff = max(backoff / 2, 0.001)
            else:
                backoff = min(backoff * 2, 0.512)
            total_sent += sent

    async def _async_socket_receive_ssl(
        conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
    ) -> memoryview:
        mv = memoryview(bytearray(length))
        total_read = 0
        # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
        # down to 1ms.
        backoff = 0.001
        while total_read < length:
            try:
                read = conn.recv_into(mv[total_read:])
                if read == 0:
                    raise OSError("connection closed")
                # KMS responses update their expected size after the first batch, stop reading after one loop
                if once:
                    return mv[:read]
            except BLOCKING_IO_ERRORS:
                await asyncio.sleep(backoff)
                read = 0
            if read > 0:
                backoff = max(backoff / 2, 0.001)
            else:
                backoff = min(backoff * 2, 0.512)
            total_read += read
        return mv


def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
    sock.sendall(buf)


async def _poll_cancellation(conn: AsyncConnection) -> None:
    while True:
        if conn.cancel_context.cancelled:
            return

        await asyncio.sleep(_POLL_TIMEOUT)


async def async_receive_data_socket(
    sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
    sock_timeout = sock.gettimeout()
    timeout = sock_timeout

    sock.settimeout(0.0)
    loop = asyncio.get_running_loop()
    try:
        if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
            return await asyncio.wait_for(
                _async_socket_receive_ssl(sock, length, loop, once=True),  # type: ignore[arg-type]
                timeout=timeout,
            )
        else:
            return await asyncio.wait_for(
                _async_socket_receive(sock, length, loop),  # type: ignore[arg-type]
                timeout=timeout,
            )
    except asyncio.TimeoutError as err:
        raise socket.timeout("timed out") from err
    finally:
        sock.settimeout(sock_timeout)


async def _async_socket_receive(
    conn: socket.socket, length: int, loop: AbstractEventLoop
) -> memoryview:
    mv = memoryview(bytearray(length))
    bytes_read = 0
    while bytes_read < length:
        chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
        if chunk_length == 0:
            raise OSError("connection closed")
        bytes_read += chunk_length
    return mv


_PYPY = "PyPy" in sys.version


def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
    """Block until at least one byte is read, or a timeout, or a cancel."""
    sock = conn.conn.sock
    timed_out = False
    # Check if the connection's socket has been manually closed
    if sock.fileno() == -1:
        return
    while True:
        # SSLSocket can have buffered data which won't be caught by select.
        if hasattr(sock, "pending") and sock.pending() > 0:
            readable = True
        else:
            # Wait up to 500ms for the socket to become readable and then
            # check for cancellation.
            if deadline:
                remaining = deadline - time.monotonic()
                # When the timeout has expired perform one final check to
                # see if the socket is readable. This helps avoid spurious
                # timeouts on AWS Lambda and other FaaS environments.
                if remaining <= 0:
                    timed_out = True
                timeout = max(min(remaining, _POLL_TIMEOUT), 0)
            else:
                timeout = _POLL_TIMEOUT
            readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
        if conn.cancel_context.cancelled:
            raise _OperationCancelled("operation cancelled")
        if readable:
            return
        if timed_out:
            raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
    buf = bytearray(length)
    mv = memoryview(buf)
    bytes_read = 0
    # To support cancelling a network read, we shorten the socket timeout and
    # check for the cancellation signal after each timeout. Alternatively we
    # could close the socket but that does not reliably cancel recv() calls
    # on all OSes.
    # When the timeout has expired we perform one final non-blocking recv.
    # This helps avoid spurious timeouts when the response is actually already
    # buffered on the client.
    orig_timeout = conn.conn.gettimeout()
    try:
        while bytes_read < length:
            try:
                # Use the legacy wait_for_read cancellation approach on PyPy due to PYTHON-5011.
                if _PYPY:
                    wait_for_read(conn, deadline)
                    if _csot.get_timeout() and deadline is not None:
                        conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
                else:
                    if deadline is not None:
                        short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT)
                    else:
                        short_timeout = _POLL_TIMEOUT
                    conn.set_conn_timeout(short_timeout)

                chunk_length = conn.conn.recv_into(mv[bytes_read:])
            except BLOCKING_IO_ERRORS:
                if conn.cancel_context.cancelled:
                    raise _OperationCancelled("operation cancelled") from None
                # We reached the true deadline.
                raise socket.timeout("timed out") from None
            except socket.timeout:
                if conn.cancel_context.cancelled:
                    raise _OperationCancelled("operation cancelled") from None
                if _PYPY:
                    # We reached the true deadline.
                    raise
                continue
            except OSError as exc:
                if conn.cancel_context.cancelled:
                    raise _OperationCancelled("operation cancelled") from None
                if _errno_from_exception(exc) == errno.EINTR:
                    continue
                raise
            if chunk_length == 0:
                raise OSError("connection closed")

            bytes_read += chunk_length
    finally:
        conn.set_conn_timeout(orig_timeout)

    return mv


class NetworkingInterfaceBase:
    def __init__(self, conn: Any):
        self.conn = conn

    @property
    def gettimeout(self) -> Any:
        raise NotImplementedError

    def settimeout(self, timeout: float | None) -> None:
        raise NotImplementedError

    def close(self) -> Any:
        raise NotImplementedError

    def is_closing(self) -> bool:
        raise NotImplementedError

    @property
    def get_conn(self) -> Any:
        raise NotImplementedError

    @property
    def sock(self) -> Any:
        raise NotImplementedError


class AsyncNetworkingInterface(NetworkingInterfaceBase):
    def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
        super().__init__(conn)

    @property
    def gettimeout(self) -> float | None:
        return self.conn[1].gettimeout

    def settimeout(self, timeout: float | None) -> None:
        self.conn[1].settimeout(timeout)

    async def close(self) -> None:
        self.conn[1].close()
        await self.conn[1].wait_closed()

    def is_closing(self) -> bool:
        return self.conn[0].is_closing()

    @property
    def get_conn(self) -> PyMongoProtocol:
        return self.conn[1]

    @property
    def sock(self) -> socket.socket:
        return self.conn[0].get_extra_info("socket")


class NetworkingInterface(NetworkingInterfaceBase):
    def __init__(self, conn: Union[socket.socket, _sslConn]):
        super().__init__(conn)

    def gettimeout(self) -> float | None:
        return self.conn.gettimeout()

    def settimeout(self, timeout: float | None) -> None:
        self.conn.settimeout(timeout)

    def close(self) -> None:
        self.conn.close()

    def is_closing(self) -> bool:
        return self.conn.is_closing()

    @property
    def get_conn(self) -> Union[socket.socket, _sslConn]:
        return self.conn

    @property
    def sock(self) -> Union[socket.socket, _sslConn]:
        return self.conn

    def fileno(self) -> int:
        return self.conn.fileno()

    def recv_into(self, buffer: bytes) -> int:
        return self.conn.recv_into(buffer)


class PyMongoProtocol(BufferedProtocol):
    def __init__(self, timeout: Optional[float] = None):
        self.transport: Transport = None  # type: ignore[assignment]
        # Each message is reader in 2-3 parts: header, compression header, and message body
        # The message buffer is allocated after the header is read.
        self._header = memoryview(bytearray(16))
        self._header_index = 0
        self._compression_header = memoryview(bytearray(9))
        self._compression_index = 0
        self._message: Optional[memoryview] = None
        self._message_index = 0
        # State. TODO: replace booleans with an enum?
        self._expecting_header = True
        self._expecting_compression = False
        self._message_size = 0
        self._op_code = 0
        self._connection_lost = False
        self._read_waiter: Optional[Future] = None
        self._timeout = timeout
        self._is_compressed = False
        self._compressor_id: Optional[int] = None
        self._max_message_size = MAX_MESSAGE_SIZE
        self._response_to: Optional[int] = None
        self._closed = asyncio.get_running_loop().create_future()
        self._pending_messages: collections.deque[Future] = collections.deque()
        self._done_messages: collections.deque[Future] = collections.deque()

    def settimeout(self, timeout: float | None) -> None:
        self._timeout = timeout

    @property
    def gettimeout(self) -> float | None:
        """The configured timeout for the socket that underlies our protocol pair."""
        return self._timeout

    def connection_made(self, transport: BaseTransport) -> None:
        """Called exactly once when a connection is made.
        The transport argument is the transport representing the write side of the connection.
        """
        self.transport = transport  # type: ignore[assignment]
        self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)

    async def write(self, message: bytes) -> None:
        """Write a message to this connection's transport."""
        if self.transport.is_closing():
            raise OSError("Connection is closed")
        self.transport.write(message)
        self.transport.resume_reading()

    async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
        """Read a single MongoDB Wire Protocol message from this connection."""
        if self.transport:
            try:
                self.transport.resume_reading()
            # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
            except AttributeError:
                raise OSError("connection is already closed") from None
        self._max_message_size = max_message_size
        if self._done_messages:
            message = await self._done_messages.popleft()
        else:
            if self.transport and self.transport.is_closing():
                raise OSError("connection is already closed")
            read_waiter = asyncio.get_running_loop().create_future()
            self._pending_messages.append(read_waiter)
            try:
                message = await read_waiter
            finally:
                if read_waiter in self._done_messages:
                    self._done_messages.remove(read_waiter)
        if message:
            op_code, compressor_id, response_to, data = message
            # No request_id for exhaust cursor "getMore".
            if request_id is not None:
                if request_id != response_to:
                    raise ProtocolError(
                        f"Got response id {response_to!r} but expected {request_id!r}"
                    )
            if compressor_id is not None:
                data = decompress(data, compressor_id)
            return data, op_code
        raise OSError("connection closed")

    def get_buffer(self, sizehint: int) -> memoryview:
        """Called to allocate a new receive buffer.
        The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
        If any data does not fit into the returned buffer, this method will be called again until
        either no data remains or an empty buffer is returned.
        """
        # Due to a bug, Python <=3.11 will call get_buffer() even after we raise
        # ProtocolError in buffer_updated() and call connection_lost(). We allocate
        # a temp buffer to drain the waiting data.
        if self._connection_lost:
            if not self._message:
                self._message = memoryview(bytearray(2**14))
            return self._message
        # TODO: optimize this by caching pointers to the buffers.
        # return self._buffer[self._index:]
        if self._expecting_header:
            return self._header[self._header_index :]
        if self._expecting_compression:
            return self._compression_header[self._compression_index :]
        return self._message[self._message_index :]  # type: ignore[index]

    def buffer_updated(self, nbytes: int) -> None:
        """Called when the buffer was updated with the received data"""
        # Wrote 0 bytes into a non-empty buffer, signal connection closed
        if nbytes == 0:
            self.close(OSError("connection closed"))
            return
        if self._connection_lost:
            return
        if self._expecting_header:
            self._header_index += nbytes
            if self._header_index >= 16:
                self._expecting_header = False
                try:
                    (
                        self._message_size,
                        self._op_code,
                        self._response_to,
                        self._expecting_compression,
                    ) = self.process_header()
                except ProtocolError as exc:
                    self.close(exc)
                    return
                self._message = memoryview(bytearray(self._message_size))
            return
        if self._expecting_compression:
            self._compression_index += nbytes
            if self._compression_index >= 9:
                self._expecting_compression = False
                self._op_code, self._compressor_id = self.process_compression_header()
            return

        self._message_index += nbytes
        if self._message_index >= self._message_size:
            self._expecting_header = True
            # Pause reading to avoid storing an arbitrary number of messages in memory.
            self.transport.pause_reading()
            if self._pending_messages:
                result = self._pending_messages.popleft()
            else:
                result = asyncio.get_running_loop().create_future()
            # Future has been cancelled, close this connection
            if result.done():
                self.close(None)
                return
            # Necessary values to reconstruct and verify message
            result.set_result(
                (self._op_code, self._compressor_id, self._response_to, self._message)
            )
            self._done_messages.append(result)
            # Reset internal state to expect a new message
            self._header_index = 0
            self._compression_index = 0
            self._message_index = 0
            self._message_size = 0
            self._message = None
            self._op_code = 0
            self._compressor_id = None
            self._response_to = None

    def process_header(self) -> tuple[int, int, int, bool]:
        """Unpack a MongoDB Wire Protocol header."""
        length, _, response_to, op_code = _UNPACK_HEADER(self._header)
        expecting_compression = False
        if op_code == 2012:  # OP_COMPRESSED
            if length <= 25:
                raise ProtocolError(
                    f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
                )
            expecting_compression = True
            length -= 9
        if length <= 16:
            raise ProtocolError(
                f"Message length ({length!r}) not longer than standard message header size (16)"
            )
        if length > self._max_message_size:
            raise ProtocolError(
                f"Message length ({length!r}) is larger than server max "
                f"message size ({self._max_message_size!r})"
            )

        return length - 16, op_code, response_to, expecting_compression

    def process_compression_header(self) -> tuple[int, int]:
        """Unpack a MongoDB Wire Protocol compression header."""
        op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
        return op_code, compressor_id

    def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
        pending = list(self._pending_messages)
        for msg in pending:
            if not msg.done():
                if exc is None:
                    msg.set_result(None)
                else:
                    msg.set_exception(exc)
            self._done_messages.append(msg)

    def close(self, exc: Optional[Exception] = None) -> None:
        self.transport.abort()
        self._resolve_pending_messages(exc)
        self._connection_lost = True

    def connection_lost(self, exc: Optional[Exception] = None) -> None:
        self._resolve_pending_messages(exc)
        if not self._closed.done():
            self._closed.set_result(None)

    async def wait_closed(self) -> None:
        await self._closed


async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
    try:
        await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
    except asyncio.TimeoutError as exc:
        # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
        raise socket.timeout("timed out") from exc


async def async_receive_message(
    conn: AsyncConnection,
    request_id: Optional[int],
    max_message_size: int = MAX_MESSAGE_SIZE,
) -> Union[_OpReply, _OpMsg]:
    """Receive a raw BSON message or raise socket.error."""
    timeout: Optional[Union[float, int]]
    timeout = conn.conn.gettimeout
    if _csot.get_timeout():
        deadline = _csot.get_deadline()
    else:
        if timeout:
            deadline = time.monotonic() + timeout
        else:
            deadline = None
    if deadline:
        # When the timeout has expired perform one final check to
        # see if the socket is readable. This helps avoid spurious
        # timeouts on AWS Lambda and other FaaS environments.
        timeout = max(deadline - time.monotonic(), 0)

    cancellation_task = create_task(_poll_cancellation(conn))
    read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
    tasks = [read_task, cancellation_task]
    try:
        done, pending = await asyncio.wait(
            tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
        )
        for task in pending:
            task.cancel()
        if pending:
            await asyncio.wait(pending)
        if len(done) == 0:
            raise socket.timeout("timed out")
        if read_task in done:
            data, op_code = read_task.result()
            try:
                unpack_reply = _UNPACK_REPLY[op_code]
            except KeyError:
                raise ProtocolError(
                    f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
                ) from None
            return unpack_reply(data)
        raise _OperationCancelled("operation cancelled")
    except asyncio.CancelledError:
        for task in tasks:
            task.cancel()
        await asyncio.wait(tasks)
        raise


def receive_message(
    conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
    """Receive a raw BSON message or raise socket.error."""
    if _csot.get_timeout():
        deadline = _csot.get_deadline()
    else:
        timeout = conn.conn.gettimeout()
        if timeout:
            deadline = time.monotonic() + timeout
        else:
            deadline = None
    # Ignore the response's request id.
    length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
    # No request_id for exhaust cursor "getMore".
    if request_id is not None:
        if request_id != response_to:
            raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
    if length <= 16:
        raise ProtocolError(
            f"Message length ({length!r}) not longer than standard message header size (16)"
        )
    if length > max_message_size:
        raise ProtocolError(
            f"Message length ({length!r}) is larger than server max "
            f"message size ({max_message_size!r})"
        )
    if op_code == 2012:
        op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
        data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
    else:
        data = receive_data(conn, length - 16, deadline)

    try:
        unpack_reply = _UNPACK_REPLY[op_code]
    except KeyError:
        raise ProtocolError(
            f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
        ) from None
    return unpack_reply(data)
