import abc
import datetime
import json
from collections import Counter, defaultdict
from typing import Any, Generic, Optional, Sequence, TypeVar, Union

import grpc
from pydantic import BaseModel
from typing_extensions import Self

from .meta import ProtoDecorator
from .proto import chat_pb2, chat_pb2_grpc, image_pb2, sample_pb2, usage_pb2
from .search import SearchParameters
from .telemetry import should_disable_sensitive_attributes
from .types import (
    ChatModel,
    Content,
    ImageDetail,
    IncludeOption,
    IncludeOptionMap,
    ReasoningEffort,
    ResponseFormat,
    ToolMode,
)

T = TypeVar("T")


class BaseClient(abc.ABC, Generic[T]):
    """Base Client for interacting with the `Chat` API."""

    _stub: chat_pb2_grpc.ChatStub

    def __init__(self, channel: Union[grpc.Channel, grpc.aio.Channel]):
        """Creates a new client based on a gRPC channel."""
        self._stub = chat_pb2_grpc.ChatStub(channel)

    def create(
        self,
        model: Union[ChatModel, str],
        *,
        conversation_id: Optional[str] = None,
        messages: Optional[Sequence[chat_pb2.Message]] = None,
        user: Optional[str] = None,
        max_tokens: Optional[int] = None,
        seed: Optional[int] = None,
        stop: Optional[Sequence[str]] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        tools: Optional[Sequence[chat_pb2.Tool]] = None,
        tool_choice: Optional[Union[ToolMode, chat_pb2.ToolChoice]] = None,
        parallel_tool_calls: Optional[bool] = None,
        response_format: Optional[Union[ResponseFormat, chat_pb2.ResponseFormat, type[BaseModel]]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        reasoning_effort: Optional[Union[ReasoningEffort, "chat_pb2.ReasoningEffort"]] = None,
        search_parameters: Optional[Union[SearchParameters, chat_pb2.SearchParameters]] = None,
        store_messages: Optional[bool] = None,
        previous_response_id: Optional[str] = None,
        use_encrypted_content: Optional[bool] = None,
        max_turns: Optional[int] = None,
        include: Optional[Sequence[Union[IncludeOption, "chat_pb2.IncludeOption"]]] = None,
        batch_request_id: Optional[str] = None,
    ) -> T:
        """Creates a new chat conversation.

        This function does not immediately perform an RPC. It only initializes a mutable request
        instance.

        Examples:
            ```
            chat = client.chat.create(
                model="grok-3-latest",
                messages=[
                    system("You are a pirate"),
                    user("How are you?"),
                ]
            )
            response = chat.sample()
            chat.append(response)
            print(response)

            chat.append(user("Tell me a joke"))
            response = chat.sample()
            print(response)
            ```

        Args:
            model: Model to use, e.g. "grok-3-latest".
            conversation_id: Optional ID to group all messages in this chat instance.
                When provided, this ID is added as a span attribute ("gen_ai.conversation.id") to all OpenTelemetry
                spans generated by this chat instance (e.g., calls to `sample`). This enables easy
                identification and grouping of spans from the same conversation for monitoring and debugging.
            messages: A list of messages that make up the the chat conversation. Different models support different
                message types, such as image and text.
            user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse.
            max_tokens: The maximum number of tokens that can be generated in the chat completion.
            seed: If specified, our system will make a best effort to sample deterministically, such that repeated
                requests with the same seed and parameters should return the same result.
            stop: Up to 4 sequences where the API will stop generating further tokens.
            temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
                more random, while lower values like 0.2 will make it more focused and deterministic.
            top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the
                results of the tokens with `top_p` probability mass.
            logprobs: Whether to return log probabilities of the output tokens or not. If true, returns the log
                probabilities of each output token returned in the content of message.
            top_logprobs: An integer between 0 and 8 specifying the number of most likely tokens to return at each token
                position, each with an associated log probability. logprobs must be set to true if this parameter is
                used.
            tools: A list of tools the model may call in JSON-schema. Currently, only functions are supported as a tool.
                Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions
                are supported.
            tool_choice: Controls which (if any) tool is called by the model. `none` means the model will not call any
                tool and instead generates a message. `auto` means the model can pick between generating a message or
                calling one or more tools. `required` means the model must call one or more tools.
                `auto` is the default if tools are present.
                To force the model to always invoke a specific tool, use the `required_tool(name)` function to create a
                `ToolChoice` with mode `required` and specify the tool's name. For example:
                ```
                tool_choice=required_tool("get_weather")
                ```
                This ensures the model must call the specified tool from the list provided in `tools`.
            parallel_tool_calls: If set to false, the model can perform maximum one tool call per response.
                Defaults to true.
            response_format: An object specifying the format that the model must output.
                `json_object` means the model must output a JSON object (although the shape is arbitrary).
                `text` means the model must output a text string.
                To force the model to output a JSON object adhering to a specific JSON Schema use the client's `parse`
                method instead.
            frequency_penalty: Positive values penalize new tokens based on their existing frequency in the text so far,
                decreasing the model's likelihood to repeat the same line verbatim.
            presence_penalty: Positive values penalize new tokens based on whether they appear in the text so far,
                increasing the model's likelihood to talk about new topics.
            reasoning_effort: Constrains how hard a reasoning model thinks before responding. Possible values are `low`
                (uses fewer reasoning tokens) and `high` (uses more reasoning tokens). Defaults to `low`.
            search_parameters: The parameters that control search behavior.
                This includes settings like search mode, date range, sources (e.g., web, news, or X), and whether
                to return citations. See `SearchParameters` for detailed configuration options.
            store_messages: Whether to store responses generated by this chat instance on the xAI backend.
                When set to True, each response is persisted with a unique response ID (accessible via
                `response.id`), enabling retrieval via `client.chat.get_stored_completion()` and deletion
                via `client.chat.delete_stored_completion()`. Stored responses can be referenced by
                `previous_response_id` in new chat instances to branch conversations from specific points.
                Defaults to False. Note: Teams with ZDR enabled will not be able to use this feature.
            previous_response_id: The ID of a previously stored response (i.e., `response.id`) to use as the starting
                point for this conversation. When provided, the entire conversation history up to that response is
                prepended to the current chat instance, enabling conversation branching and continuation from any
                stored point. This prepending happens automatically on the server side and is opaque to the user.
                Note: Adding `previous_response_id` to a chat instance will not update the local message list of the
                chat instance (i.e., `chat.messages`).
            use_encrypted_content: Whether to return encrypted reasoning content from the model response.
                When enabled, encrypted reasoning content is included in responses, enabling optimal hydration
                of reasoning traces in follow-up conversations. Calling `append(response)` automatically passes
                this encrypted content back for subsequent requests. This is particularly useful for users with
                zero data retention (ZDR) enabled who cannot use `store_messages` and `previous_response_id`
                for conversation continuity. Defaults to False.
            max_turns: The maximum number of agentic turns the model can take. When set, the model will automatically
                iterate up to this many turns, calling tools and processing their results until it reaches a final
                answer or hits the turn limit. Defaults to server-side maximum. Note: This parameter has no effect
                on non-agentic requests (i.e. requests that do not use server-side tools). With parallel tool calls
                enabled, multiple tool calls can occur within a single turn, so max_turns does not necessarily equal
                the total number of tool calls.
            include: A list of output options to include in the response.
                Check the `IncludeOption` enum for all possible values.
                Defaults to None.
            batch_request_id: An optional user-provided identifier for the batch request. **If provided, it must be
              unique within the batch.**Used to identify the corresponding result when the response is returned to the
              user.

        Returns:
            A new chat request bound to a client.
        """
        tool_choice_pb: Optional[chat_pb2.ToolChoice] = None
        if isinstance(tool_choice, str):
            tool_choice_pb = chat_pb2.ToolChoice(mode=_tool_mode_to_proto(tool_choice))
        else:
            tool_choice_pb = tool_choice

        response_format_pb: Optional[chat_pb2.ResponseFormat] = None
        if isinstance(response_format, str):
            response_format_pb = chat_pb2.ResponseFormat(format_type=_format_type_to_proto(response_format))
        elif isinstance(response_format, type) and issubclass(response_format, BaseModel):
            response_format_pb = chat_pb2.ResponseFormat(
                format_type=chat_pb2.FORMAT_TYPE_JSON_SCHEMA,
                schema=json.dumps(response_format.model_json_schema()),
            )
        else:
            response_format_pb = response_format

        reasoning_effort_pb: Optional[chat_pb2.ReasoningEffort] = None
        if isinstance(reasoning_effort, str):
            reasoning_effort_pb = _reasoning_effort_to_proto(reasoning_effort)
        else:
            reasoning_effort_pb = reasoning_effort

        search_parameters_pb: Optional[chat_pb2.SearchParameters] = None
        if isinstance(search_parameters, SearchParameters):
            search_parameters_pb = search_parameters._to_proto()
        else:
            search_parameters_pb = search_parameters

        include_pb: Optional[Sequence[chat_pb2.IncludeOption]] = None
        if include is not None:
            include_pb = [
                _include_option_to_proto(include_option) if isinstance(include_option, str) else include_option
                for include_option in include
            ]

        return self._make_chat(
            conversation_id=conversation_id,
            batch_request_id=batch_request_id,
            model=model,
            messages=messages,
            user=user,
            max_tokens=max_tokens,
            seed=seed,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            tools=tools,
            tool_choice=tool_choice_pb,
            parallel_tool_calls=parallel_tool_calls,
            response_format=response_format_pb,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            reasoning_effort=reasoning_effort_pb,
            search_parameters=search_parameters_pb,
            store_messages=store_messages,
            previous_response_id=previous_response_id,
            use_encrypted_content=use_encrypted_content,
            max_turns=max_turns,
            include=include_pb,
        )

    @abc.abstractmethod
    def _make_chat(self, conversation_id: Optional[str], batch_request_id: Optional[str], **settings) -> T:
        """Creates the proto wrapper for chat requests."""


class BaseChat(ProtoDecorator[chat_pb2.GetCompletionsRequest]):
    """Utility class for simplifying the interaction with Chat requests and responses."""

    _stub: chat_pb2_grpc.ChatStub

    def __init__(
        self,
        stub: chat_pb2_grpc.ChatStub,
        conversation_id: Optional[str],
        batch_request_id: Optional[str],
        **settings,
    ) -> None:
        """Prepares a new chat request.

        Args:
            stub: gRPC stub used to connect to the server.
            conversation_id: The ID of the conversation.
            batch_request_id: The ID of the batch request, should only be set when creating a chat completion
                for a batch request.
            **settings: See `chat_pb2.GetCompletionsRequest`.
        """
        super().__init__(chat_pb2.GetCompletionsRequest(**settings))
        self._stub = stub
        self._conversation_id = conversation_id
        self._batch_request_id = batch_request_id

    def append(self, message: Union[chat_pb2.Message, "Response"]) -> Self:
        """Adds a new message to the conversation history, enabling multi-turn interactions.

        This method appends a message to the chat's message sequence, which can be a user input,
        system prompt, assistant response, or tool result. It supports both `chat_pb2.Message`
        objects (created via helper functions like `user`, `system`, `assistant`, or `tool_result`) and
        `Response` objects from previous chat interactions. The method returns the chat object
        itself, allowing for method chaining.

        Examples:
            ```
            # Adding a simple user message
            chat = client.chat.create(model="grok-3")
            chat.append(user("How are you?"))
            ```

            ```
            # Adding a system prompt with an image
            chat = client.chat.create(model="grok-3")
            chat.append(system(
                "Analyze the following image: ",
                image("https://example.com/image.jpg", detail="high"),
                "Provide a detailed description."
            ))
            ```

            ```
            # Appending an assistant's response in a multi-turn conversation
            chat = client.chat.create(model="grok-3")
            chat.append(user("Tell me a pirate joke."))
            response = chat.sample()  # Get assistant's response
            print(f"Grok: {response.content}")
            chat.append(response)  # Add assistant's response to history
            chat.append(user("Another one, please!"))
            ```

            ```
            # Multi-turn chat loop (inspired by sync/chat.py)
            chat = client.chat.create(
                model="grok-3",
                messages=[system("You talk like a pirate.")]
            )
            while True:
                prompt = input("You: ")
                if prompt.lower() == "exit":
                    break
                chat.append(user(prompt))
                response = chat.sample()
                print(f"Grok: {response.content}")
                chat.append(response)
            ```

        Args:
            message: The message to append, either a `chat_pb2.Message` (e.g., created by `user`,
                `system`, `assistant`, or `tool_result`) or a `Response` object from a previous
                chat interaction.

        Returns:
            Self: The chat object, enabling method chaining.
        """
        if isinstance(message, chat_pb2.Message):
            self._proto.messages.append(message)
        elif isinstance(message, Response):
            if message._index is None:
                # Every single output should be appended for agentic tool call responses.
                for output in message.proto.outputs:
                    self._proto.messages.append(
                        chat_pb2.Message(
                            role=output.message.role,
                            content=[text(output.message.content)],
                            reasoning_content=output.message.reasoning_content,
                            encrypted_content=output.message.encrypted_content,
                            tool_calls=output.message.tool_calls,
                        )
                    )
            else:
                self._proto.messages.append(
                    chat_pb2.Message(
                        role=message._get_output().message.role,
                        content=[text(message.content)],
                        reasoning_content=message.reasoning_content,
                        encrypted_content=message.encrypted_content,
                        tool_calls=message.tool_calls,
                    )
                )
        else:
            raise ValueError("Unrecognized message type.")
        return self

    def _make_request(self, n: int) -> chat_pb2.GetCompletionsRequest:
        """Creates a request proto.

        Args:
            n: Number of completions to generate.
        """
        request = chat_pb2.GetCompletionsRequest()
        # prevent requests with no messages.
        if not self._proto.messages:
            raise ValueError(
                "Cannot create a completion request: No messages provided. Please include at least one "
                "message (e.g., using user(), system(), or assistant()) in the request."
            )

        request.CopyFrom(self._proto)
        request.n = n
        return request

    def _make_span_request_attributes(self) -> dict[str, Any]:  # noqa: C901, PLR0912
        """Creates a dictionary with all relevant request attributes to be set on the span as it is created."""
        attributes: dict[str, Any] = {
            "gen_ai.operation.name": "chat",
            "gen_ai.provider.name": "xai",
            "gen_ai.output.type": "text",
            "gen_ai.request.model": self._proto.model,
            "server.port": 443,
            "server.address": "api.x.ai",
        }

        if should_disable_sensitive_attributes():
            return attributes

        # Initialize optional fields to their default values (as set server-side).
        # Override with user-set values only if they are provided by the user.
        attributes["gen_ai.request.frequency_penalty"] = 0.0
        attributes["gen_ai.request.presence_penalty"] = 0.0
        attributes["gen_ai.request.temperature"] = 1.0
        attributes["gen_ai.request.parallel_tool_calls"] = True
        attributes["gen_ai.request.store_messages"] = False
        attributes["gen_ai.request.use_encrypted_content"] = False

        attributes["gen_ai.request.logprobs"] = self._proto.logprobs

        # Float fields that need rounding
        float_fields = [
            ("frequency_penalty", "gen_ai.request.frequency_penalty"),
            ("presence_penalty", "gen_ai.request.presence_penalty"),
            ("temperature", "gen_ai.request.temperature"),
            ("top_p", "gen_ai.request.top_p"),
        ]

        # Integer fields
        int_fields = [
            ("n", "gen_ai.request.choice.count"),
            ("seed", "gen_ai.request.seed"),
            ("max_tokens", "gen_ai.request.max_tokens"),
            ("top_logprobs", "gen_ai.request.top_logprobs"),
        ]

        # Set float fields with rounding
        for proto_field, attr_name in float_fields:
            if self._proto.HasField(proto_field):
                attributes[attr_name] = round(getattr(self._proto, proto_field), 6)

        # Set integer fields
        for proto_field, attr_name in int_fields:
            if self._proto.HasField(proto_field):
                attributes[attr_name] = getattr(self._proto, proto_field)

        # Special cases
        if self._conversation_id:
            attributes["gen_ai.conversation.id"] = self._conversation_id
        if len(self._proto.stop) > 0:
            attributes["gen_ai.request.stop_sequences"] = list(self._proto.stop)
        if self._proto.HasField("response_format"):
            attributes["gen_ai.output.type"] = (
                chat_pb2.FormatType.Name(self._proto.response_format.format_type).removeprefix("FORMAT_TYPE_").lower()
            )
        if self._proto.HasField("reasoning_effort"):
            attributes["gen_ai.request.reasoning_effort"] = (
                chat_pb2.ReasoningEffort.Name(self._proto.reasoning_effort).removeprefix("EFFORT_").lower()
            )
        if self._proto.user:
            attributes["user_id"] = self._proto.user
        if self._proto.HasField("parallel_tool_calls"):
            attributes["gen_ai.request.parallel_tool_calls"] = self._proto.parallel_tool_calls
        if self._proto.store_messages:
            attributes["gen_ai.request.store_messages"] = self._proto.store_messages
        if self._proto.previous_response_id:
            attributes["gen_ai.request.previous_response_id"] = self._proto.previous_response_id
        if self._proto.use_encrypted_content:
            attributes["gen_ai.request.use_encrypted_content"] = self._proto.use_encrypted_content

        prompt_attributes = self._get_span_prompt_attributes()
        attributes.update(prompt_attributes)

        return attributes

    def _get_span_prompt_attributes(self) -> dict[str, Any]:
        """Creates a dictionary with prompt message attributes for span telemetry."""
        prompt_attributes: dict[str, Any] = {}

        # Skip collecting sensitive attributes if disabled
        if should_disable_sensitive_attributes():
            return prompt_attributes

        # Only text content is included in span attributes.
        for index, message in enumerate(self._proto.messages):
            if message.role == chat_pb2.MessageRole.ROLE_USER:
                prompt_attributes[f"gen_ai.prompt.{index}.role"] = "user"
                prompt_attributes[f"gen_ai.prompt.{index}.content"] = "".join([c.text for c in message.content])
            elif message.role == chat_pb2.MessageRole.ROLE_ASSISTANT:
                prompt_attributes[f"gen_ai.prompt.{index}.role"] = "assistant"
                prompt_attributes[f"gen_ai.prompt.{index}.content"] = "".join([c.text for c in message.content])
                if message.tool_calls:
                    prompt_attributes[f"gen_ai.prompt.{index}.tool_calls"] = json.dumps(
                        [
                            {
                                "id": tool_call.id,
                                "type": "function",
                                "function": {
                                    "name": tool_call.function.name,
                                    "arguments": json.loads(tool_call.function.arguments),
                                },
                            }
                            for tool_call in message.tool_calls
                        ]
                    )
            elif message.role == chat_pb2.MessageRole.ROLE_SYSTEM:
                prompt_attributes[f"gen_ai.prompt.{index}.role"] = "system"
                prompt_attributes[f"gen_ai.prompt.{index}.content"] = "".join([c.text for c in message.content])
            elif message.role == chat_pb2.MessageRole.ROLE_DEVELOPER:
                prompt_attributes[f"gen_ai.prompt.{index}.role"] = "developer"
                prompt_attributes[f"gen_ai.prompt.{index}.content"] = "".join([c.text for c in message.content])
            elif message.role == chat_pb2.MessageRole.ROLE_TOOL:
                prompt_attributes[f"gen_ai.prompt.{index}.role"] = "tool"
                prompt_attributes[f"gen_ai.prompt.{index}.content"] = "".join([c.text for c in message.content])

        return prompt_attributes

    def _make_span_response_attributes(self, responses: Sequence["Response"]) -> dict[str, Any]:
        """Creates a dictionary with response metadata and completion attributes for span telemetry."""
        attributes: dict[str, Any] = {}

        if should_disable_sensitive_attributes():
            return attributes

        # All of these attributes are the same for all responses, so we can just use the first response to access them.
        response = responses[0]
        attributes["gen_ai.response.id"] = response.id
        attributes["gen_ai.response.model"] = response._proto.model
        attributes["gen_ai.usage.input_tokens"] = response.usage.prompt_tokens
        attributes["gen_ai.usage.output_tokens"] = response.usage.completion_tokens
        attributes["gen_ai.usage.total_tokens"] = response.usage.total_tokens
        attributes["gen_ai.usage.reasoning_tokens"] = response.usage.reasoning_tokens
        attributes["gen_ai.usage.cached_prompt_text_tokens"] = response.usage.cached_prompt_text_tokens
        attributes["gen_ai.usage.prompt_text_tokens"] = response.usage.prompt_text_tokens
        attributes["gen_ai.usage.prompt_image_tokens"] = response.usage.prompt_image_tokens
        attributes["gen_ai.response.system_fingerprint"] = response.system_fingerprint

        # Only finish reasons are different for each response.
        attributes["gen_ai.response.finish_reasons"] = [response.finish_reason for response in responses]

        completion_attributes = self._get_span_completion_attributes(responses)
        attributes.update(completion_attributes)

        return attributes

    def _get_span_completion_attributes(self, responses: Sequence["Response"]) -> dict[str, Any]:
        """Creates a dictionary with completion content attributes for span telemetry."""
        completion_attributes: dict[str, Any] = {}

        # Skip collecting sensitive attributes if disabled
        if should_disable_sensitive_attributes():
            return completion_attributes

        for index, response in enumerate(responses):
            completion_attributes[f"gen_ai.completion.{index}.role"] = response.role.removeprefix("ROLE_").lower()
            completion_attributes[f"gen_ai.completion.{index}.content"] = response.content
            if response.reasoning_content:
                completion_attributes[f"gen_ai.completion.{index}.reasoning_content"] = response.reasoning_content
            if response.tool_calls:
                completion_attributes[f"gen_ai.completion.{index}.tool_calls"] = json.dumps(
                    [
                        {
                            "id": tool_call.id,
                            "type": "function",
                            "function": {
                                "name": tool_call.function.name,
                                "arguments": json.loads(tool_call.function.arguments),
                            },
                        }
                        for tool_call in response.tool_calls
                    ]
                )

        return completion_attributes

    def _uses_server_side_tools(self) -> bool:
        """Returns True if any server-side tools are specified in the completions request."""
        return any(tool.WhichOneof("tool") != "function" for tool in self._proto.tools)

    def _auto_detect_multi_output_mode(
        self, index: Optional[int], outputs: Sequence[Union[chat_pb2.CompletionOutput, chat_pb2.CompletionOutputChunk]]
    ) -> Optional[int]:
        """Auto-detects if the server is using multi-output mode and updates the index accordingly.

        When we expect single-output mode (index=0) but the server returns multiple outputs
        (likely because it added tools implicitly), this method switches to multi-output mode
        (index=None) to properly handle all outputs.

        Args:
            index: The current index value (0 for single-output, None for multi-output).
            outputs: The outputs from the response or chunk to check.

        Returns:
            The potentially updated index value (None if multi-output mode is detected, otherwise unchanged).
        """
        if index == 0 and outputs:
            max_output_index = max(output.index for output in outputs)
            if max_output_index > 0:
                # Server is using multi-output mode (likely added tools implicitly)
                return None
        return index

    @property
    def messages(self) -> Sequence[chat_pb2.Message]:
        """Returns the messages in the conversation."""
        return self._proto.messages


def user(*args: Content) -> chat_pb2.Message:
    """Creates a new message of role "user"."""
    return chat_pb2.Message(role=chat_pb2.MessageRole.ROLE_USER, content=[_process_content(c) for c in args])


def assistant(*args: Content) -> chat_pb2.Message:
    """Creates a new message of role "assistant"."""
    return chat_pb2.Message(role=chat_pb2.MessageRole.ROLE_ASSISTANT, content=[_process_content(c) for c in args])


def system(*args: Content) -> chat_pb2.Message:
    """Creates a new message of role "system"."""
    return chat_pb2.Message(role=chat_pb2.MessageRole.ROLE_SYSTEM, content=[_process_content(c) for c in args])


def developer(*args: Content) -> chat_pb2.Message:
    """Creates a new message of role "developer".

    Note: This role is only supported by model versions higher than `grok-4.1`/`grok-4-1` (not included). Using the
    `developer` role in `grok-4.1` or below will be converted to `system` message by the backend.
    """
    return chat_pb2.Message(role=chat_pb2.MessageRole.ROLE_DEVELOPER, content=[_process_content(c) for c in args])


def tool_result(result: str, tool_call_id: Optional[str] = None) -> chat_pb2.Message:
    """Creates a new message of role "tool".

    Use this to provide the result of a client-side tool execution back to the model in the conversation history.
    This enables multi-turn tool use and agentic workflows: the model calls a tool, you execute it, then append
    the result.

    Args:
        result: The string output/result from your tool's execution. This will be sent to the model as content.
        tool_call_id: Optional ID linking this result to a specific tool call (should match `tool_call.id` from
            the assistant's tool_calls list). Essential for parallel_tool_calls or multiple tools to associate results
            correctly.
            If omitted (for single tool calls), the association may still work but is less explicit.

    Examples:
        Basic function calling loop:
        ```python
        from xai_sdk.chat import tool_result
        import json

        # ... after chat.sample() or in stream
        if response.tool_calls:
            for tool_call in response.tool_calls:
                # Parse and execute
                args = json.loads(tool_call.function.arguments)
                tool_output = get_weather(args["city"])  # Your tool function
                # Append with tool_call_id for proper linking
                chat.append(tool_result(tool_output, tool_call_id=tool_call.id))
            # Continue conversation
            response = chat.sample()
        ```

        See `examples/sync/function_calling.py` and `examples/sync/server_side_tools.py` (for mixed client/server
        tools) for complete patterns.

    Returns:
        A `chat_pb2.Message` object with ROLE_TOOL, ready to append to chat.
    """
    return chat_pb2.Message(role=chat_pb2.MessageRole.ROLE_TOOL, content=[text(result)], tool_call_id=tool_call_id)


def tool(name: str, description: str, parameters: dict[str, Any]) -> chat_pb2.Tool:
    """Creates a new tool for function calling in chat conversations.

    This function defines a tool that the model can call to perform specific tasks, such as executing a function
    with provided arguments. The tool is represented as a `chat_pb2.Tool` object, which includes a function
    specification with a name, description, and JSON schema for the parameters. The model uses this schema to
    generate valid JSON inputs for the function when it decides to call the tool. Tools are typically passed to
    the `create` method of a chat client to enable function calling in a conversation.

    Examples:
        Using a Pydantic model to define the parameter schema:
        ```python
        from pydantic import BaseModel, Field
        from xai_sdk import Client
        from xai_sdk.chat import system, tool

        class GetWeatherRequest(BaseModel):
            city: str = Field(description="The name of the city to get the weather for.")
            units: Literal["C", "F"] = Field(description="The units to use for the temperature.")

        client = Client()

        weather_tool = tool(
            name="get_weather",
            description="Get the weather for a given city.",
            parameters=GetWeatherRequest.model_json_schema(),
        )

        conversation = client.chat.create(
            model="grok-3",
            messages=[system("You are a helpful assistant.")],
            tools=[weather_tool],
        )
        ```

        Using an explicit JSON schema definition:
        ```python
        from xai_sdk import Client
        from xai_sdk.chat import system, tool

        client = Client()

        weather_tool = tool(
            name="get_weather",
            description="Get the weather for a given city.",
            parameters={
                "type": "object",
                "properties": {
                    "city": {"type": "string", "description": "The name of the city to get the weather for."},
                    "units": {
                        "type": "string",
                        "description": "The units to use for the temperature.",
                        "enum": ["C", "F"],
                    },
                },
                "required": ["city", "units"],
            },
        )

        conversation = client.chat.create(
            model="grok-3",
            messages=[system("You are a helpful assistant.")],
            tools=[weather_tool],
        )
        ```

    Args:
        name: The name of the function that the model can call. This should be unique and descriptive
            (e.g., "get_weather").
        description: A brief description of what the function does, helping the model understand when to call it
            (e.g., "Get the weather for a given city.").
        parameters: A JSON schema dictionary or a dictionary derived from a Pydantic model's `model_json_schema()`
            that defines the structure and types of the function's input parameters.

    Returns:
        A `chat_pb2.Tool` object representing the function, which can be passed to the `tools` parameter of a
        chat client's `create` method.

    Note:
        - The `parameters` dictionary is serialized to a JSON string internally, so it must be JSON-serializable.
        - A maximum of 128 tools can be provided to a chat conversation.
        - The model decides whether to call the tool based on the conversation context and the tool's description.
    """
    return chat_pb2.Tool(
        function=chat_pb2.Function(
            name=name,
            description=description,
            parameters=json.dumps(parameters),
        )
    )


def required_tool(name: str) -> chat_pb2.ToolChoice:
    """Creates a new tool choice with function name `name`.

    Use this to force the model to always invoke a specific tool.
    `name` must be the name of a tool that has been provided in the `tools` parameter of a chat client's
    `create` method.
    """
    return chat_pb2.ToolChoice(function_name=name)


def text(content: str) -> chat_pb2.Content:
    """Returns a new content object of type text."""
    return chat_pb2.Content(text=content)


def image(image_url: str, *, detail: Optional[ImageDetail] = "auto") -> chat_pb2.Content:
    """Creates a new content object of type image for use in chat messages.

    Args:
        image_url: The URL or base64-encoded string of the image. Supported formats are PNG and JPG.
            If a URL is provided, the image is fetched for each API request without caching.
            Fetching uses the "XaiImageApiFetch/1.0" user agent with a 5-second timeout.
            The maximum image size is 10 MiB; larger images or failed fetches will cause the API request to fail.
        detail: Specifies the image resolution for model processing. One of:
        - `"auto"`: The system selects an appropriate resolution (default).
        - `"low"`: Uses a low-resolution image, reducing token usage and increasing speed.
        - `"high"`: Uses a high-resolution image, increasing token usage and processing time
            but capturing more detail.

    Returns:
        A `chat_pb2.Content` object representing the image content.
    """
    pb_detail = image_pb2.ImageDetail.DETAIL_AUTO
    if detail == "low":
        pb_detail = image_pb2.ImageDetail.DETAIL_LOW
    elif detail == "high":
        pb_detail = image_pb2.ImageDetail.DETAIL_HIGH

    return chat_pb2.Content(image_url=image_pb2.ImageUrlContent(image_url=image_url, detail=pb_detail))


def file(file_id: str) -> chat_pb2.Content:
    """Creates a new content object of type file for use in chat messages.

    This allows you to reference previously uploaded files in chat conversations.
    The model can read and analyze the file content.

    Args:
        file_id: The ID of a previously uploaded file. You can obtain this ID by
            uploading a file using the Files API (`client.files.upload(...)`).

    Returns:
        A `chat_pb2.Content` object representing the file reference.
    """
    return chat_pb2.Content(file=chat_pb2.FileContent(file_id=file_id))


def _process_content(content: Content) -> chat_pb2.Content:
    """Converts a `Content` type to a proto."""
    if isinstance(content, str):
        return text(content)
    else:
        return content


def _reasoning_effort_to_proto(effort: ReasoningEffort) -> chat_pb2.ReasoningEffort:
    """Converts a `ReasoningEffort` literal to a proto."""
    match effort:
        case "low":
            return chat_pb2.ReasoningEffort.EFFORT_LOW
        case "high":
            return chat_pb2.ReasoningEffort.EFFORT_HIGH
        case _:
            raise ValueError(f"Invalid reasoning effort: {effort}. Must be one of: {ReasoningEffort.__args__}")


def _include_option_to_proto(include_option: IncludeOption) -> chat_pb2.IncludeOption:
    """Converts a `IncludeOption` literal to a proto."""
    if include_option in IncludeOptionMap:
        return IncludeOptionMap[include_option]
    raise ValueError(f"Invalid include option: {include_option}. Must be one of: {IncludeOptionMap.keys()}")


def _tool_mode_to_proto(mode: ToolMode) -> chat_pb2.ToolMode:
    """Converts a `ToolMode` literal to a proto."""
    match mode:
        case "auto":
            return chat_pb2.ToolMode.TOOL_MODE_AUTO
        case "none":
            return chat_pb2.ToolMode.TOOL_MODE_NONE
        case "required":
            return chat_pb2.ToolMode.TOOL_MODE_REQUIRED
        case _:
            raise ValueError(f"Invalid tool mode: {mode}. Must be one of: {ToolMode.__args__}")


def _format_type_to_proto(format_type: ResponseFormat) -> chat_pb2.FormatType:
    """Converts a `FormatType` literal to a proto."""
    match format_type:
        case "text":
            return chat_pb2.FORMAT_TYPE_TEXT
        case "json_object":
            return chat_pb2.FORMAT_TYPE_JSON_OBJECT
        case "json_schema":
            return chat_pb2.FORMAT_TYPE_JSON_SCHEMA
        case _:
            raise ValueError(f"Invalid response format: {format_type}. Must be one of: {ResponseFormat.__args__}")


class Chunk(ProtoDecorator[chat_pb2.GetChatCompletionChunk]):
    """Adds convenience functions to the chunk proto."""

    _index: int | None

    def __init__(self, proto: chat_pb2.GetChatCompletionChunk, index: int | None):
        """Creates a new decorator instance.

        Args:
            proto: Chunk proto to wrap.
            index: Index of the response to track. If set to None, the chunk will expose all assistant outputs.
        """
        super().__init__(proto)
        self._index = index

    @property
    def choices(self) -> Sequence["CompletionOutputChunk"]:
        """Returns the completion output chunks belonging to this index."""
        return [
            CompletionOutputChunk(output)
            for output in self.proto.outputs
            if output.delta.role == chat_pb2.MessageRole.ROLE_ASSISTANT
            and (output.index == self._index or self._index is None)
        ]

    @property
    def created(self) -> datetime.datetime:
        """Returns the creation timestamp of this chunk."""
        return self.proto.created.ToDatetime()

    @property
    def output(self) -> str:
        """Concatenates all chunks into a single string."""
        return "".join(c.content + c.reasoning_content for c in self.choices)

    @property
    def content(self) -> str:
        """Concatenates all content chunks into a single string."""
        return "".join(c.content for c in self.choices)

    @property
    def reasoning_content(self) -> str:
        """Concatenates all reasoning chunks into a single string."""
        return "".join(c.reasoning_content for c in self.choices)

    @property
    def tool_calls(self) -> Sequence[chat_pb2.ToolCall]:
        """Returns the tool calls of this chunk."""
        tool_calls = []
        for c in self.choices:
            tool_calls.extend(c.tool_calls)
        return tool_calls

    @property
    def server_side_tool_usage(self) -> dict[str, int]:
        """Returns the server side tools used for this chunk."""
        tools_used = [usage_pb2.ServerSideTool.Name(tool) for tool in self.proto.usage.server_side_tools_used]
        return dict(Counter(tools_used))

    @property
    def citations(self) -> Sequence[str]:
        """Returns the citations of this chunk."""
        return self.proto.citations

    @property
    def inline_citations(self) -> Sequence[chat_pb2.InlineCitation]:
        """Returns the inline citations of this chunk.

        Inline citations provide structured citation metadata with position information,
        enabling you to know exactly where in the response text each citation appears.

        Each InlineCitation contains:
        - id: Display number as a string (e.g., "1", "2")
        - start_index: Character position where the citation starts in the response text
        - end_index: Character position where the citation ends (exclusive)
        - web_citation: Present if the citation is from a web source
        - x_citation: Present if the citation is from an X/Twitter source
        - collections_citation: Present if the citation is from a collections search

        Note: Inline citations are only populated when `include=["inline_citations"]`
        is passed when creating the chat.
        """
        inline_citations = []
        for c in self.choices:
            inline_citations.extend(c.proto.delta.citations)
        return inline_citations

    @property
    def tool_outputs(self) -> Sequence["CompletionOutputChunk"]:
        """Returns the completion output chunks that contain the tool outputs."""
        return [
            CompletionOutputChunk(output)
            for output in self.proto.outputs
            if output.delta.role == chat_pb2.MessageRole.ROLE_TOOL
            and (output.index == self._index or self._index is None)
        ]

    @property
    def debug_output(self) -> chat_pb2.DebugOutput:
        """Returns the debug output of this chunk."""
        return self.proto.debug_output

    def __str__(self):
        """Concatenates all chunks into a single string."""
        return "".join(c.content + c.reasoning_content for c in self.choices)


class CompletionOutputChunk(ProtoDecorator[chat_pb2.CompletionOutputChunk]):
    """Adds convenience functions to the completion output chunk proto."""

    @property
    def content(self) -> str:
        """Returns the main content/answer of this completion output chunk."""
        return self.proto.delta.content

    @property
    def reasoning_content(self) -> str:
        """Returns the reasoning content of this completion output chunk."""
        return self.proto.delta.reasoning_content

    @property
    def role(self) -> str:
        """Returns the role of this completion output chunk."""
        return chat_pb2.MessageRole.Name(self.proto.delta.role)

    @property
    def tool_calls(self) -> Sequence[chat_pb2.ToolCall]:
        """Returns the tool calls of this completion output chunk."""
        return self.proto.delta.tool_calls

    @property
    def finish_reason(self) -> sample_pb2.FinishReason:
        """Returns the finish reason of this completion output chunk."""
        return self.proto.finish_reason


class _ResponseProtoDecorator(ProtoDecorator[chat_pb2.GetChatCompletionResponse]):
    def __init__(self, proto: chat_pb2.GetChatCompletionResponse) -> None:
        """Initialize with proto and content buffers for efficient accumulation."""
        super().__init__(proto)
        # Buffers for efficient string accumulation: dict[output_index, [chunk_ones_content, chunk_twos_content, ...]]
        # for chunks with the same index
        self._content_buffers: dict[int, list[str]] = defaultdict(list)
        self._reasoning_content_buffers: dict[int, list[str]] = defaultdict(list)
        self._encrypted_content_buffers: dict[int, list[str]] = defaultdict(list)
        self._proto_in_sync = True

    def _sync_buffers_to_proto(self) -> None:
        """Materialize buffered content into proto messages."""
        if self._proto_in_sync:
            return

        for index, buffer in self._content_buffers.items():
            if buffer and index < len(self._proto.outputs):
                self._proto.outputs[index].message.content = "".join(buffer)
                self._content_buffers[index] = [self._proto.outputs[index].message.content]

        for index, buffer in self._reasoning_content_buffers.items():
            if buffer and index < len(self._proto.outputs):
                self._proto.outputs[index].message.reasoning_content = "".join(buffer)
                self._reasoning_content_buffers[index] = [self._proto.outputs[index].message.reasoning_content]

        for index, buffer in self._encrypted_content_buffers.items():
            if buffer and index < len(self._proto.outputs):
                self._proto.outputs[index].message.encrypted_content = "".join(buffer)
                self._encrypted_content_buffers[index] = [self._proto.outputs[index].message.encrypted_content]

        self._proto_in_sync = True

    @property
    def proto(self) -> chat_pb2.GetChatCompletionResponse:
        """Ensure buffers are synced before returning proto."""
        self._sync_buffers_to_proto()
        return self._proto

    def process_chunk(self, chunk: chat_pb2.GetChatCompletionChunk):
        # Consolidate the response.
        self._proto.usage.CopyFrom(chunk.usage)
        self._proto.created.CopyFrom(chunk.created)
        self._proto.id = chunk.id
        self._proto.model = chunk.model
        self._proto.system_fingerprint = chunk.system_fingerprint
        self._proto.citations.extend(chunk.citations)

        # Make sure all chunk outputs has corresponding response outputs.
        if chunk.outputs:
            max_index = max(c.index for c in chunk.outputs)
            if max_index >= len(self._proto.outputs):
                self._proto.outputs.extend(
                    [chat_pb2.CompletionOutput() for _ in range(max_index + 1 - len(self._proto.outputs))]
                )

        for c in chunk.outputs:
            choice = self._proto.outputs[c.index]
            choice.index = c.index
            choice.message.role = c.delta.role
            choice.message.tool_calls.extend(c.delta.tool_calls)
            # c.delta.citations represents the inline citations on this chunk
            choice.message.citations.extend(c.delta.citations)
            choice.finish_reason = c.finish_reason

            # Accumulate content in buffers instead of concatenating strings
            if c.delta.content:
                self._content_buffers[c.index].append(c.delta.content)
                self._proto_in_sync = False

            if c.delta.reasoning_content:
                self._reasoning_content_buffers[c.index].append(c.delta.reasoning_content)
                self._proto_in_sync = False

            if c.delta.encrypted_content:
                self._encrypted_content_buffers[c.index].append(c.delta.encrypted_content)
                self._proto_in_sync = False


class Response(_ResponseProtoDecorator):
    """Response of a chat request."""

    # A single request can produce multiple responses. This index is used to retrieve the content of
    # a single answer from the response proto.
    _index: int | None

    def __init__(self, response: chat_pb2.GetChatCompletionResponse, index: int | None) -> None:
        """Initializes a new instance of the `Response` class.

        Args:
            response: The response proto, which can hold multiple answers.
            index: The index of the answer this class exposes via its convenience methods.
                If set to None, the response will expose all answers, the content and reasoning content
                will be only from the assistant response.
        """
        super().__init__(response)
        self._index = index

    def _get_output(self, *, sync: bool = False) -> chat_pb2.CompletionOutput:
        # Sync buffers to proto only when content is needed
        if sync:
            self._sync_buffers_to_proto()

        outputs = [
            output
            for output in self._proto.outputs
            if output.message.role == chat_pb2.MessageRole.ROLE_ASSISTANT
            and (output.index == self._index or self._index is None)
        ]
        if not outputs:
            return chat_pb2.CompletionOutput()
        return outputs[-1]

    @property
    def id(self) -> str:
        """Returns the id of this response."""
        return self._proto.id

    @property
    def created(self) -> datetime.datetime:
        """Returns the creation timestamp of this response."""
        return self._proto.created.ToDatetime()

    @property
    def content(self) -> str:
        """Returns the answer content of this response."""
        return self._get_output(sync=True).message.content

    @property
    def encrypted_content(self) -> str:
        """Returns the encrypted reasoning content from the model response."""
        return self._get_output(sync=True).message.encrypted_content

    @property
    def role(self) -> str:
        """Returns the role of this response."""
        return chat_pb2.MessageRole.Name(self._get_output(sync=False).message.role)

    @property
    def usage(self) -> usage_pb2.SamplingUsage:
        """Returns the usage of this response."""
        return self._proto.usage

    @property
    def reasoning_content(self) -> str:
        """Returns the reasoning trace generated by the model.

        This is only available for models that support reasoning.
        """
        return self._get_output(sync=True).message.reasoning_content

    @property
    def finish_reason(self) -> str:
        """Returns the finish reason of this response."""
        return sample_pb2.FinishReason.Name(self._get_output(sync=False).finish_reason)

    @property
    def logprobs(self) -> chat_pb2.LogProbs:
        """Returns the logprobs of this response."""
        return self._get_output(sync=False).logprobs

    @property
    def system_fingerprint(self) -> str:
        """Returns the system fingerprint of this response."""
        return self.proto.system_fingerprint

    @property
    def tool_calls(self) -> Sequence[chat_pb2.ToolCall]:
        """Returns the all tool calls of this response."""
        return [
            tc
            for output in self.proto.outputs
            if output.message.role == chat_pb2.MessageRole.ROLE_ASSISTANT
            for tc in output.message.tool_calls
        ]

    @property
    def citations(self) -> Sequence[str]:
        """Returns the citations of this response."""
        return self.proto.citations

    @property
    def inline_citations(self) -> Sequence[chat_pb2.InlineCitation]:
        """Returns the inline citations of this response.

        Inline citations provide structured citation metadata with position information,
        enabling you to know exactly where in the response text each citation appears.

        Each InlineCitation contains:
        - id: Display number as a string (e.g., "1", "2")
        - start_index: Character position where the citation starts in the response text
        - end_index: Character position where the citation ends (exclusive)
        - web_citation: Present if the citation is from a web source
        - x_citation: Present if the citation is from an X/Twitter source
        - collections_citation: Present if the citation is from a collections search

        Note: Inline citations are only populated when `include=["inline_citations"]`
        is passed when creating the chat.
        """
        return [
            citation
            for output in self.proto.outputs
            if output.message.role == chat_pb2.MessageRole.ROLE_ASSISTANT
            for citation in output.message.citations
        ]

    @property
    def tool_outputs(self) -> Sequence[chat_pb2.CompletionOutput]:
        """Returns the output entries that contain the tool outputs."""
        return [output for output in self.proto.outputs if output.message.role == chat_pb2.MessageRole.ROLE_TOOL]

    @property
    def server_side_tool_usage(self) -> dict[str, int]:
        """Returns the server side tools used for this response."""
        tools_used = [usage_pb2.ServerSideTool.Name(tool) for tool in self.proto.usage.server_side_tools_used]
        return dict(Counter(tools_used))

    @property
    def request_settings(self) -> chat_pb2.RequestSettings:
        """Returns the request settings, i.e. the model parameters set on the request used to generate this response."""
        return self.proto.settings

    @property
    def debug_output(self) -> chat_pb2.DebugOutput:
        """Returns the debug output of this response. Only available to trusted testers."""
        return self.proto.debug_output
