"""Tests for structured message history helpers."""

from __future__ import annotations

from datetime import UTC, datetime, timedelta

import pytest
from pydantic_ai.messages import (
    ModelMessage,
    ModelRequest,
    ModelResponse,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
    UserPromptPart,
)

from app.agents.history import (
    SESSION_WINDOW_HOURS,
    deserialize_messages,
    is_conversation_stale,
    keep_recent_messages,
    serialize_messages,
)


class TestSerializeDeserialize:
    def test_round_trip_with_user_and_text(self) -> None:
        messages = [
            ModelRequest(parts=[UserPromptPart(content="hello")]),
            ModelResponse(parts=[TextPart(content="hi there")]),
        ]
        json_str = serialize_messages(messages)
        restored = deserialize_messages(json_str)

        assert len(restored) == 2
        assert isinstance(restored[0], ModelRequest)
        assert isinstance(restored[0].parts[0], UserPromptPart)
        assert restored[0].parts[0].content == "hello"
        assert isinstance(restored[1], ModelResponse)
        assert isinstance(restored[1].parts[0], TextPart)
        assert restored[1].parts[0].content == "hi there"

    def test_deserialize_none_returns_empty(self) -> None:
        assert deserialize_messages(None) == []

    def test_deserialize_invalid_json_returns_empty(self) -> None:
        result = deserialize_messages("not valid json")
        assert result == []

    def test_empty_list_round_trip(self) -> None:
        json_str = serialize_messages([])
        assert deserialize_messages(json_str) == []


class TestKeepRecentMessages:
    @pytest.mark.anyio
    async def test_under_limit_unchanged(self) -> None:
        messages = [
            ModelRequest(parts=[UserPromptPart(content="a")]),
            ModelResponse(parts=[TextPart(content="b")]),
        ]
        result = await keep_recent_messages(messages, limit=10)
        assert result == messages

    @pytest.mark.anyio
    async def test_over_limit_trims(self) -> None:
        messages: list[ModelMessage] = [
            ModelRequest(parts=[UserPromptPart(content=f"msg-{i}")]) for i in range(10)
        ]
        result = await keep_recent_messages(messages, limit=3)
        assert len(result) == 3
        first_part = result[0].parts[0]
        last_part = result[2].parts[0]
        assert isinstance(first_part, UserPromptPart)
        assert first_part.content == "msg-7"
        assert isinstance(last_part, UserPromptPart)
        assert last_part.content == "msg-9"

    @pytest.mark.anyio
    async def test_preserves_tool_pair(self) -> None:
        """When trim boundary splits a tool-call/tool-result pair, include both."""
        tool_response = ModelResponse(
            parts=[ToolCallPart(tool_name="search", args='{"q":"x"}', tool_call_id="c1")]
        )
        tool_request = ModelRequest(
            parts=[ToolReturnPart(tool_name="search", content="result", tool_call_id="c1")]
        )
        filler = ModelRequest(parts=[UserPromptPart(content="filler")])

        # [filler, tool_response, tool_request, filler]
        # limit=2 → last 2 = [tool_request, filler]
        # tool_request starts with ToolReturnPart → include tool_response
        messages = [filler, tool_response, tool_request, filler]
        result = await keep_recent_messages(messages, limit=2)

        assert len(result) == 3
        assert result[0] is tool_response
        assert result[1] is tool_request
        assert result[2] is filler

    @pytest.mark.anyio
    async def test_empty_input(self) -> None:
        result = await keep_recent_messages([])
        assert result == []


class TestTwoTurnConversation:
    def test_round_trip_with_tool_usage(self) -> None:
        messages: list[ModelMessage] = [
            ModelRequest(parts=[UserPromptPart(content="Book a table for two tonight")]),
            ModelResponse(
                parts=[
                    ToolCallPart(
                        tool_name="check_availability",
                        args='{"date":"2026-03-12","party_size":2}',
                        tool_call_id="call-1",
                    )
                ]
            ),
            ModelRequest(
                parts=[
                    ToolReturnPart(
                        tool_name="check_availability",
                        content='{"available":true}',
                        tool_call_id="call-1",
                    )
                ]
            ),
            ModelResponse(parts=[TextPart(content="I found a table for two tonight.")]),
            ModelRequest(parts=[UserPromptPart(content="Please confirm it under Alex")]),
            ModelResponse(parts=[TextPart(content="Your reservation is confirmed under Alex.")]),
        ]

        serialized = serialize_messages(messages)
        restored = deserialize_messages(serialized)

        assert len(restored) == 6

        assert isinstance(restored[0], ModelRequest)
        assert isinstance(restored[0].parts[0], UserPromptPart)
        assert restored[0].parts[0].content == "Book a table for two tonight"

        assert isinstance(restored[1], ModelResponse)
        tool_call = restored[1].parts[0]
        assert isinstance(tool_call, ToolCallPart)
        assert tool_call.tool_name == "check_availability"
        assert tool_call.args == '{"date":"2026-03-12","party_size":2}'
        assert tool_call.tool_call_id == "call-1"

        assert isinstance(restored[2], ModelRequest)
        tool_return = restored[2].parts[0]
        assert isinstance(tool_return, ToolReturnPart)
        assert tool_return.tool_name == "check_availability"
        assert tool_return.content == '{"available":true}'
        assert tool_return.tool_call_id == tool_call.tool_call_id

        assert isinstance(restored[3], ModelResponse)
        assert isinstance(restored[3].parts[0], TextPart)
        assert restored[3].parts[0].content == "I found a table for two tonight."

        assert isinstance(restored[4], ModelRequest)
        assert isinstance(restored[4].parts[0], UserPromptPart)
        assert restored[4].parts[0].content == "Please confirm it under Alex"

        assert isinstance(restored[5], ModelResponse)
        assert isinstance(restored[5].parts[0], TextPart)
        assert restored[5].parts[0].content == "Your reservation is confirmed under Alex."


class TestNullMessageHistoryFallback:
    def test_deserialize_none_returns_empty(self) -> None:
        assert deserialize_messages(None) == []

    def test_deserialize_blank_string_returns_empty(self) -> None:
        assert deserialize_messages("") == []

    @pytest.mark.anyio
    async def test_keep_recent_messages_empty_list_returns_empty(self) -> None:
        assert await keep_recent_messages([]) == []

    def test_none_survives_empty_round_trip(self) -> None:
        restored = deserialize_messages(serialize_messages(deserialize_messages(None)))
        assert restored == []


class TestIsConversationStale:
    def test_none_is_not_stale(self) -> None:
        """Brand-new conversations with no messages are not stale."""
        assert is_conversation_stale(None) is False

    def test_recent_message_is_not_stale(self) -> None:
        now = datetime.now(UTC).replace(tzinfo=None)
        recent = now - timedelta(hours=SESSION_WINDOW_HOURS - 1)
        assert is_conversation_stale(recent) is False

    def test_old_message_is_stale(self) -> None:
        now = datetime.now(UTC).replace(tzinfo=None)
        old = now - timedelta(hours=SESSION_WINDOW_HOURS + 1)
        assert is_conversation_stale(old) is True

    def test_exactly_at_boundary_is_not_stale(self) -> None:
        """A message exactly at the boundary should not be stale (< not <=)."""
        # Give 1 second of slack to avoid flakiness from test execution time.
        now = datetime.now(UTC).replace(tzinfo=None)
        at_boundary = now - timedelta(hours=SESSION_WINDOW_HOURS) + timedelta(seconds=1)
        assert is_conversation_stale(at_boundary) is False

    def test_just_past_boundary_is_stale(self) -> None:
        now = datetime.now(UTC).replace(tzinfo=None)
        past = now - timedelta(hours=SESSION_WINDOW_HOURS, seconds=1)
        assert is_conversation_stale(past) is True

    def test_session_window_is_12_hours(self) -> None:
        """Guard against accidental changes to the window."""
        assert SESSION_WINDOW_HOURS == 12
