from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Literal

import logfire
from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

from app.agents.deps import AgentDeps
from app.config import get_settings

settings = get_settings()

# ---------- Keyword sets ----------

_DUTCH_STOP_WORDS = {"de", "een", "voor", "van", "het", "op", "is", "zijn", "worden", "dat"}

_RESERVATION_SIGNALS = {
    "reserv",
    "boek",
    "tafel",
    "book",
    "date",
    "avond",
    "middag",
    "lunch",
    "plek",
    "personen",
    "persoon",
    "beschikbaar",
    "seats",
    "seat",
}
_TAKEAWAY_SIGNALS = {
    "bestell",
    "order",
    "afhaal",
    "menu",
    "pizza",
    "burger",
    "kip",
    "betaal",
    "levering",
    "eten",
    "delivery",
}

# ---------- Keyword heuristic (fallback) ----------


def classify_agent_heuristic(message: str) -> Literal["reservation", "faq", "takeaway"]:
    lower = message.lower()
    res_score = sum(1 for s in _RESERVATION_SIGNALS if s in lower)
    tak_score = sum(1 for s in _TAKEAWAY_SIGNALS if s in lower)
    if res_score > tak_score:
        return "reservation"
    if tak_score > res_score:
        return "takeaway"
    return "faq"


# Backward compat — chat.py callers updated to use classify_with_router
classify_agent = classify_agent_heuristic


def detect_language(message: str) -> Literal["nl", "en"]:
    words = set(message.lower().split())
    if words & _DUTCH_STOP_WORDS:
        return "nl"
    return "en"


# ---------- LLM Router ----------


@dataclass
class AgentResult:
    """Classification result from the router."""

    agent_type: str


class RouterFailure(BaseModel):
    """Returned when the router cannot confidently classify."""

    explanation: str
    suggested_response: str


async def hand_off_to_reservation(ctx: RunContext[AgentDeps], message: str) -> AgentResult:
    """Route to reservation agent for booking, modifying, or cancelling reservations."""
    return AgentResult(agent_type="reservation")


async def hand_off_to_faq(ctx: RunContext[AgentDeps], message: str) -> AgentResult:
    """Route to FAQ agent for restaurant, menu, hours, location, policies."""
    return AgentResult(agent_type="faq")


async def hand_off_to_takeaway(ctx: RunContext[AgentDeps], message: str) -> AgentResult:
    """Route to takeaway agent for placing, modifying, or paying for orders."""
    return AgentResult(agent_type="takeaway")


router_agent = Agent(
    model=AnthropicModel(
        settings.ROUTER_MODEL,
        provider=AnthropicProvider(api_key=settings.APP_ANTHROPIC_API_KEY),
    ),
    deps_type=AgentDeps,
    output_type=[hand_off_to_reservation, hand_off_to_faq, hand_off_to_takeaway, RouterFailure],
    system_prompt=(
        "You are a restaurant assistant router. Your ONLY job is to classify "
        "the customer's message and hand off to the correct specialist agent.\n\n"
        "Available agents:\n"
        "- reservation: Booking, modifying, finding, or cancelling table reservations\n"
        "- faq: Questions about the restaurant, menu, opening hours, location, policies\n"
        "- takeaway: Placing, modifying, or paying for takeaway/delivery orders\n\n"
        "Rules:\n"
        "- NEVER answer the customer directly — always hand off\n"
        "- For ambiguous messages, prefer FAQ (safest default)\n"
        "- Pick the PRIMARY intent if the message mentions multiple topics\n"
        "- Pass the customer's original message unchanged to the handoff"
    ),
)


async def classify_with_router(
    message: str,
    deps: AgentDeps,
) -> str:
    """Classify a message using the LLM router, falling back to keyword heuristic."""
    try:
        start = time.monotonic()
        result = await router_agent.run(message, deps=deps)
        latency_ms = (time.monotonic() - start) * 1000

        if isinstance(result.output, AgentResult):
            agent_type = result.output.agent_type
            logfire.info(
                "agent_router_classify",
                agent_type=agent_type,
                router_latency_ms=round(latency_ms, 1),
                restaurant_id=deps.restaurant_id,
                fallback_used=False,
            )
            return agent_type

        if isinstance(result.output, RouterFailure):
            logfire.warning(
                "agent_router_failure",
                explanation=result.output.explanation,
                restaurant_id=deps.restaurant_id,
            )
            return "faq"

        # Unexpected output type
        logfire.warning("agent_router_unexpected_output", output_type=type(result.output).__name__)
        return classify_agent_heuristic(message)

    except Exception as exc:
        logfire.error(
            "agent_router_error",
            error=str(exc),
            restaurant_id=deps.restaurant_id,
        )
        return classify_agent_heuristic(message)
