from typing import Any

import logfire
from pydantic_ai import RunContext
from sqlmodel import select

from app.agents.deps import AgentDeps
from app.models.restaurant import Restaurant
from app.rag.embeddings import generate_embedding
from app.rag.search import hybrid_search

LANGUAGE_FTS_MAP = {"nl": "dutch", "en": "english"}


async def search_knowledge_base_impl(
    ctx: RunContext[AgentDeps],
    query: str,
    source: str | None = None,
) -> list[dict[str, Any]]:
    embedding = await generate_embedding(query)
    language_fts = LANGUAGE_FTS_MAP.get(ctx.deps.language, "dutch")
    results = await hybrid_search(
        session=ctx.deps.session,
        query_text=query,
        query_embedding=embedding,
        restaurant_id=ctx.deps.restaurant_id,
        limit=5,
        language=language_fts,
        source=source,
    )
    logfire.info(
        "rag.retrieval_quality",
        results_returned=len(results),
        results_above_threshold=len(results),
        results_below_threshold=0,
    )
    return [
        {
            "content": r["content"],
            "source": r["source"],
            "score": r["score"],
            "metadata": r.get("metadata") or {},
        }
        for r in results
    ]


async def get_restaurant_policies_impl(
    ctx: RunContext[AgentDeps],
) -> dict[str, Any]:
    session = ctx.deps.session
    result = await session.execute(
        select(Restaurant).where(Restaurant.id == ctx.deps.restaurant_id)
    )
    restaurant = result.scalar_one_or_none()
    settings_data = (restaurant.settings or {}) if restaurant else {}

    return {
        "min_advance_hours": settings_data.get("min_advance_hours", 1),
        "max_advance_days": settings_data.get("max_advance_days", 90),
        "max_party_size": settings_data.get("max_party_size", 20),
        "cancellation_policy": settings_data.get("cancellation_policy", None),
    }


async def is_open_now_impl(ctx: RunContext[AgentDeps]) -> dict[str, Any]:
    """Return whether the restaurant is open right now (local time).
    Includes current local date/time, matched window if open, and next steps if closed.
    """
    from datetime import datetime
    from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

    from app.services.service_blocks import find_block_for_time, resolve_service_blocks

    session = ctx.deps.session
    res = await session.execute(select(Restaurant).where(Restaurant.id == ctx.deps.restaurant_id))
    restaurant = res.scalar_one_or_none()
    tz_name = (getattr(restaurant, "timezone", None) or "").strip() or "Europe/Amsterdam"
    try:
        tz = ZoneInfo(tz_name)
    except (ZoneInfoNotFoundError, KeyError):
        tz = ZoneInfo("Europe/Amsterdam")
    now_local = datetime.now(tz)
    blocks = await resolve_service_blocks(session, ctx.deps.restaurant_id, now_local.date())
    # Only consider 'open' blocks
    windows = [rb for rb in blocks if getattr(rb.block, "block_type", "open") == "open"]
    match = find_block_for_time(windows, now_local.time()) if windows else None
    if match is None:
        return {
            "open": False,
            "date": now_local.date().isoformat(),
            "time": now_local.strftime("%H:%M"),
        }
    start_s = match.block.start_time.strftime("%H:%M")
    end_s = match.block.end_time.strftime("%H:%M")
    return {
        "open": True,
        "date": now_local.date().isoformat(),
        "time": now_local.strftime("%H:%M"),
        "current_window": f"{start_s}-{end_s}",
    }


async def get_opening_hours_impl(
    ctx: RunContext[AgentDeps],
    start_date: str | None = None,
    days: int = 14,
) -> list[dict[str, Any]]:
    """Return opening hours per day for a date range, honoring overrides.

    - If ``start_date`` is None, use today's date in the restaurant's local timezone.
    - ``days`` is capped to [1, 31] to avoid excessive ranges.
    - Uses resolve_service_blocks() so holiday/override closures are reflected.
    """
    from datetime import datetime, timedelta
    from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

    from app.services.service_blocks import resolve_service_blocks

    session = ctx.deps.session
    # Load restaurant to determine timezone
    res = await session.execute(select(Restaurant).where(Restaurant.id == ctx.deps.restaurant_id))
    restaurant = res.scalar_one_or_none()
    tz_name = (getattr(restaurant, "timezone", None) or "").strip() or "Europe/Amsterdam"
    try:
        tz = ZoneInfo(tz_name)
    except (ZoneInfoNotFoundError, KeyError):
        tz = ZoneInfo("Europe/Amsterdam")

    if start_date:
        try:
            start = datetime.strptime(start_date, "%Y-%m-%d").date()
        except ValueError:
            # Fallback to 'today' in local time on parse failure
            start = datetime.now(tz).date()
    else:
        start = datetime.now(tz).date()

    days = max(1, min(int(days), 31))

    out: list[dict[str, Any]] = []
    for i in range(days):
        d = start + timedelta(days=i)
        # Resolve blocks with overrides for this date
        blocks = await resolve_service_blocks(session, ctx.deps.restaurant_id, d)
        windows: list[tuple[str, str]] = []
        for rb in blocks:
            b = rb.block
            if getattr(b, "block_type", "open") != "open":
                continue
            start_s = b.start_time.strftime("%H:%M")
            end_s = b.end_time.strftime("%H:%M")
            windows.append((start_s, end_s))
        # Merge overlapping/duplicate windows
        windows.sort()
        merged: list[tuple[str, str]] = []
        for s, e in windows:
            if not merged:
                merged.append((s, e))
                continue
            ps, pe = merged[-1]
            if s <= pe:  # overlap/adjacent
                merged[-1] = (ps, max(pe, e))
            else:
                merged.append((s, e))
        out.append(
            {
                "date": d.isoformat(),
                "windows": [f"{s}-{e}" for s, e in merged],
                "closed": len(merged) == 0,
            }
        )
    return out
