from __future__ import annotations

import ast
from pathlib import Path
from unittest import TestCase

_REPO_ROOT = Path(__file__).resolve().parents[2]


class TestReservationDateFilteringContracts(TestCase):
    """Contract tests for date-range filtering in the reservations router.

    Uses AST inspection and targeted source assertions instead of brittle
    raw-string matching. Verifies the structural contracts the frontend
    and API consumers depend on.
    """

    _source: str | None = None
    _tree: ast.Module | None = None

    def _router_source(self) -> str:
        if self._source is None:
            path = _REPO_ROOT / "backend/app/routers/reservations.py"
            self._source = path.read_text(encoding="utf-8")
        return self._source

    def _router_tree(self) -> ast.Module:
        if self._tree is None:
            self._tree = ast.parse(self._router_source())
        return self._tree

    def _find_function(self, name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
        for node in ast.walk(self._router_tree()):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name:
                return node
        return None

    def _function_param_names(self, name: str) -> list[str]:
        fn = self._find_function(name)
        self.assertIsNotNone(fn, f"Function {name} not found in reservations router")
        assert fn is not None
        return [arg.arg for arg in fn.args.args]

    # ── Endpoint signature contracts ──────────────────────────────────────

    def test_list_reservations_accepts_date_from_param(self) -> None:
        params = self._function_param_names("list_reservations")
        self.assertIn("date_from", params)

    def test_list_reservations_accepts_date_to_param(self) -> None:
        params = self._function_param_names("list_reservations")
        self.assertIn("date_to", params)

    def test_list_reservations_accepts_single_date_param(self) -> None:
        params = self._function_param_names("list_reservations")
        self.assertIn("date", params)

    def test_list_reservations_accepts_status_filter(self) -> None:
        params = self._function_param_names("list_reservations")
        self.assertIn("status", params)

    def test_list_reservations_accepts_customer_id_filter(self) -> None:
        params = self._function_param_names("list_reservations")
        self.assertIn("customer_id", params)

    def test_customer_id_filters_query(self) -> None:
        """customer_id query param must filter Reservation.customer_id."""
        source = self._router_source()
        self.assertIn("Reservation.customer_id == customer_id", source)

    def test_customer_id_skips_default_limit(self) -> None:
        """A customer-scoped query must not be capped at 100 — operators
        need the full reservation history for a single guest."""
        fn = self._find_function("list_reservations")
        assert fn is not None
        # Look for an elif customer_id branch in the date-handling chain.
        source = ast.get_source_segment(self._router_source(), fn) or ""
        self.assertIn("elif customer_id is not None", source)

    # ── Date filtering logic contracts ────────────────────────────────────

    def test_date_range_filters_reserved_at_with_naive_utc_window(self) -> None:
        """When date_from and date_to are provided, the query must filter
        Reservation.reserved_at using the restaurant-tz-aware naive-UTC
        day window from ``app.utils.tz.local_day_window``."""
        source = self._router_source()
        self.assertIn("local_day_window(from_date, tz)", source)
        self.assertIn("local_day_window(to_date, tz)", source)
        self.assertIn("Reservation.reserved_at >= date_from_start", source)
        # Half-open range — the second window's start is the exclusive end.
        self.assertIn("Reservation.reserved_at < date_to_end", source)

    def test_date_range_resolves_tz_from_restaurant(self) -> None:
        """The boundaries must be tied to the authenticated restaurant's
        timezone, never the server timezone."""
        source = self._router_source()
        self.assertIn("resolve_tz(restaurant.timezone)", source)

    def test_single_date_filters_full_local_day(self) -> None:
        """Single date parameter also uses the naive-UTC ``local_day_window``."""
        source = self._router_source()
        self.assertIn("local_day_window(filter_date, tz)", source)
        self.assertIn("day_start", source)
        self.assertIn("day_end", source)

    def test_no_date_params_applies_default_limit(self) -> None:
        """When no date filters provided, a limit is applied to prevent unbounded queries."""
        source = self._router_source()
        # The else branch after date filtering must apply a limit
        self.assertIn(".limit(100)", source)

    # ── Date parsing uses stdlib fromisoformat ────────────────────────────

    def test_date_parsing_uses_fromisoformat(self) -> None:
        """Dates must be parsed via date.fromisoformat for ISO 8601 compliance."""
        source = self._router_source()
        # Both date_from/date_to and single date paths use fromisoformat
        count = source.count("fromisoformat(")
        self.assertGreaterEqual(
            count,
            3,
            "Expected at least 3 fromisoformat calls (date_from, date_to, date)",
        )

    def test_imports_date_type(self) -> None:
        """The module must import the date type from datetime."""
        tree = self._router_tree()
        found = False
        for node in ast.walk(tree):
            if isinstance(node, ast.ImportFrom) and node.module == "datetime":
                for alias in node.names:
                    if alias.name == "date" or alias.name == "datetime":
                        found = True
        self.assertTrue(found, "datetime imports not found")

    # ── Invalid date format handling ──────────────────────────────────────

    def test_invalid_date_returns_422(self) -> None:
        """Invalid date format must raise HTTPException with status 422."""
        source = self._router_source()
        # The ValueError handler must return 422
        self.assertIn("status_code=422", source)

    def test_invalid_date_error_message(self) -> None:
        """The error detail must instruct the caller to use YYYY-MM-DD format."""
        source = self._router_source()
        self.assertIn("Invalid date format. Use YYYY-MM-DD", source)

    def test_date_validation_catches_value_error(self) -> None:
        """fromisoformat is wrapped in try/except ValueError."""
        fn = self._find_function("list_reservations")
        assert fn is not None
        handler_types: list[str] = []
        for node in ast.walk(fn):
            if (
                isinstance(node, ast.ExceptHandler)
                and node.type is not None
                and isinstance(node.type, ast.Name)
            ):
                handler_types.append(node.type.id)
        self.assertIn("ValueError", handler_types)

    # ── Tenant scoping ────────────────────────────────────────────────────

    def test_query_scoped_to_restaurant(self) -> None:
        """All reservation queries must be tenant-scoped."""
        source = self._router_source()
        self.assertIn("Reservation.restaurant_id == restaurant.id", source)
