from __future__ import annotations

import ast
from pathlib import Path
from unittest import TestCase

_REPO_ROOT = Path(__file__).resolve().parents[2]


class TestChatHistoryEndpointContracts(TestCase):
    """Contract tests for chat dashboard endpoints in backend/app/routers/chat.py."""

    def _read_chat_router(self) -> str:
        path = _REPO_ROOT / "backend/app/routers/chat.py"
        return path.read_text(encoding="utf-8")

    def _read_main(self) -> str:
        path = _REPO_ROOT / "backend/app/main.py"
        return path.read_text(encoding="utf-8")

    def _parse_chat_module(self) -> ast.Module:
        source = self._read_chat_router()
        return ast.parse(source)

    def _find_function_def(self, name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
        tree = self._parse_chat_module()
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name:
                return node
        return None

    # ── Router registration ────────────────────────────────────────────────

    def test_chat_router_registered_in_main(self) -> None:
        source = self._read_main()
        self.assertIn("chat.router", source)

    # ── GET /sessions (list_chat_sessions) ─────────────────────────────────

    def test_list_sessions_endpoint_exists(self) -> None:
        fn = self._find_function_def("list_chat_sessions")
        self.assertIsNotNone(fn, "list_chat_sessions function not found")

    def test_list_sessions_accepts_pagination_params(self) -> None:
        fn = self._find_function_def("list_chat_sessions")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("limit", param_names)
        self.assertIn("cursor", param_names)

    def test_list_sessions_accepts_date_filters(self) -> None:
        fn = self._find_function_def("list_chat_sessions")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("date_from", param_names)
        self.assertIn("date_to", param_names)

    def test_list_sessions_accepts_channel_filter(self) -> None:
        fn = self._find_function_def("list_chat_sessions")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("channel", param_names)

    def test_list_sessions_depends_on_restaurant_tenant(self) -> None:
        fn = self._find_function_def("list_chat_sessions")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("restaurant", param_names)

    def test_list_sessions_scoped_to_tenant(self) -> None:
        source = self._read_chat_router()
        # The query must filter by restaurant_id to enforce tenant isolation
        self.assertIn("Conversation.restaurant_id == restaurant.id", source)

    def test_list_sessions_returns_response_model(self) -> None:
        source = self._read_chat_router()
        self.assertIn("response_model=ChatSessionsResponse", source)

    def test_list_sessions_has_cursor_pagination(self) -> None:
        source = self._read_chat_router()
        # The implementation fetches limit+1 rows and derives next_cursor
        self.assertIn("limit + 1", source)
        self.assertIn("next_cursor", source)

    # ── GET /sessions/{session_id}/messages ─────────────────────────────────

    def test_get_messages_endpoint_exists(self) -> None:
        fn = self._find_function_def("get_session_messages")
        self.assertIsNotNone(fn, "get_session_messages function not found")

    def test_get_messages_requires_session_id(self) -> None:
        fn = self._find_function_def("get_session_messages")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("session_id", param_names)

    def test_get_messages_returns_chat_message_list(self) -> None:
        source = self._read_chat_router()
        self.assertIn("response_model=list[ChatMessageRead]", source)

    def test_get_messages_cross_tenant_returns_404(self) -> None:
        """The endpoint must check Conversation.restaurant_id == restaurant.id
        and raise 404 if the session belongs to another restaurant."""
        # Find the get_session_messages function body
        fn = self._find_function_def("get_session_messages")
        assert fn is not None
        # Extract function source lines for targeted assertion
        lines = ast.get_source_segment(self._read_chat_router(), fn)
        assert lines is not None
        self.assertIn("Conversation.restaurant_id == restaurant.id", lines)
        self.assertIn("Chat session not found", lines)

    def test_get_messages_orders_by_created_at_asc(self) -> None:
        source = self._read_chat_router()
        self.assertIn("Message.created_at.asc()", source)

    # ── GET /stats (get_chat_stats) ─────────────────────────────────────────

    def test_stats_endpoint_exists(self) -> None:
        fn = self._find_function_def("get_chat_stats")
        self.assertIsNotNone(fn, "get_chat_stats function not found")

    def test_stats_returns_response_model(self) -> None:
        source = self._read_chat_router()
        self.assertIn("response_model=ChatStatsResponse", source)

    def test_stats_includes_total_sessions(self) -> None:
        source = self._read_chat_router()
        self.assertIn("total_sessions", source)

    def test_stats_includes_avg_messages(self) -> None:
        source = self._read_chat_router()
        self.assertIn("avg_messages_per_session", source)
        self.assertIn("func.avg(Conversation.message_count)", source)

    def test_stats_includes_channel_distribution(self) -> None:
        source = self._read_chat_router()
        self.assertIn("channel_distribution", source)
        self.assertIn("group_by(Conversation.channel)", source)

    def test_stats_scoped_to_tenant(self) -> None:
        fn = self._find_function_def("get_chat_stats")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("restaurant", param_names)

    def test_stats_accepts_date_range(self) -> None:
        fn = self._find_function_def("get_chat_stats")
        assert fn is not None
        param_names = [arg.arg for arg in fn.args.args]
        self.assertIn("date_from", param_names)
        self.assertIn("date_to", param_names)

    # ── Response models ────────────────────────────────────────────────────

    def test_chat_session_summary_model_fields(self) -> None:
        tree = self._parse_chat_module()
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef) and node.name == "ChatSessionSummary":
                field_names = {
                    n.target.id
                    for n in node.body
                    if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name)
                }
                expected = {"id", "channel", "status", "started_at", "message_count"}
                self.assertTrue(
                    expected.issubset(field_names),
                    f"Missing fields in ChatSessionSummary: {expected - field_names}",
                )
                return
        self.fail("ChatSessionSummary class not found")

    def test_chat_stats_response_model_fields(self) -> None:
        tree = self._parse_chat_module()
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef) and node.name == "ChatStatsResponse":
                field_names = {
                    n.target.id
                    for n in node.body
                    if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name)
                }
                expected = {
                    "total_sessions",
                    "avg_messages_per_session",
                    "channel_distribution",
                }
                self.assertTrue(
                    expected.issubset(field_names),
                    f"Missing fields in ChatStatsResponse: {expected - field_names}",
                )
                return
        self.fail("ChatStatsResponse class not found")

    # ── Dashboard preview chat parity with public widget ───────────────────

    def test_chat_stream_uses_shared_website_caller_builder(self) -> None:
        """``chat_stream`` (authenticated dashboard preview) MUST route caller
        construction through ``_build_website_caller`` so its conversational
        behaviour is identical to the public widget endpoint."""
        fn = self._find_function_def("chat_stream")
        self.assertIsNotNone(fn, "chat_stream not found")
        assert fn is not None
        body = ast.get_source_segment(self._read_chat_router(), fn) or ""
        self.assertIn("_build_website_caller", body)
        # The legacy `channel="dashboard", verified=True` shortcut must be gone:
        # operators previously bypassed verification, which made the dashboard
        # preview diverge from the real guest experience.
        self.assertNotIn('channel="dashboard"', body)
        self.assertNotIn("verified=True", body)

    def test_public_chat_stream_uses_shared_website_caller_builder(self) -> None:
        """``public_chat_stream`` MUST use the same helper — the two endpoints
        derive caller identity from a single source of truth."""
        fn = self._find_function_def("public_chat_stream")
        self.assertIsNotNone(fn, "public_chat_stream not found")
        assert fn is not None
        body = ast.get_source_segment(self._read_chat_router(), fn) or ""
        self.assertIn("_build_website_caller", body)

    def test_build_website_caller_produces_unverified_website_identity(self) -> None:
        """The helper must default to ``channel="website"``, ``verified=False``,
        ``customer_id=None`` — the verification flow then lifts those if a
        matching, non-expired ``ConversationVerification`` row exists."""
        fn = self._find_function_def("_build_website_caller")
        self.assertIsNotNone(fn, "_build_website_caller not found")
        assert fn is not None
        body = ast.get_source_segment(self._read_chat_router(), fn) or ""
        self.assertIn('channel="website"', body)
        self.assertIn("verified=False", body)
        self.assertIn("ConversationVerification", body)
