"""Tests for the walk-in availability endpoint and walk_in source validation."""

from __future__ import annotations

import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

from app.models.restaurant import Restaurant
from app.models.table_combination import TableCombination

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_BACKEND_PATH = Path(__file__).resolve().parents[1]
if str(_BACKEND_PATH) not in sys.path:
    sys.path.insert(0, str(_BACKEND_PATH))

from app.routers.tables import get_table_availability  # noqa: E402

_ALLOC_MODULE = "app.routers.tables"
_START = "2026-03-15T19:00:00"
_START_DT = datetime.fromisoformat(_START)
_DURATION = 90
_END_DT = _START_DT + timedelta(minutes=_DURATION)
_RID = "rest-1"


def _mock_restaurant(restaurant_id: str = _RID) -> Restaurant:
    return Restaurant(id=restaurant_id, name="Test", slug="test", team_id="team-1")


def _table_row(
    table_id: str, label: str, capacity: int, combo_membership: int = 0
) -> tuple[str, str, int, int]:
    return (table_id, label, capacity, combo_membership)


def _mock_combo(
    combo_id: str, name: str, table_ids: list[str], combined_capacity: int
) -> TableCombination:
    return TableCombination(
        id=combo_id,
        name=name,
        restaurant_id=_RID,
        table_ids=table_ids,
        combined_capacity=combined_capacity,
    )


class _FakeResult:
    """Wraps table rows for session.execute().all()."""

    def __init__(self, rows: list[tuple[str, str, int, int]]) -> None:
        self._rows = rows

    def all(self) -> list[tuple[str, str, int, int]]:
        return self._rows


class _FakeScalarsResult:
    """Wraps combo objects for session.execute().scalars().all()."""

    def __init__(self, combos: list[TableCombination]) -> None:
        self._combos = combos

    def all(self) -> list[TableCombination]:
        return self._combos


def _build_session(
    table_rows: list[tuple[str, str, int, int]],
    combos: list[TableCombination],
) -> AsyncMock:
    """Build a mock session whose execute() returns table rows first, then combos."""
    session = AsyncMock()
    call_count = 0

    async def _execute(_stmt: object, _params: object = None) -> object:
        nonlocal call_count
        call_count += 1
        if call_count == 1:
            return _FakeResult(table_rows)
        # Second call: combo query — needs .scalars().all()
        scalars_result = _FakeScalarsResult(combos)
        result = MagicMock()
        result.scalars.return_value = scalars_result
        return result

    session.execute = _execute
    return session


# ═══════════════════════════════════════════════════════════════════════════════
# TestAvailabilityEndpoint
# ═══════════════════════════════════════════════════════════════════════════════


class TestAvailabilityEndpoint(IsolatedAsyncioTestCase):
    """Test get_table_availability handler with mocked session and restaurant."""

    async def test_all_tables_available_when_no_conflicts(self) -> None:
        rows = [_table_row("t1", "T1", 4), _table_row("t2", "T2", 6)]
        session = _build_session(rows, [])

        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", AsyncMock(return_value=False)):
            resp = await get_table_availability(
                party_size=2,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        assert len(resp.tables) == 2
        assert all(t.available for t in resp.tables)

    async def test_occupied_table_marked_unavailable(self) -> None:
        rows = [_table_row("t1", "T1", 4), _table_row("t2", "T2", 6)]
        session = _build_session(rows, [])

        async def _conflict(_s: object, table_id: str, *_a: object, **_kw: object) -> bool:
            return table_id == "t1"

        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", side_effect=_conflict):
            resp = await get_table_availability(
                party_size=2,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        by_id = {t.table_id: t for t in resp.tables}
        assert by_id["t1"].available is False
        assert by_id["t2"].available is True

    async def test_undersized_table_unavailable(self) -> None:
        rows = [_table_row("t1", "T1", 2), _table_row("t2", "T2", 6)]
        session = _build_session(rows, [])
        mock_conflict = AsyncMock(return_value=False)

        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", mock_conflict):
            resp = await get_table_availability(
                party_size=4,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        by_id = {t.table_id: t for t in resp.tables}
        assert by_id["t1"].available is False
        assert by_id["t2"].available is True
        # Conflict check should NOT have been called for the undersized table
        conflict_table_ids = [call.args[1] for call in mock_conflict.call_args_list]
        assert "t1" not in conflict_table_ids

    async def test_combinations_included_with_availability(self) -> None:
        rows = [_table_row("t1", "T1", 2)]
        combo_avail = _mock_combo("c1", "Combo 1", ["t1", "t2"], 8)
        combo_unavail = _mock_combo("c2", "Combo 2", ["t3", "t4"], 10)
        session = _build_session(rows, [combo_avail, combo_unavail])

        async def _combo_conflict(
            _s: object, combo: TableCombination, *_a: object, **_kw: object
        ) -> bool:
            return combo.id == "c2"

        with (
            patch(f"{_ALLOC_MODULE}._table_has_time_conflict", AsyncMock(return_value=False)),
            patch(f"{_ALLOC_MODULE}._combo_has_time_conflict", side_effect=_combo_conflict),
        ):
            resp = await get_table_availability(
                party_size=4,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        assert len(resp.combinations) == 2
        by_id = {c.id: c for c in resp.combinations}
        assert by_id["c1"].available is True
        assert by_id["c2"].available is False

    async def test_duration_change_affects_availability(self) -> None:
        """Different duration_minutes changes the end_time passed to conflict checks."""
        rows = [_table_row("t1", "T1", 4)]
        session = _build_session(rows, [])
        mock_conflict = AsyncMock(return_value=False)

        short_duration = 30
        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", mock_conflict):
            await get_table_availability(
                party_size=2,
                start=_START,
                duration_minutes=short_duration,
                session=session,
                restaurant=_mock_restaurant(),
            )

        expected_end = _START_DT + timedelta(minutes=short_duration)
        call_args = mock_conflict.call_args
        # _table_has_time_conflict(session, tid, start_dt, end_dt)
        assert call_args[0][2] == _START_DT
        assert call_args[0][3] == expected_end

    async def test_undersized_combo_unavailable(self) -> None:
        """Combination with combined_capacity < party_size is marked unavailable."""
        rows: list[tuple[str, str, int, int]] = []
        small_combo = _mock_combo("c1", "Small", ["t1", "t2"], 3)
        session = _build_session(rows, [small_combo])
        mock_combo_conflict = AsyncMock(return_value=False)

        with (
            patch(f"{_ALLOC_MODULE}._table_has_time_conflict", AsyncMock(return_value=False)),
            patch(f"{_ALLOC_MODULE}._combo_has_time_conflict", mock_combo_conflict),
        ):
            resp = await get_table_availability(
                party_size=5,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        assert resp.combinations[0].available is False
        # Should not have called conflict check for undersized combo
        mock_combo_conflict.assert_not_called()

    async def test_tables_sorted_best_fit(self) -> None:
        """Tables sorted: available first, fewest combo memberships, smallest capacity."""
        rows = [
            _table_row("t1", "Big-in-combos", 6, combo_membership=3),
            _table_row("t2", "Small-standalone", 2, combo_membership=0),
            _table_row("t3", "Medium-one-combo", 4, combo_membership=1),
            _table_row("t4", "Undersized", 1, combo_membership=0),
        ]
        session = _build_session(rows, [])

        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", AsyncMock(return_value=False)):
            resp = await get_table_availability(
                party_size=2,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        labels = [t.table_label for t in resp.tables]
        # Available tables first (t2, t3, t1), sorted by (membership, capacity).
        # t4 is undersized → unavailable, pushed to end.
        assert labels == ["Small-standalone", "Medium-one-combo", "Big-in-combos", "Undersized"]

    async def test_tables_available_before_unavailable(self) -> None:
        """Unavailable tables are always sorted after available ones."""
        rows = [
            _table_row("t1", "Available-big", 8, combo_membership=0),
            _table_row("t2", "Occupied-small", 4, combo_membership=0),
        ]
        session = _build_session(rows, [])

        async def _conflict(_s: object, table_id: str, *_a: object, **_kw: object) -> bool:
            return table_id == "t2"

        with patch(f"{_ALLOC_MODULE}._table_has_time_conflict", side_effect=_conflict):
            resp = await get_table_availability(
                party_size=2,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        assert resp.tables[0].table_label == "Available-big"
        assert resp.tables[0].available is True
        assert resp.tables[1].table_label == "Occupied-small"
        assert resp.tables[1].available is False

    async def test_combos_sorted_by_overlap_then_capacity(self) -> None:
        """Combos sorted: available first, lowest overlap score, smallest capacity."""
        rows: list[tuple[str, str, int, int]] = []
        # c1 shares tables with c2 and c3 (high overlap)
        # c2 shares only with c1 (low overlap)
        c1 = _mock_combo("c1", "HighOverlap", ["t1", "t2"], 8)
        c2 = _mock_combo("c2", "LowOverlap", ["t1", "t3"], 6)
        c3 = _mock_combo("c3", "Isolated", ["t4", "t5"], 10)
        session = _build_session(rows, [c1, c2, c3])

        with (
            patch(f"{_ALLOC_MODULE}._table_has_time_conflict", AsyncMock(return_value=False)),
            patch(f"{_ALLOC_MODULE}._combo_has_time_conflict", AsyncMock(return_value=False)),
        ):
            resp = await get_table_availability(
                party_size=4,
                start=_START,
                duration_minutes=_DURATION,
                session=session,
                restaurant=_mock_restaurant(),
            )

        names = [c.name for c in resp.combinations]
        # c3 (Isolated) has 0 overlap, c2 (LowOverlap) has 1, c1 (HighOverlap) has 1
        # but c2 capacity=6 < c1 capacity=8, so c2 before c1.
        assert names == ["Isolated", "LowOverlap", "HighOverlap"]


# ═══════════════════════════════════════════════════════════════════════════════
# TestWalkInSource
# ═══════════════════════════════════════════════════════════════════════════════


class TestWalkInSource(IsolatedAsyncioTestCase):
    """Verify source='walk_in' works with the ReservationCreate schema."""

    def test_reservation_create_schema_accepts_walk_in_source(self) -> None:
        from app.models.reservation import ReservationCreate

        data = ReservationCreate(
            guest_name="Walk-in Guest",
            party_size=3,
            reserved_at=datetime(2026, 3, 15, 19, 0),
            end_time=datetime(2026, 3, 15, 21, 0),
            source="walk_in",
        )
        assert data.source == "walk_in"
        dumped = data.model_dump()
        assert dumped["source"] == "walk_in"

    def test_reservation_create_schema_includes_combination_id(self) -> None:
        from app.models.reservation import ReservationCreate

        data = ReservationCreate(
            guest_name="Combo Guest",
            party_size=6,
            reserved_at=datetime(2026, 3, 15, 20, 0),
            end_time=datetime(2026, 3, 15, 22, 0),
            source="walk_in",
            combination_id="combo-123",
        )
        assert data.combination_id == "combo-123"
        dumped = data.model_dump()
        assert dumped["combination_id"] == "combo-123"

    def test_reservation_create_default_source_is_manual(self) -> None:
        from app.models.reservation import ReservationCreate

        data = ReservationCreate(
            guest_name="Default Source",
            party_size=2,
            reserved_at=datetime(2026, 3, 15, 18, 0),
            end_time=datetime(2026, 3, 15, 20, 0),
        )
        assert data.source == "manual"
