from typing import Any

import logfire
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession


async def hybrid_search(
    session: AsyncSession,
    query_text: str,
    query_embedding: list[float],
    restaurant_id: str,
    limit: int = 5,
    min_score: float = 0.008,
    language: str = "dutch",
    source: str | None = None,
) -> list[dict[str, Any]]:
    if language not in {"dutch", "english"}:
        language = "dutch"

    with logfire.span(
        "rag.hybrid_search",
        query_length=len(query_text),
        language=language,
        source_filter=source,
        restaurant_id=restaurant_id,
    ) as span:
        sql = text(
            """
            WITH vector_ranked AS (
                SELECT id,
                       ROW_NUMBER() OVER (
                           ORDER BY embedding <=> CAST(:query_embedding AS vector)
                       ) AS rank
                FROM knowledge_document
                WHERE restaurant_id = :restaurant_id
                  AND (CAST(:source AS text) IS NULL OR source = CAST(:source AS text))
                  AND embedding IS NOT NULL
            ),
            fts_ranked AS (
                SELECT id,
                       ROW_NUMBER() OVER (
                           ORDER BY ts_rank(
                               to_tsvector(:language, content),
                               plainto_tsquery(:language, :query_text)
                           ) DESC
                       ) AS rank
                FROM knowledge_document
                WHERE restaurant_id = :restaurant_id
                  AND (CAST(:source AS text) IS NULL OR source = CAST(:source AS text))
                  AND to_tsvector(:language, content) @@ plainto_tsquery(:language, :query_text)
            ),
            combined AS (
                SELECT COALESCE(v.id, f.id) AS id,
                       COALESCE(1.0 / (60 + v.rank), 0) +
                       COALESCE(1.0 / (60 + f.rank), 0) AS rrf_score
                FROM vector_ranked v
                FULL OUTER JOIN fts_ranked f ON v.id = f.id
            )
            SELECT kd.id, kd.content, kd.source, kd.metadata, c.rrf_score
            FROM combined c
            JOIN knowledge_document kd ON kd.id = c.id
            WHERE c.rrf_score >= :min_score
            ORDER BY c.rrf_score DESC
            LIMIT :limit
            """
        )

        result = await session.execute(
            sql,
            {
                "query_embedding": str(query_embedding),
                "query_text": query_text,
                "restaurant_id": restaurant_id,
                "source": source,
                "limit": limit,
                "min_score": min_score,
                "language": language,
            },
        )

        rows = result.fetchall()
        results = [
            {
                "id": row.id,
                "content": row.content,
                "source": row.source,
                "metadata": row.metadata,
                "score": float(row.rrf_score),
            }
            for row in rows
        ]

        span.set_attribute("result_count", len(results))
        span.set_attribute("max_score", max((r["score"] for r in results), default=0.0))
        span.set_attribute("min_score", min((r["score"] for r in results), default=0.0))

        if not results:
            logfire.warn(
                "rag.hybrid_search.empty",
                restaurant_id=restaurant_id,
                query_length=len(query_text),
                language=language,
            )

        return results
