# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tools for creating `messages
<https://www.mongodb.com/docs/manual/reference/mongodb-wire-protocol/>`_ to be sent to
MongoDB.

.. note:: This module is for internal use and is generally not needed by
   application developers.
"""
from __future__ import annotations

import datetime
import random
import struct
from io import BytesIO as _BytesIO
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Iterable,
    Mapping,
    MutableMapping,
    NoReturn,
    Optional,
    Union,
)

import bson
from bson import CodecOptions, _dict_to_bson, _make_c_string
from bson.int64 import Int64
from bson.raw_bson import (
    _RAW_ARRAY_BSON_OPTIONS,
    DEFAULT_RAW_BSON_OPTIONS,
    RawBSONDocument,
    _inflate_bson,
)
from pymongo.hello import HelloCompat
from pymongo.monitoring import _EventListeners

try:
    from pymongo import _cmessage  # type: ignore[attr-defined]

    _use_c = True
except ImportError:
    _use_c = False
from pymongo.errors import (
    ConfigurationError,
    CursorNotFound,
    DocumentTooLarge,
    ExecutionTimeout,
    InvalidOperation,
    NotPrimaryError,
    OperationFailure,
    ProtocolError,
)
from pymongo.read_preferences import ReadPreference, _ServerMode

if TYPE_CHECKING:
    from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
    from pymongo.read_concern import ReadConcern
    from pymongo.typings import (
        _Address,
        _AgnosticClientSession,
        _AgnosticConnection,
        _AgnosticMongoClient,
        _DocumentOut,
    )


MAX_INT32 = 2147483647
MIN_INT32 = -2147483648

# Overhead allowed for encoded command documents.
_COMMAND_OVERHEAD = 16382

_INSERT = 0
_UPDATE = 1
_DELETE = 2

_EMPTY = b""
_BSONOBJ = b"\x03"
_ZERO_8 = b"\x00"
_ZERO_16 = b"\x00\x00"
_ZERO_32 = b"\x00\x00\x00\x00"
_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00"
_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff"
_OP_MAP = {
    _INSERT: b"\x04documents\x00\x00\x00\x00\x00",
    _UPDATE: b"\x04updates\x00\x00\x00\x00\x00",
    _DELETE: b"\x04deletes\x00\x00\x00\x00\x00",
}
_FIELD_MAP = {
    "insert": "documents",
    "update": "updates",
    "delete": "deletes",
    "bulkWrite": "ops",
}

_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions(
    unicode_decode_error_handler="replace"
)


def _randint() -> int:
    """Generate a pseudo random 32 bit integer."""
    return random.randint(MIN_INT32, MAX_INT32)  # noqa: S311


def _maybe_add_read_preference(
    spec: MutableMapping[str, Any], read_preference: _ServerMode
) -> MutableMapping[str, Any]:
    """Add $readPreference to spec when appropriate."""
    mode = read_preference.mode
    document = read_preference.document
    # Only add $readPreference if it's something other than primary to avoid
    # problems with mongos versions that don't support read preferences. Also,
    # for maximum backwards compatibility, don't add $readPreference for
    # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting
    # the secondaryOkay bit has the same effect).
    if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1):
        if "$query" not in spec:
            spec = {"$query": spec}
        spec["$readPreference"] = document
    return spec


def _convert_exception(exception: Exception) -> dict[str, Any]:
    """Convert an Exception into a failure document for publishing."""
    return {"errmsg": str(exception), "errtype": exception.__class__.__name__}


def _convert_client_bulk_exception(exception: Exception) -> dict[str, Any]:
    """Convert an Exception into a failure document for publishing,
    for use in client-level bulk write API.
    """
    return {
        "errmsg": str(exception),
        "code": exception.code,  # type: ignore[attr-defined]
        "errtype": exception.__class__.__name__,
    }


def _convert_write_result(
    operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
) -> dict[str, Any]:
    """Convert a legacy write result to write command format."""
    # Based on _merge_legacy from bulk.py
    affected = result.get("n", 0)
    res = {"ok": 1, "n": affected}
    errmsg = result.get("errmsg", result.get("err", ""))
    if errmsg:
        # The write was successful on at least the primary so don't return.
        if result.get("wtimeout"):
            res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}}
        else:
            # The write failed.
            error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg}
            if "errInfo" in result:
                error["errInfo"] = result["errInfo"]
            res["writeErrors"] = [error]
            return res
    if operation == "insert":
        # GLE result for insert is always 0 in most MongoDB versions.
        res["n"] = len(command["documents"])
    elif operation == "update":
        if "upserted" in result:
            res["upserted"] = [{"index": 0, "_id": result["upserted"]}]
        # Versions of MongoDB before 2.6 don't return the _id for an
        # upsert if _id is not an ObjectId.
        elif result.get("updatedExisting") is False and affected == 1:
            # If _id is in both the update document *and* the query spec
            # the update document _id takes precedence.
            update = command["updates"][0]
            _id = update["u"].get("_id", update["q"].get("_id"))
            res["upserted"] = [{"index": 0, "_id": _id}]
    return res


_OPTIONS = {
    "tailable": 2,
    "oplogReplay": 8,
    "noCursorTimeout": 16,
    "awaitData": 32,
    "allowPartialResults": 128,
}


_MODIFIERS = {
    "$query": "filter",
    "$orderby": "sort",
    "$hint": "hint",
    "$comment": "comment",
    "$maxScan": "maxScan",
    "$maxTimeMS": "maxTimeMS",
    "$max": "max",
    "$min": "min",
    "$returnKey": "returnKey",
    "$showRecordId": "showRecordId",
    "$showDiskLoc": "showRecordId",  # <= MongoDb 3.0
    "$snapshot": "snapshot",
}


def _gen_find_command(
    coll: str,
    spec: Mapping[str, Any],
    projection: Optional[Union[Mapping[str, Any], Iterable[str]]],
    skip: int,
    limit: int,
    batch_size: Optional[int],
    options: Optional[int],
    read_concern: ReadConcern,
    collation: Optional[Mapping[str, Any]] = None,
    session: Optional[_AgnosticClientSession] = None,
    allow_disk_use: Optional[bool] = None,
) -> dict[str, Any]:
    """Generate a find command document."""
    cmd: dict[str, Any] = {"find": coll}
    if "$query" in spec:
        cmd.update(
            [
                (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val)
                for key, val in spec.items()
            ]
        )
        if "$explain" in cmd:
            cmd.pop("$explain")
        if "$readPreference" in cmd:
            cmd.pop("$readPreference")
    else:
        cmd["filter"] = spec

    if projection:
        cmd["projection"] = projection
    if skip:
        cmd["skip"] = skip
    if limit:
        cmd["limit"] = abs(limit)
        if limit < 0:
            cmd["singleBatch"] = True
    if batch_size:
        # When limit and batchSize are equal we increase batchSize by 1 to
        # avoid an unnecessary killCursors.
        if limit == batch_size:
            batch_size += 1
        cmd["batchSize"] = batch_size
    if read_concern.level and not (session and session.in_transaction):
        cmd["readConcern"] = read_concern.document
    if collation:
        cmd["collation"] = collation
    if allow_disk_use is not None:
        cmd["allowDiskUse"] = allow_disk_use
    if options:
        cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val])

    return cmd


def _gen_get_more_command(
    cursor_id: Optional[int],
    coll: str,
    batch_size: Optional[int],
    max_await_time_ms: Optional[int],
    comment: Optional[Any],
    conn: _AgnosticConnection,
) -> dict[str, Any]:
    """Generate a getMore command document."""
    cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll}
    if batch_size:
        cmd["batchSize"] = batch_size
    if max_await_time_ms is not None:
        cmd["maxTimeMS"] = max_await_time_ms
    if comment is not None and conn.max_wire_version >= 9:
        cmd["comment"] = comment
    return cmd


_pack_compression_header = struct.Struct("<iiiiiiB").pack
_COMPRESSION_HEADER_SIZE = 25


def _compress(
    operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext]
) -> tuple[int, bytes]:
    """Takes message data, compresses it, and adds an OP_COMPRESSED header."""
    compressed = ctx.compress(data)
    request_id = _randint()

    header = _pack_compression_header(
        _COMPRESSION_HEADER_SIZE + len(compressed),  # Total message length
        request_id,  # Request id
        0,  # responseTo
        2012,  # operation id
        operation,  # original operation id
        len(data),  # uncompressed message length
        ctx.compressor_id,
    )  # compressor id
    return request_id, header + compressed


_pack_header = struct.Struct("<iiii").pack


def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]:
    """Takes message data and adds a message header based on the operation.

    Returns the resultant message string.
    """
    rid = _randint()
    message = _pack_header(16 + len(data), rid, 0, operation)
    return rid, message + data


_pack_int = struct.Struct("<i").pack
_pack_op_msg_flags_type = struct.Struct("<IB").pack
_pack_byte = struct.Struct("<B").pack


def _op_msg_no_header(
    flags: int,
    command: Mapping[str, Any],
    identifier: str,
    docs: Optional[list[Mapping[str, Any]]],
    opts: CodecOptions,
) -> tuple[bytes, int, int]:
    """Get a OP_MSG message.

    Note: this method handles multiple documents in a type one payload but
    it does not perform batch splitting and the total message size is
    only checked *after* generating the entire message.
    """
    # Encode the command document in payload 0 without checking keys.
    encoded = _dict_to_bson(command, False, opts)
    flags_type = _pack_op_msg_flags_type(flags, 0)
    total_size = len(encoded)
    max_doc_size = 0
    if identifier and docs is not None:
        type_one = _pack_byte(1)
        cstring = _make_c_string(identifier)
        encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs]
        size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4
        encoded_size = _pack_int(size)
        total_size += size
        max_doc_size = max(len(doc) for doc in encoded_docs)
        data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs]
    else:
        data = [flags_type, encoded]
    return b"".join(data), total_size, max_doc_size


def _op_msg_compressed(
    flags: int,
    command: Mapping[str, Any],
    identifier: str,
    docs: Optional[list[Mapping[str, Any]]],
    opts: CodecOptions,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int, int]:
    """Internal OP_MSG message helper."""
    msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
    rid, msg = _compress(2013, msg, ctx)
    return rid, msg, total_size, max_bson_size


def _op_msg_uncompressed(
    flags: int,
    command: Mapping[str, Any],
    identifier: str,
    docs: Optional[list[Mapping[str, Any]]],
    opts: CodecOptions,
) -> tuple[int, bytes, int, int]:
    """Internal compressed OP_MSG message helper."""
    data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
    request_id, op_message = __pack_message(2013, data)
    return request_id, op_message, total_size, max_bson_size


if _use_c:
    _op_msg_uncompressed = _cmessage._op_msg


def _op_msg(
    flags: int,
    command: MutableMapping[str, Any],
    dbname: str,
    read_preference: Optional[_ServerMode],
    opts: CodecOptions,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int, int]:
    """Get a OP_MSG message."""
    command["$db"] = dbname
    # getMore commands do not send $readPreference.
    if read_preference is not None and "$readPreference" not in command:
        # Only send $readPreference if it's not primary (the default).
        if read_preference.mode:
            command["$readPreference"] = read_preference.document
    name = next(iter(command))
    try:
        identifier = _FIELD_MAP[name]
        docs = command.pop(identifier)
    except KeyError:
        identifier = ""
        docs = None
    try:
        if ctx:
            return _op_msg_compressed(flags, command, identifier, docs, opts, ctx)
        return _op_msg_uncompressed(flags, command, identifier, docs, opts)
    finally:
        # Add the field back to the command.
        if identifier:
            command[identifier] = docs


def _query_impl(
    options: int,
    collection_name: str,
    num_to_skip: int,
    num_to_return: int,
    query: Mapping[str, Any],
    field_selector: Optional[Mapping[str, Any]],
    opts: CodecOptions,
) -> tuple[bytes, int]:
    """Get an OP_QUERY message."""
    encoded = _dict_to_bson(query, False, opts)
    if field_selector:
        efs = _dict_to_bson(field_selector, False, opts)
    else:
        efs = b""
    max_bson_size = max(len(encoded), len(efs))
    return (
        b"".join(
            [
                _pack_int(options),
                bson._make_c_string(collection_name),
                _pack_int(num_to_skip),
                _pack_int(num_to_return),
                encoded,
                efs,
            ]
        ),
        max_bson_size,
    )


def _query_compressed(
    options: int,
    collection_name: str,
    num_to_skip: int,
    num_to_return: int,
    query: Mapping[str, Any],
    field_selector: Optional[Mapping[str, Any]],
    opts: CodecOptions,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int]:
    """Internal compressed query message helper."""
    op_query, max_bson_size = _query_impl(
        options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
    )
    rid, msg = _compress(2004, op_query, ctx)
    return rid, msg, max_bson_size


def _query_uncompressed(
    options: int,
    collection_name: str,
    num_to_skip: int,
    num_to_return: int,
    query: Mapping[str, Any],
    field_selector: Optional[Mapping[str, Any]],
    opts: CodecOptions,
) -> tuple[int, bytes, int]:
    """Internal query message helper."""
    op_query, max_bson_size = _query_impl(
        options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
    )
    rid, msg = __pack_message(2004, op_query)
    return rid, msg, max_bson_size


if _use_c:
    _query_uncompressed = _cmessage._query_message


def _query(
    options: int,
    collection_name: str,
    num_to_skip: int,
    num_to_return: int,
    query: Mapping[str, Any],
    field_selector: Optional[Mapping[str, Any]],
    opts: CodecOptions,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int]:
    """Get a **query** message."""
    if ctx:
        return _query_compressed(
            options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx
        )
    return _query_uncompressed(
        options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
    )


_pack_long_long = struct.Struct("<q").pack


def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes:
    """Get an OP_GET_MORE message."""
    return b"".join(
        [
            _ZERO_32,
            bson._make_c_string(collection_name),
            _pack_int(num_to_return),
            _pack_long_long(cursor_id),
        ]
    )


def _get_more_compressed(
    collection_name: str,
    num_to_return: int,
    cursor_id: int,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes]:
    """Internal compressed getMore message helper."""
    return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)


def _get_more_uncompressed(
    collection_name: str, num_to_return: int, cursor_id: int
) -> tuple[int, bytes]:
    """Internal getMore message helper."""
    return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))


if _use_c:
    _get_more_uncompressed = _cmessage._get_more_message


def _get_more(
    collection_name: str,
    num_to_return: int,
    cursor_id: int,
    ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes]:
    """Get a **getMore** message."""
    if ctx:
        return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
    return _get_more_uncompressed(collection_name, num_to_return, cursor_id)


# OP_MSG -------------------------------------------------------------


_OP_MSG_MAP = {
    _INSERT: b"documents\x00",
    _UPDATE: b"updates\x00",
    _DELETE: b"deletes\x00",
}


class _BulkWriteContextBase:
    """Private base class for wrapping around AsyncConnection to use with write splitting functions."""

    __slots__ = (
        "db_name",
        "conn",
        "op_id",
        "name",
        "field",
        "publish",
        "start_time",
        "listeners",
        "session",
        "compress",
        "op_type",
        "codec",
    )

    def __init__(
        self,
        database_name: str,
        cmd_name: str,
        conn: _AgnosticConnection,
        operation_id: int,
        listeners: _EventListeners,
        session: Optional[_AgnosticClientSession],
        op_type: int,
        codec: CodecOptions,
    ):
        self.db_name = database_name
        self.conn = conn
        self.op_id = operation_id
        self.listeners = listeners
        self.publish = listeners.enabled_for_commands
        self.name = cmd_name
        self.field = _FIELD_MAP[self.name]
        self.start_time = datetime.datetime.now()
        self.session = session
        self.compress = bool(conn.compression_context)
        self.op_type = op_type
        self.codec = codec

    @property
    def max_bson_size(self) -> int:
        """A proxy for SockInfo.max_bson_size."""
        return self.conn.max_bson_size

    @property
    def max_message_size(self) -> int:
        """A proxy for SockInfo.max_message_size."""
        if self.compress:
            # Subtract 16 bytes for the message header.
            return self.conn.max_message_size - 16
        return self.conn.max_message_size

    @property
    def max_write_batch_size(self) -> int:
        """A proxy for SockInfo.max_write_batch_size."""
        return self.conn.max_write_batch_size

    @property
    def max_split_size(self) -> int:
        """The maximum size of a BSON command before batch splitting."""
        return self.max_bson_size

    def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None:
        """Publish a CommandSucceededEvent."""
        self.listeners.publish_command_success(
            duration,
            reply,
            self.name,
            request_id,
            self.conn.address,
            self.conn.server_connection_id,
            self.op_id,
            self.conn.service_id,
            database_name=self.db_name,
        )

    def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None:
        """Publish a CommandFailedEvent."""
        self.listeners.publish_command_failure(
            duration,
            failure,
            self.name,
            request_id,
            self.conn.address,
            self.conn.server_connection_id,
            self.op_id,
            self.conn.service_id,
            database_name=self.db_name,
        )


class _BulkWriteContext(_BulkWriteContextBase):
    """A wrapper around AsyncConnection/Connection for use with the collection-level bulk write API."""

    __slots__ = ()

    def __init__(
        self,
        database_name: str,
        cmd_name: str,
        conn: _AgnosticConnection,
        operation_id: int,
        listeners: _EventListeners,
        session: Optional[_AgnosticClientSession],
        op_type: int,
        codec: CodecOptions,
    ):
        super().__init__(
            database_name,
            cmd_name,
            conn,
            operation_id,
            listeners,
            session,
            op_type,
            codec,
        )

    def batch_command(
        self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
    ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]:
        namespace = self.db_name + ".$cmd"
        request_id, msg, to_send = _do_batched_op_msg(
            namespace, self.op_type, cmd, docs, self.codec, self
        )
        if not to_send:
            raise InvalidOperation("cannot do an empty bulk write")
        return request_id, msg, to_send

    def _start(
        self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]]
    ) -> MutableMapping[str, Any]:
        """Publish a CommandStartedEvent."""
        cmd[self.field] = docs
        self.listeners.publish_command_start(
            cmd,
            self.db_name,
            request_id,
            self.conn.address,
            self.conn.server_connection_id,
            self.op_id,
            self.conn.service_id,
        )
        return cmd


class _EncryptedBulkWriteContext(_BulkWriteContext):
    __slots__ = ()

    def batch_command(
        self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
    ) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]:
        namespace = self.db_name + ".$cmd"
        msg, to_send = _encode_batched_write_command(
            namespace, self.op_type, cmd, docs, self.codec, self
        )
        if not to_send:
            raise InvalidOperation("cannot do an empty bulk write")

        # Chop off the OP_QUERY header to get a properly batched write command.
        cmd_start = msg.index(b"\x00", 4) + 9
        outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS)
        return -1, outgoing, to_send

    @property
    def max_split_size(self) -> int:
        """Reduce the batch splitting size."""
        return _MAX_SPLIT_SIZE_ENC


def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn:
    """Internal helper for raising DocumentTooLarge."""
    if operation == "insert":
        raise DocumentTooLarge(
            "BSON document too large (%d bytes)"
            " - the connected server supports"
            " BSON document sizes up to %d"
            " bytes." % (doc_size, max_size)
        )
    else:
        # There's nothing intelligent we can say
        # about size for update and delete
        raise DocumentTooLarge(f"{operation!r} command document too large")


# From the Client Side Encryption spec:
# Because automatic encryption increases the size of commands, the driver
# MUST split bulk writes at a reduced size limit before undergoing automatic
# encryption. The write payload MUST be split at 2MiB (2097152).
_MAX_SPLIT_SIZE_ENC = 2097152


def _batched_op_msg_impl(
    operation: int,
    command: Mapping[str, Any],
    docs: list[Mapping[str, Any]],
    ack: bool,
    opts: CodecOptions,
    ctx: _BulkWriteContext,
    buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
    """Create a batched OP_MSG write."""
    max_bson_size = ctx.max_bson_size
    max_write_batch_size = ctx.max_write_batch_size
    max_message_size = ctx.max_message_size

    flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
    # Flags
    buf.write(flags)

    # Type 0 Section
    buf.write(b"\x00")
    buf.write(_dict_to_bson(command, False, opts))

    # Type 1 Section
    buf.write(b"\x01")
    size_location = buf.tell()
    # Save space for size
    buf.write(b"\x00\x00\x00\x00")
    try:
        buf.write(_OP_MSG_MAP[operation])
    except KeyError:
        raise InvalidOperation("Unknown command") from None

    to_send = []
    idx = 0
    for doc in docs:
        # Encode the current operation
        value = _dict_to_bson(doc, False, opts)
        doc_length = len(value)
        new_message_size = buf.tell() + doc_length
        # Does first document exceed max_message_size?
        doc_too_large = idx == 0 and (new_message_size > max_message_size)
        # When OP_MSG is used unacknowledged we have to check
        # document size client side or applications won't be notified.
        # Otherwise we let the server deal with documents that are too large
        # since ordered=False causes those documents to be skipped instead of
        # halting the bulk write operation.
        unacked_doc_too_large = not ack and (doc_length > max_bson_size)
        if doc_too_large or unacked_doc_too_large:
            write_op = list(_FIELD_MAP.keys())[operation]
            _raise_document_too_large(write_op, len(value), max_bson_size)
        # We have enough data, return this batch.
        if new_message_size > max_message_size:
            break
        buf.write(value)
        to_send.append(doc)
        idx += 1
        # We have enough documents, return this batch.
        if idx == max_write_batch_size:
            break

    # Write type 1 section size
    length = buf.tell()
    buf.seek(size_location)
    buf.write(_pack_int(length - size_location))

    return to_send, length


def _encode_batched_op_msg(
    operation: int,
    command: Mapping[str, Any],
    docs: list[Mapping[str, Any]],
    ack: bool,
    opts: CodecOptions,
    ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
    """Encode the next batched insert, update, or delete operation
    as OP_MSG.
    """
    buf = _BytesIO()

    to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)
    return buf.getvalue(), to_send


if _use_c:
    _encode_batched_op_msg = _cmessage._encode_batched_op_msg


def _batched_op_msg_compressed(
    operation: int,
    command: Mapping[str, Any],
    docs: list[Mapping[str, Any]],
    ack: bool,
    opts: CodecOptions,
    ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
    """Create the next batched insert, update, or delete operation
    with OP_MSG, compressed.
    """
    data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)

    assert ctx.conn.compression_context is not None
    request_id, msg = _compress(2013, data, ctx.conn.compression_context)
    return request_id, msg, to_send


def _batched_op_msg(
    operation: int,
    command: Mapping[str, Any],
    docs: list[Mapping[str, Any]],
    ack: bool,
    opts: CodecOptions,
    ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
    """OP_MSG implementation entry point."""
    buf = _BytesIO()

    # Save space for message length and request id
    buf.write(_ZERO_64)
    # responseTo, opCode
    buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")

    to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)

    # Header - request id and message length
    buf.seek(4)
    request_id = _randint()
    buf.write(_pack_int(request_id))
    buf.seek(0)
    buf.write(_pack_int(length))

    return request_id, buf.getvalue(), to_send


if _use_c:
    _batched_op_msg = _cmessage._batched_op_msg


def _do_batched_op_msg(
    namespace: str,
    operation: int,
    command: MutableMapping[str, Any],
    docs: list[Mapping[str, Any]],
    opts: CodecOptions,
    ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
    """Create the next batched insert, update, or delete operation
    using OP_MSG.
    """
    command["$db"] = namespace.split(".", 1)[0]
    if "writeConcern" in command:
        ack = bool(command["writeConcern"].get("w", 1))
    else:
        ack = True
    if ctx.conn.compression_context:
        return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
    return _batched_op_msg(operation, command, docs, ack, opts, ctx)


class _ClientBulkWriteContext(_BulkWriteContextBase):
    """A wrapper around AsyncConnection/Connection for use with the client-level bulk write API."""

    __slots__ = ()

    def __init__(
        self,
        database_name: str,
        cmd_name: str,
        conn: _AgnosticConnection,
        operation_id: int,
        listeners: _EventListeners,
        session: Optional[_AgnosticClientSession],
        codec: CodecOptions,
    ):
        super().__init__(
            database_name,
            cmd_name,
            conn,
            operation_id,
            listeners,
            session,
            0,
            codec,
        )

    def batch_command(
        self,
        cmd: MutableMapping[str, Any],
        operations: list[tuple[str, Mapping[str, Any]]],
        namespaces: list[str],
    ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
        request_id, msg, to_send_ops, to_send_ns = _client_do_batched_op_msg(
            cmd, operations, namespaces, self.codec, self
        )
        if not to_send_ops:
            raise InvalidOperation("cannot do an empty bulk write")
        return request_id, msg, to_send_ops, to_send_ns

    def _start(
        self,
        cmd: MutableMapping[str, Any],
        request_id: int,
        op_docs: list[Mapping[str, Any]],
        ns_docs: list[Mapping[str, Any]],
    ) -> MutableMapping[str, Any]:
        """Publish a CommandStartedEvent."""
        cmd["ops"] = op_docs
        cmd["nsInfo"] = ns_docs
        self.listeners.publish_command_start(
            cmd,
            self.db_name,
            request_id,
            self.conn.address,
            self.conn.server_connection_id,
            self.op_id,
            self.conn.service_id,
        )
        return cmd


_OP_MSG_OVERHEAD = 1000


def _client_construct_op_msg(
    command_encoded: bytes,
    to_send_ops_encoded: list[bytes],
    to_send_ns_encoded: list[bytes],
    ack: bool,
    buf: _BytesIO,
) -> int:
    # Write flags
    flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
    buf.write(flags)

    # Type 0 Section
    buf.write(b"\x00")
    buf.write(command_encoded)

    # Type 1 Section for ops
    buf.write(b"\x01")
    size_location = buf.tell()
    # Save space for size
    buf.write(b"\x00\x00\x00\x00")
    buf.write(b"ops\x00")
    # Write all the ops documents
    for op_encoded in to_send_ops_encoded:
        buf.write(op_encoded)
    resume_location = buf.tell()
    # Write type 1 section size
    length = buf.tell()
    buf.seek(size_location)
    buf.write(_pack_int(length - size_location))
    buf.seek(resume_location)

    # Type 1 Section for nsInfo
    buf.write(b"\x01")
    size_location = buf.tell()
    # Save space for size
    buf.write(b"\x00\x00\x00\x00")
    buf.write(b"nsInfo\x00")
    # Write all the nsInfo documents
    for ns_encoded in to_send_ns_encoded:
        buf.write(ns_encoded)
    # Write type 1 section size
    length = buf.tell()
    buf.seek(size_location)
    buf.write(_pack_int(length - size_location))

    return length


def _client_batched_op_msg_impl(
    command: Mapping[str, Any],
    operations: list[tuple[str, Mapping[str, Any]]],
    namespaces: list[str],
    ack: bool,
    opts: CodecOptions,
    ctx: _ClientBulkWriteContext,
    buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], int]:
    """Create a batched OP_MSG write for client-level bulk write."""

    def _check_doc_size_limits(
        op_type: str,
        doc_size: int,
        limit: int,
    ) -> None:
        if doc_size > limit:
            _raise_document_too_large(op_type, doc_size, limit)

    max_bson_size = ctx.max_bson_size
    max_write_batch_size = ctx.max_write_batch_size
    max_message_size = ctx.max_message_size

    command_encoded = _dict_to_bson(command, False, opts)
    # When OP_MSG is used unacknowledged we have to check command
    # document size client-side or applications won't be notified.
    if not ack:
        _check_doc_size_limits("bulkWrite", len(command_encoded), max_bson_size + _COMMAND_OVERHEAD)

    # Don't include bulkWrite-command-agnostic fields in batch-splitting calculations.
    abridged_keys = ["bulkWrite", "errorsOnly", "ordered"]
    if command.get("bypassDocumentValidation"):
        abridged_keys.append("bypassDocumentValidation")
    if command.get("comment"):
        abridged_keys.append("comment")
    if command.get("let"):
        abridged_keys.append("let")
    command_abridged = {key: command[key] for key in abridged_keys}
    command_len_abridged = len(_dict_to_bson(command_abridged, False, opts))

    # Maximum combined size of the ops and nsInfo document sequences.
    max_doc_sequences_bytes = max_message_size - (_OP_MSG_OVERHEAD + command_len_abridged)

    ns_info = {}
    to_send_ops: list[Mapping[str, Any]] = []
    to_send_ns: list[Mapping[str, str]] = []
    to_send_ops_encoded: list[bytes] = []
    to_send_ns_encoded: list[bytes] = []
    total_ops_length = 0
    total_ns_length = 0
    idx = 0

    for (real_op_type, op_doc), namespace in zip(operations, namespaces):
        op_type = real_op_type
        # Check insert/replace document size if unacknowledged.
        if real_op_type == "insert":
            if not ack:
                doc_size = len(_dict_to_bson(op_doc["document"], False, opts))
                _check_doc_size_limits(real_op_type, doc_size, max_bson_size)
        if real_op_type == "replace":
            op_type = "update"
            if not ack:
                doc_size = len(_dict_to_bson(op_doc["updateMods"], False, opts))
                _check_doc_size_limits(real_op_type, doc_size, max_bson_size)

        ns_doc = None
        ns_length = 0

        if namespace not in ns_info:
            ns_doc = {"ns": namespace}
            new_ns_index = len(to_send_ns)
            ns_info[namespace] = new_ns_index

        # First entry in the operation doc has the operation type as its
        # key and the index of its namespace within ns_info as its value.
        op_doc[op_type] = ns_info[namespace]  # type: ignore[index]

        # Encode current operation doc and, if newly added, namespace doc.
        op_doc_encoded = _dict_to_bson(op_doc, False, opts)
        op_length = len(op_doc_encoded)
        if ns_doc:
            ns_doc_encoded = _dict_to_bson(ns_doc, False, opts)
            ns_length = len(ns_doc_encoded)

        # Check operation document size if unacknowledged.
        if not ack:
            _check_doc_size_limits(op_type, op_length, max_bson_size + _COMMAND_OVERHEAD)

        new_message_size = total_ops_length + total_ns_length + op_length + ns_length
        # We have enough data, return this batch.
        if new_message_size > max_doc_sequences_bytes:
            if idx == 0:
                _raise_document_too_large(op_type, op_length, max_bson_size + _COMMAND_OVERHEAD)
            break

        # Add op and ns documents to this batch.
        to_send_ops.append(op_doc)
        to_send_ops_encoded.append(op_doc_encoded)
        total_ops_length += op_length
        if ns_doc:
            to_send_ns.append(ns_doc)
            to_send_ns_encoded.append(ns_doc_encoded)
            total_ns_length += ns_length

        idx += 1

        # We have enough documents, return this batch.
        if idx == max_write_batch_size:
            break

    # Construct the entire OP_MSG.
    length = _client_construct_op_msg(
        command_encoded, to_send_ops_encoded, to_send_ns_encoded, ack, buf
    )

    return to_send_ops, to_send_ns, length


def _client_encode_batched_op_msg(
    command: Mapping[str, Any],
    operations: list[tuple[str, Mapping[str, Any]]],
    namespaces: list[str],
    ack: bool,
    opts: CodecOptions,
    ctx: _ClientBulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
    """Encode the next batched client-level bulkWrite
    operation as OP_MSG.
    """
    buf = _BytesIO()

    to_send_ops, to_send_ns, _ = _client_batched_op_msg_impl(
        command, operations, namespaces, ack, opts, ctx, buf
    )
    return buf.getvalue(), to_send_ops, to_send_ns


def _client_batched_op_msg_compressed(
    command: Mapping[str, Any],
    operations: list[tuple[str, Mapping[str, Any]]],
    namespaces: list[str],
    ack: bool,
    opts: CodecOptions,
    ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
    """Create the next batched client-level bulkWrite operation
    with OP_MSG, compressed.
    """
    data, to_send_ops, to_send_ns = _client_encode_batched_op_msg(
        command, operations, namespaces, ack, opts, ctx
    )

    assert ctx.conn.compression_context is not None
    request_id, msg = _compress(2013, data, ctx.conn.compression_context)
    return request_id, msg, to_send_ops, to_send_ns


def _client_batched_op_msg(
    command: Mapping[str, Any],
    operations: list[tuple[str, Mapping[str, Any]]],
    namespaces: list[str],
    ack: bool,
    opts: CodecOptions,
    ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
    """OP_MSG implementation entry point for client-level bulkWrite."""
    buf = _BytesIO()

    # Save space for message length and request id
    buf.write(_ZERO_64)
    # responseTo, opCode
    buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")

    to_send_ops, to_send_ns, length = _client_batched_op_msg_impl(
        command, operations, namespaces, ack, opts, ctx, buf
    )

    # Header - request id and message length
    buf.seek(4)
    request_id = _randint()
    buf.write(_pack_int(request_id))
    buf.seek(0)
    buf.write(_pack_int(length))

    return request_id, buf.getvalue(), to_send_ops, to_send_ns


def _client_do_batched_op_msg(
    command: MutableMapping[str, Any],
    operations: list[tuple[str, Mapping[str, Any]]],
    namespaces: list[str],
    opts: CodecOptions,
    ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
    """Create the next batched client-level bulkWrite
    operation using OP_MSG.
    """
    command["$db"] = "admin"
    if "writeConcern" in command:
        ack = bool(command["writeConcern"].get("w", 1))
    else:
        ack = True
    if ctx.conn.compression_context:
        return _client_batched_op_msg_compressed(command, operations, namespaces, ack, opts, ctx)
    return _client_batched_op_msg(command, operations, namespaces, ack, opts, ctx)


# End OP_MSG -----------------------------------------------------


def _encode_batched_write_command(
    namespace: str,
    operation: int,
    command: MutableMapping[str, Any],
    docs: list[Mapping[str, Any]],
    opts: CodecOptions,
    ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
    """Encode the next batched insert, update, or delete command."""
    buf = _BytesIO()

    to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf)
    return buf.getvalue(), to_send


if _use_c:
    _encode_batched_write_command = _cmessage._encode_batched_write_command


def _batched_write_command_impl(
    namespace: str,
    operation: int,
    command: MutableMapping[str, Any],
    docs: list[Mapping[str, Any]],
    opts: CodecOptions,
    ctx: _BulkWriteContext,
    buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
    """Create a batched OP_QUERY write command."""
    max_bson_size = ctx.max_bson_size
    max_write_batch_size = ctx.max_write_batch_size
    # Max BSON object size + 16k - 2 bytes for ending NUL bytes.
    # Server guarantees there is enough room: SERVER-10643.
    max_cmd_size = max_bson_size + _COMMAND_OVERHEAD
    max_split_size = ctx.max_split_size

    # No options
    buf.write(_ZERO_32)
    # Namespace as C string
    buf.write(namespace.encode("utf8"))
    buf.write(_ZERO_8)
    # Skip: 0, Limit: -1
    buf.write(_SKIPLIM)

    # Where to write command document length
    command_start = buf.tell()
    buf.write(bson.encode(command))

    # Start of payload
    buf.seek(-1, 2)
    # Work around some Jython weirdness.
    buf.truncate()
    try:
        buf.write(_OP_MAP[operation])
    except KeyError:
        raise InvalidOperation("Unknown command") from None

    # Where to write list document length
    list_start = buf.tell() - 4
    to_send = []
    idx = 0
    for doc in docs:
        # Encode the current operation
        key = str(idx).encode("utf8")
        value = _dict_to_bson(doc, False, opts)
        # Is there enough room to add this document? max_cmd_size accounts for
        # the two trailing null bytes.
        doc_too_large = len(value) > max_cmd_size
        if doc_too_large:
            write_op = list(_FIELD_MAP.keys())[operation]
            _raise_document_too_large(write_op, len(value), max_bson_size)
        enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size
        enough_documents = idx >= max_write_batch_size
        if enough_data or enough_documents:
            break
        buf.write(_BSONOBJ)
        buf.write(key)
        buf.write(_ZERO_8)
        buf.write(value)
        to_send.append(doc)
        idx += 1

    # Finalize the current OP_QUERY message.
    # Close list and command documents
    buf.write(_ZERO_16)

    # Write document lengths and request id
    length = buf.tell()
    buf.seek(list_start)
    buf.write(_pack_int(length - list_start - 1))
    buf.seek(command_start)
    buf.write(_pack_int(length - command_start))

    return to_send, length


class _OpReply:
    """A MongoDB OP_REPLY response message."""

    __slots__ = ("flags", "cursor_id", "number_returned", "documents")

    UNPACK_FROM = struct.Struct("<iqii").unpack_from
    OP_CODE = 1

    def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
        self.flags = flags
        self.cursor_id = Int64(cursor_id)
        self.number_returned = number_returned
        self.documents = documents

    def raw_response(
        self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
    ) -> list[bytes]:
        """Check the response header from the database, without decoding BSON.

        Check the response for errors and unpack.

        Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
        OperationFailure.

        :param cursor_id: cursor_id we sent to get this response -
            used for raising an informative exception when we get cursor id not
            valid at server response.
        """
        if self.flags & 1:
            # Shouldn't get this response if we aren't doing a getMore
            if cursor_id is None:
                raise ProtocolError("No cursor id for getMore operation")

            # Fake a getMore command response. OP_GET_MORE provides no
            # document.
            msg = "Cursor not found, cursor id: %d" % (cursor_id,)
            errobj = {"ok": 0, "errmsg": msg, "code": 43}
            raise CursorNotFound(msg, 43, errobj)
        elif self.flags & 2:
            error_object: dict = bson.BSON(self.documents).decode()
            # Fake the ok field if it doesn't exist.
            error_object.setdefault("ok", 0)
            if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
                raise NotPrimaryError(error_object["$err"], error_object)
            elif error_object.get("code") == 50:
                default_msg = "operation exceeded time limit"
                raise ExecutionTimeout(
                    error_object.get("$err", default_msg), error_object.get("code"), error_object
                )
            raise OperationFailure(
                "database error: %s" % error_object.get("$err"),
                error_object.get("code"),
                error_object,
            )
        if self.documents:
            return [self.documents]
        return []

    def unpack_response(
        self,
        cursor_id: Optional[int] = None,
        codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
        user_fields: Optional[Mapping[str, Any]] = None,
        legacy_response: bool = False,
    ) -> list[dict[str, Any]]:
        """Unpack a response from the database and decode the BSON document(s).

        Check the response for errors and unpack, returning a dictionary
        containing the response data.

        Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
        OperationFailure.

        :param cursor_id: cursor_id we sent to get this response -
            used for raising an informative exception when we get cursor id not
            valid at server response
        :param codec_options: an instance of
            :class:`~bson.codec_options.CodecOptions`
        :param user_fields: Response fields that should be decoded
            using the TypeDecoders from codec_options, passed to
            bson._decode_all_selective.
        """
        self.raw_response(cursor_id)
        if legacy_response:
            return bson.decode_all(self.documents, codec_options)
        return bson._decode_all_selective(self.documents, codec_options, user_fields)

    def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
        """Unpack a command response."""
        docs = self.unpack_response(codec_options=codec_options)
        assert self.number_returned == 1
        return docs[0]

    def raw_command_response(self) -> NoReturn:
        """Return the bytes of the command response."""
        # This should never be called on _OpReply.
        raise NotImplementedError

    @property
    def more_to_come(self) -> bool:
        """Is the moreToCome bit set on this response?"""
        return False

    @classmethod
    def unpack(cls, msg: bytes) -> _OpReply:
        """Construct an _OpReply from raw bytes."""
        # PYTHON-945: ignore starting_from field.
        flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)

        documents = msg[20:]
        return cls(flags, cursor_id, number_returned, documents)


class _OpMsg:
    """A MongoDB OP_MSG response message."""

    __slots__ = ("flags", "cursor_id", "number_returned", "payload_document")

    UNPACK_FROM = struct.Struct("<IBi").unpack_from
    OP_CODE = 2013

    # Flag bits.
    CHECKSUM_PRESENT = 1
    MORE_TO_COME = 1 << 1
    EXHAUST_ALLOWED = 1 << 16  # Only present on requests.

    def __init__(self, flags: int, payload_document: bytes):
        self.flags = flags
        self.payload_document = payload_document

    def raw_response(
        self,
        cursor_id: Optional[int] = None,
        user_fields: Optional[Mapping[str, Any]] = {},
    ) -> list[Mapping[str, Any]]:
        """
        cursor_id is ignored
        user_fields is used to determine which fields must not be decoded
        """
        inflated_response = bson._decode_selective(
            RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS
        )
        return [inflated_response]

    def unpack_response(
        self,
        cursor_id: Optional[int] = None,
        codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
        user_fields: Optional[Mapping[str, Any]] = None,
        legacy_response: bool = False,
    ) -> list[dict[str, Any]]:
        """Unpack a OP_MSG command response.

        :param cursor_id: Ignored, for compatibility with _OpReply.
        :param codec_options: an instance of
            :class:`~bson.codec_options.CodecOptions`
        :param user_fields: Response fields that should be decoded
            using the TypeDecoders from codec_options, passed to
            bson._decode_all_selective.
        """
        # If _OpMsg is in-use, this cannot be a legacy response.
        assert not legacy_response
        return bson._decode_all_selective(self.payload_document, codec_options, user_fields)

    def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
        """Unpack a command response."""
        return self.unpack_response(codec_options=codec_options)[0]

    def raw_command_response(self) -> bytes:
        """Return the bytes of the command response."""
        return self.payload_document

    @property
    def more_to_come(self) -> bool:
        """Is the moreToCome bit set on this response?"""
        return bool(self.flags & self.MORE_TO_COME)

    @classmethod
    def unpack(cls, msg: bytes) -> _OpMsg:
        """Construct an _OpMsg from raw bytes."""
        flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
        if flags != 0:
            if flags & cls.CHECKSUM_PRESENT:
                raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}")

            if flags ^ cls.MORE_TO_COME:
                raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}")
        if first_payload_type != 0:
            raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}")

        if len(msg) != first_payload_size + 5:
            raise ProtocolError("Unsupported OP_MSG reply: >1 section")

        payload_document = msg[5:]
        return cls(flags, payload_document)


_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
    _OpReply.OP_CODE: _OpReply.unpack,
    _OpMsg.OP_CODE: _OpMsg.unpack,
}


class _Query:
    """A query operation."""

    __slots__ = (
        "flags",
        "db",
        "coll",
        "ntoskip",
        "spec",
        "fields",
        "codec_options",
        "read_preference",
        "limit",
        "batch_size",
        "name",
        "read_concern",
        "collation",
        "session",
        "client",
        "allow_disk_use",
        "_as_command",
        "exhaust",
    )

    # For compatibility with the _GetMore class.
    conn_mgr = None
    cursor_id = None

    def __init__(
        self,
        flags: int,
        db: str,
        coll: str,
        ntoskip: int,
        spec: Mapping[str, Any],
        fields: Optional[Mapping[str, Any]],
        codec_options: CodecOptions,
        read_preference: _ServerMode,
        limit: int,
        batch_size: int,
        read_concern: ReadConcern,
        collation: Optional[Mapping[str, Any]],
        session: Optional[_AgnosticClientSession],
        client: _AgnosticMongoClient,
        allow_disk_use: Optional[bool],
        exhaust: bool,
    ):
        self.flags = flags
        self.db = db
        self.coll = coll
        self.ntoskip = ntoskip
        self.spec = spec
        self.fields = fields
        self.codec_options = codec_options
        self.read_preference = read_preference
        self.read_concern = read_concern
        self.limit = limit
        self.batch_size = batch_size
        self.collation = collation
        self.session = session
        self.client = client
        self.allow_disk_use = allow_disk_use
        self.name = "find"
        self._as_command: Optional[tuple[dict[str, Any], str]] = None
        self.exhaust = exhaust

    def reset(self) -> None:
        self._as_command = None

    def namespace(self) -> str:
        return f"{self.db}.{self.coll}"

    def use_command(self, conn: _AgnosticConnection) -> bool:
        use_find_cmd = False
        if not self.exhaust:
            use_find_cmd = True
        elif conn.max_wire_version >= 8:
            # OP_MSG supports exhaust on MongoDB 4.2+
            use_find_cmd = True
        elif not self.read_concern.ok_for_legacy:
            raise ConfigurationError(
                "read concern level of %s is not valid "
                "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version)
            )

        conn.validate_session(self.client, self.session)  # type: ignore[arg-type]
        return use_find_cmd

    def update_command(self, cmd: dict[str, Any]) -> None:
        self._as_command = cmd, self.db

    def as_command(
        self, conn: _AgnosticConnection, apply_timeout: bool = False
    ) -> tuple[dict[str, Any], str]:
        """Return a find command document for this query."""
        # We use the command twice: on the wire and for command monitoring.
        # Generate it once, for speed and to avoid repeating side-effects.
        if self._as_command is not None:
            return self._as_command

        explain = "$explain" in self.spec
        cmd: dict[str, Any] = _gen_find_command(
            self.coll,
            self.spec,
            self.fields,
            self.ntoskip,
            self.limit,
            self.batch_size,
            self.flags,
            self.read_concern,
            self.collation,
            self.session,
            self.allow_disk_use,
        )
        if explain:
            self.name = "explain"
            cmd = {"explain": cmd}
        conn.add_server_api(cmd)
        if self.session:
            self.session._apply_to(cmd, False, self.read_preference, conn)  # type: ignore[arg-type]
            # Explain does not support readConcern.
            if not explain and not self.session.in_transaction:
                self.session._update_read_concern(cmd, conn)  # type: ignore[arg-type]
        conn.send_cluster_time(cmd, self.session, self.client)  # type: ignore[arg-type]
        # Support CSOT
        if apply_timeout:
            conn.apply_timeout(self.client, cmd=cmd)  # type: ignore[arg-type]
        self._as_command = cmd, self.db
        return self._as_command

    def get_message(
        self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False
    ) -> tuple[int, bytes, int]:
        """Get a query message, possibly setting the secondaryOk bit."""
        # Use the read_preference decided by _socket_from_server.
        self.read_preference = read_preference
        if read_preference.mode:
            # Set the secondaryOk bit.
            flags = self.flags | 4
        else:
            flags = self.flags

        ns = self.namespace()
        spec = self.spec

        if use_cmd:
            spec = self.as_command(conn)[0]
            request_id, msg, size, _ = _op_msg(
                0,
                spec,
                self.db,
                read_preference,
                self.codec_options,
                ctx=conn.compression_context,
            )
            return request_id, msg, size

        # OP_QUERY treats ntoreturn of -1 and 1 the same, return
        # one document and close the cursor. We have to use 2 for
        # batch size if 1 is specified.
        ntoreturn = self.batch_size == 1 and 2 or self.batch_size
        if self.limit:
            if ntoreturn:
                ntoreturn = min(self.limit, ntoreturn)
            else:
                ntoreturn = self.limit

        if conn.is_mongos:
            assert isinstance(spec, MutableMapping)
            spec = _maybe_add_read_preference(spec, read_preference)

        return _query(
            flags,
            ns,
            self.ntoskip,
            ntoreturn,
            spec,
            None if use_cmd else self.fields,
            self.codec_options,
            ctx=conn.compression_context,
        )


class _GetMore:
    """A getmore operation."""

    __slots__ = (
        "db",
        "coll",
        "ntoreturn",
        "cursor_id",
        "max_await_time_ms",
        "codec_options",
        "read_preference",
        "session",
        "client",
        "conn_mgr",
        "_as_command",
        "exhaust",
        "comment",
    )

    name = "getMore"

    def __init__(
        self,
        db: str,
        coll: str,
        ntoreturn: int,
        cursor_id: int,
        codec_options: CodecOptions,
        read_preference: _ServerMode,
        session: Optional[_AgnosticClientSession],
        client: _AgnosticMongoClient,
        max_await_time_ms: Optional[int],
        conn_mgr: Any,
        exhaust: bool,
        comment: Any,
    ):
        self.db = db
        self.coll = coll
        self.ntoreturn = ntoreturn
        self.cursor_id = cursor_id
        self.codec_options = codec_options
        self.read_preference = read_preference
        self.session = session
        self.client = client
        self.max_await_time_ms = max_await_time_ms
        self.conn_mgr = conn_mgr
        self._as_command: Optional[tuple[dict[str, Any], str]] = None
        self.exhaust = exhaust
        self.comment = comment

    def reset(self) -> None:
        self._as_command = None

    def namespace(self) -> str:
        return f"{self.db}.{self.coll}"

    def use_command(self, conn: _AgnosticConnection) -> bool:
        use_cmd = False
        if not self.exhaust:
            use_cmd = True
        elif conn.max_wire_version >= 8:
            # OP_MSG supports exhaust on MongoDB 4.2+
            use_cmd = True

        conn.validate_session(self.client, self.session)  # type: ignore[arg-type]
        return use_cmd

    def update_command(self, cmd: dict[str, Any]) -> None:
        self._as_command = cmd, self.db

    def as_command(
        self, conn: _AgnosticConnection, apply_timeout: bool = False
    ) -> tuple[dict[str, Any], str]:
        """Return a getMore command document for this query."""
        # See _Query.as_command for an explanation of this caching.
        if self._as_command is not None:
            return self._as_command

        cmd: dict[str, Any] = _gen_get_more_command(
            self.cursor_id,
            self.coll,
            self.ntoreturn,
            self.max_await_time_ms,
            self.comment,
            conn,
        )
        if self.session:
            self.session._apply_to(cmd, False, self.read_preference, conn)  # type: ignore[arg-type]
        conn.add_server_api(cmd)
        conn.send_cluster_time(cmd, self.session, self.client)  # type: ignore[arg-type]
        # Support CSOT
        if apply_timeout:
            conn.apply_timeout(self.client, cmd=None)  # type: ignore[arg-type]
        self._as_command = cmd, self.db
        return self._as_command

    def get_message(
        self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False
    ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]:
        """Get a getmore message."""
        ns = self.namespace()
        ctx = conn.compression_context

        if use_cmd:
            spec = self.as_command(conn)[0]
            if self.conn_mgr and self.exhaust:
                flags = _OpMsg.EXHAUST_ALLOWED
            else:
                flags = 0
            request_id, msg, size, _ = _op_msg(
                flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context
            )
            return request_id, msg, size

        return _get_more(ns, self.ntoreturn, self.cursor_id, ctx)


class _RawBatchQuery(_Query):
    def use_command(self, conn: _AgnosticConnection) -> bool:
        # Compatibility checks.
        super().use_command(conn)
        if conn.max_wire_version >= 8:
            # MongoDB 4.2+ supports exhaust over OP_MSG
            return True
        elif not self.exhaust:
            return True
        return False


class _RawBatchGetMore(_GetMore):
    def use_command(self, conn: _AgnosticConnection) -> bool:
        # Compatibility checks.
        super().use_command(conn)
        if conn.max_wire_version >= 8:
            # MongoDB 4.2+ supports exhaust over OP_MSG
            return True
        elif not self.exhaust:
            return True
        return False


class _CursorAddress(tuple):
    """The server address (host, port) of a cursor, with namespace property."""

    __namespace: Any

    def __new__(cls, address: _Address, namespace: str) -> _CursorAddress:
        self = tuple.__new__(cls, address)
        self.__namespace = namespace
        return self

    @property
    def namespace(self) -> str:
        """The namespace this cursor."""
        return self.__namespace

    def __hash__(self) -> int:
        # Two _CursorAddress instances with different namespaces
        # must not hash the same.
        return ((*self, self.__namespace)).__hash__()

    def __eq__(self, other: object) -> bool:
        if isinstance(other, _CursorAddress):
            return tuple(self) == tuple(other) and self.namespace == other.namespace
        return NotImplemented

    def __ne__(self, other: object) -> bool:
        return not self == other
