"""Tests for generate_embeddings_batch and knowledge backfill batching."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from app.rag.embeddings import generate_embeddings_batch
from app.routers.knowledge import backfill_embeddings

pytestmark = pytest.mark.anyio

EMBEDDINGS_MOD = "app.rag.embeddings"
KNOWLEDGE_MOD = "app.routers.knowledge"


def _mock_embedding_item(index: int, embedding: list[float]) -> MagicMock:
    item = MagicMock()
    item.index = index
    item.embedding = embedding
    return item


def _scalars_result(values: list[MagicMock]) -> MagicMock:
    result = MagicMock()
    scalars_mock = MagicMock()
    scalars_mock.all.return_value = values
    result.scalars.return_value = scalars_mock
    return result


def _knowledge_doc(content: str) -> MagicMock:
    doc = MagicMock()
    doc.content = content
    doc.embedding = None
    return doc


async def test_batch_empty_input_returns_empty() -> None:
    result = await generate_embeddings_batch([])

    assert result == []


async def test_batch_returns_embeddings_in_order() -> None:
    mock_response = MagicMock()
    mock_response.data = [
        _mock_embedding_item(1, [0.2] * 1536),
        _mock_embedding_item(0, [0.1] * 1536),
    ]
    mock_client = AsyncMock()
    mock_client.embeddings.create.return_value = mock_response

    with patch(f"{EMBEDDINGS_MOD}.get_settings") as mock_settings:
        mock_settings.return_value.APP_OPENAI_API_KEY = "test-key"
        with patch(f"{EMBEDDINGS_MOD}.AsyncOpenAI", return_value=mock_client):
            result = await generate_embeddings_batch(["text1", "text2"])

    assert len(result) == 2
    assert result[0] == [0.1] * 1536
    assert result[1] == [0.2] * 1536


async def test_batch_no_api_key_raises_value_error() -> None:
    with patch(f"{EMBEDDINGS_MOD}.get_settings") as mock_settings:
        mock_settings.return_value.APP_OPENAI_API_KEY = ""

        with pytest.raises(ValueError, match="APP_OPENAI_API_KEY not configured"):
            await generate_embeddings_batch(["some text"])


async def test_batch_api_failure_propagates() -> None:
    mock_client = AsyncMock()
    mock_client.embeddings.create.side_effect = RuntimeError("API error")

    with (
        patch(f"{EMBEDDINGS_MOD}.get_settings") as mock_settings,
        patch(f"{EMBEDDINGS_MOD}.AsyncOpenAI", return_value=mock_client),
        pytest.raises(RuntimeError, match="API error"),
    ):
        mock_settings.return_value.APP_OPENAI_API_KEY = "test-key"
        await generate_embeddings_batch(["text"])


async def test_backfill_endpoint_batches_updates_and_commits() -> None:
    docs = [_knowledge_doc(f"doc-{index}") for index in range(101)]
    session = AsyncMock()
    session.execute.return_value = _scalars_result(docs)
    restaurant = MagicMock()
    restaurant.id = "restaurant-1"

    first_batch_embeddings = [[float(index)] for index in range(100)]
    second_batch_embeddings = [[100.0]]

    with patch(
        f"{KNOWLEDGE_MOD}.generate_embeddings_batch",
        new=AsyncMock(side_effect=[first_batch_embeddings, second_batch_embeddings]),
    ) as mock_batch:
        result = await backfill_embeddings(session=session, restaurant=restaurant)

    assert result.processed == 101
    assert result.failed == 0
    assert docs[0].embedding == [0.0]
    assert docs[99].embedding == [99.0]
    assert docs[100].embedding == [100.0]
    assert mock_batch.await_count == 2
    assert mock_batch.await_args_list[0].args == ([f"doc-{index}" for index in range(100)],)
    assert mock_batch.await_args_list[1].args == (["doc-100"],)
    session.commit.assert_awaited_once_with()


async def test_backfill_endpoint_counts_failed_batches_and_continues() -> None:
    docs = [_knowledge_doc(f"doc-{index}") for index in range(101)]
    session = AsyncMock()
    session.execute.return_value = _scalars_result(docs)
    restaurant = MagicMock()
    restaurant.id = "restaurant-1"

    with (
        patch(
            f"{KNOWLEDGE_MOD}.generate_embeddings_batch",
            new=AsyncMock(side_effect=[RuntimeError("boom"), [[100.0]]]),
        ) as mock_batch,
        patch(f"{KNOWLEDGE_MOD}.logfire.error") as mock_logfire_error,
    ):
        result = await backfill_embeddings(
            restaurant_id="restaurant-override",
            session=session,
            restaurant=restaurant,
        )

    assert result.processed == 1
    assert result.failed == 100
    assert docs[0].embedding is None
    assert docs[99].embedding is None
    assert docs[100].embedding == [100.0]
    assert mock_batch.await_count == 2
    mock_logfire_error.assert_called_once_with(
        "knowledge_backfill_batch_failed",
        error="boom",
        restaurant_id="restaurant-override",
        batch_size=100,
        retriable=True,
    )
    session.commit.assert_awaited_once_with()
