"""Shared agent execution infrastructure for all channels (web, whatsapp).

This module is the single source of truth for:
- Building ``AgentDeps`` (capability filtering, config)
- Agent invocation constants (usage limits)
- Non-streaming agent execution with tracing and intermediate-text callbacks

Web/SSE streaming stays in ``routers/chat.py`` because the streaming loop is
tightly coupled to the SSE framing.  Both channels use ``build_agent_deps``
and ``AGENT_USAGE_LIMITS`` from here.
"""

from __future__ import annotations

from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field

import httpx
import logfire
from pydantic_ai import (
    Agent,
    FinalResultEvent,
    PartDeltaEvent,
    PartStartEvent,
    TextPartDelta,
    UsageLimits,
)
from pydantic_ai.messages import ModelMessage, TextPart
from sqlalchemy.ext.asyncio import AsyncSession

from app.agents.deps import AgentDeps, CallerIdentity
from app.agents.history import serialize_messages
from app.agents.restaurant import restaurant_agent
from app.models.restaurant import Restaurant
from app.schemas.agent_config import get_agent_config, get_enabled_agents

# ------------------------------------------------------------------
# Constants
# ------------------------------------------------------------------

AGENT_USAGE_LIMITS = UsageLimits(request_limit=25, total_tokens_limit=50_000)

DEFAULT_TIMEZONE = "Europe/Brussels"

# ------------------------------------------------------------------
# Deps construction
# ------------------------------------------------------------------


def build_agent_deps(
    *,
    session: AsyncSession,
    restaurant: Restaurant,
    http_client: httpx.AsyncClient,
    caller: CallerIdentity,
    language: str,
) -> AgentDeps:
    """Build ``AgentDeps`` with capability filtering derived from restaurant settings.

    This is the **only** place where ``get_enabled_agents`` and
    ``get_agent_config`` should be called during agent setup.  Both web and
    WhatsApp channels use this function so capability logic stays consistent.
    """
    return AgentDeps(
        session=session,
        restaurant_id=restaurant.id,
        http_client=http_client,
        caller=caller,
        language=language,
        timezone=restaurant.timezone or DEFAULT_TIMEZONE,
        enabled_capabilities=get_enabled_agents(restaurant.settings),
        agent_config=get_agent_config(restaurant.settings, "faq"),
    )


# ------------------------------------------------------------------
# Non-streaming execution (WhatsApp, future batch channels)
# ------------------------------------------------------------------


@dataclass
class AgentRunResult:
    """Outcome of a non-streaming agent run."""

    output_text: str = ""
    all_messages_json: str = ""
    token_count: int | None = None
    intermediate_texts: list[str] = field(default_factory=list)


async def run_agent_to_result(
    *,
    message: str,
    history: list[ModelMessage],
    deps: AgentDeps,
    conversation_id: str,
    on_intermediate_text: Callable[[str], Awaitable[None]] | None = None,
) -> AgentRunResult:
    """Run ``restaurant_agent`` with tracing, error handling, and optional callbacks.

    Used by the WhatsApp channel (and any future non-streaming channel).
    When the model produces text *before* making tool calls,
    ``on_intermediate_text`` is invoked so the channel can deliver it
    immediately (e.g. as a separate WhatsApp message).

    The web/SSE channel uses ``run_stream_events`` directly because it
    needs per-token deltas — that path lives in ``routers/chat.py``.
    """
    intermediate_texts: list[str] = []

    logfire.info(
        "llm_run_started",
        conversation_id=conversation_id,
        history_items=len(history),
    )
    with logfire.span(
        "agent_conversation",
        restaurant_id=deps.restaurant_id,
        conversation_id=conversation_id,
        language=deps.language,
    ):
        async with restaurant_agent.iter(
            message,
            deps=deps,
            message_history=history,
            usage_limits=AGENT_USAGE_LIMITS,
        ) as run:
            async for node in run:
                if Agent.is_model_request_node(node):
                    model_text = ""
                    is_final = False
                    async with node.stream(run.ctx) as stream:
                        async for event in stream:
                            if isinstance(event, PartStartEvent) and isinstance(
                                event.part, TextPart
                            ):
                                model_text += event.part.content
                            elif isinstance(event, PartDeltaEvent) and isinstance(
                                event.delta, TextPartDelta
                            ):
                                model_text += event.delta.content_delta
                            elif isinstance(event, FinalResultEvent):
                                is_final = True

                    if not is_final and model_text.strip() and on_intermediate_text is not None:
                        text = model_text.strip()
                        intermediate_texts.append(text)
                        await on_intermediate_text(text)

            output_text = str(run.result.output).strip()
            usage = run.result.usage()
            token_count = usage.total_tokens if hasattr(usage, "total_tokens") else None
            all_messages_json = serialize_messages(run.result.all_messages())

    return AgentRunResult(
        output_text=output_text,
        all_messages_json=all_messages_json,
        token_count=token_count,
        intermediate_texts=intermediate_texts,
    )
