"""Tests for combo-aware conflict detection in table_allocation.

Covers two bugs fixed in _table_has_time_conflict / _combo_has_time_conflict:

1. _table_has_time_conflict now detects reservations via a combination that
   includes the checked table (walk-in with combination_id, table_id=NULL).
2. _combo_has_time_conflict no longer discards the combo's own ID, so
   a reservation directly on the combo is detected.
"""

import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, cast
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, patch

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_BACKEND_PATH = Path(__file__).resolve().parents[1]
if str(_BACKEND_PATH) not in sys.path:
    sys.path.insert(0, str(_BACKEND_PATH))

from app.models.table_combination import TableCombination  # noqa: E402
from app.services.table_allocation import (  # noqa: E402
    _combo_has_time_conflict,
    _table_has_time_conflict,
)

_MODULE = "app.services.table_allocation"
_RESERVED_AT = datetime(2026, 3, 15, 19, 0)
_END_TIME = datetime(2026, 3, 15, 20, 30)


# ── Mock helpers ─────────────────────────────────────────────────────────────


class _MockResult:
    """Minimal SQLAlchemy result proxy supporting both scalar and row access."""

    def __init__(
        self,
        *,
        scalar: object = None,
        rows: list[tuple[str, ...]] | None = None,
    ) -> None:
        self._scalar = scalar
        self._rows = rows or []

    def scalar_one_or_none(self) -> object:
        return self._scalar

    def all(self) -> list[tuple[str, ...]]:
        return self._rows


def _combo(combo_id: str, table_ids: list[str], capacity: int = 4) -> TableCombination:
    return TableCombination(
        id=combo_id,
        name=combo_id,
        restaurant_id="r1",
        table_ids=table_ids,
        combined_capacity=capacity,
    )


class _SeqSession:
    """Async session returning pre-programmed results by execute order."""

    def __init__(self, results: list[_MockResult]) -> None:
        self._results = list(results)
        self._idx = 0

    async def execute(self, _stmt: object, _params: object = None) -> _MockResult:
        if self._idx < len(self._results):
            r = self._results[self._idx]
            self._idx += 1
            return r
        return _MockResult()


# ═════════════════════════════════════════════════════════════════════════════
# _table_has_time_conflict — combo-aware detection
# ═════════════════════════════════════════════════════════════════════════════


class TestTableConflictViaCombo(IsolatedAsyncioTestCase):
    """_table_has_time_conflict detects reservations through combinations."""

    async def test_detects_reservation_via_combination(self) -> None:
        """Table T1 is in combo C1.  Walk-in reservation stored as
        {combination_id=C1, table_id=NULL}.  T1 must be seen as occupied."""
        session = cast(
            Any,
            _SeqSession(
                [
                    # Query 1: direct table_id check → no direct reservation
                    _MockResult(scalar=None),
                    # Query 2: combo-based check → combo subquery finds C1 which has a
                    # reservation; the outer query matches that reservation
                    _MockResult(scalar="res-walk-in"),
                ]
            ),
        )
        result = await _table_has_time_conflict(session, "t1", _RESERVED_AT, _END_TIME)
        self.assertTrue(result)

    async def test_no_false_positive_from_unrelated_combo(self) -> None:
        """Table T1 is NOT in any combo with a reservation → no conflict."""
        session = cast(
            Any,
            _SeqSession(
                [
                    # Query 1: no direct reservation
                    _MockResult(scalar=None),
                    # Query 2: subquery finds no combos containing T1, so no match
                    _MockResult(scalar=None),
                ]
            ),
        )
        result = await _table_has_time_conflict(session, "t1", _RESERVED_AT, _END_TIME)
        self.assertFalse(result)

    async def test_direct_conflict_still_short_circuits(self) -> None:
        """Direct table_id reservation is detected on the first query;
        combo query is never reached."""
        session = cast(
            Any,
            _SeqSession(
                [
                    # Query 1: direct hit
                    _MockResult(scalar="res-direct"),
                ]
            ),
        )
        result = await _table_has_time_conflict(session, "t1", _RESERVED_AT, _END_TIME)
        self.assertTrue(result)
        # Only 1 of 2 queries was consumed — short-circuited
        self.assertEqual(session._idx, 1)

    async def test_exclude_reservation_id_applies_to_both_queries(self) -> None:
        """Passing exclude_reservation_id skips a matching reservation in both
        the direct and combo-based paths."""
        session = cast(
            Any,
            _SeqSession(
                [
                    # Query 1: direct check excluding res-1 → no match
                    _MockResult(scalar=None),
                    # Query 2: combo check excluding res-1 → no match
                    _MockResult(scalar=None),
                ]
            ),
        )
        result = await _table_has_time_conflict(
            session, "t1", _RESERVED_AT, _END_TIME, exclude_reservation_id="res-1"
        )
        self.assertFalse(result)
        # Both queries consumed
        self.assertEqual(session._idx, 2)


# ═════════════════════════════════════════════════════════════════════════════
# _combo_has_time_conflict — self-inclusion (no longer discards own combo ID)
# ═════════════════════════════════════════════════════════════════════════════


class TestComboSelfConflict(IsolatedAsyncioTestCase):
    """_combo_has_time_conflict detects a reservation on the combo itself."""

    async def test_direct_reservation_on_combo_detected(self) -> None:
        """Combo C1 has a walk-in reservation.  Even with no individual-table
        conflicts, the combo must report a conflict (bug: combo.id was
        discarded from the check set)."""
        combo = _combo("c1", ["t1", "t2"])

        # Mock _table_has_time_conflict → False for each constituent table
        # so step 1 doesn't short-circuit; we want step 2 to be the one that
        # catches it.
        with patch(
            f"{_MODULE}._table_has_time_conflict",
            AsyncMock(return_value=False),
        ):
            session = cast(
                Any,
                _SeqSession(
                    [
                        # Combo gathering for T1 → combo C1 contains T1
                        _MockResult(rows=[("c1",)]),
                        # Combo gathering for T2 → combo C1 contains T2
                        _MockResult(rows=[("c1",)]),
                        # Final reservation check: combination_id IN ['c1'] → hit
                        _MockResult(scalar="res-walk-in"),
                    ]
                ),
            )
            result = await _combo_has_time_conflict(session, combo, _RESERVED_AT, _END_TIME)
        self.assertTrue(result)

    async def test_no_conflict_when_combo_free(self) -> None:
        """Combo C1 has no reservations — no conflict even though combo.id is
        now included in the check set."""
        combo = _combo("c1", ["t1", "t2"])

        with patch(
            f"{_MODULE}._table_has_time_conflict",
            AsyncMock(return_value=False),
        ):
            session = cast(
                Any,
                _SeqSession(
                    [
                        # Combo gathering for T1 → C1
                        _MockResult(rows=[("c1",)]),
                        # Combo gathering for T2 → C1
                        _MockResult(rows=[("c1",)]),
                        # Final reservation check → no match
                        _MockResult(scalar=None),
                    ]
                ),
            )
            result = await _combo_has_time_conflict(session, combo, _RESERVED_AT, _END_TIME)
        self.assertFalse(result)

    async def test_conflict_via_other_combo_sharing_table(self) -> None:
        """Combo C1 (T1, T2) and combo C2 (T1, T3) share table T1.
        A reservation on C2 means C1's table T1 is occupied → conflict."""
        combo = _combo("c1", ["t1", "t2"])

        with patch(
            f"{_MODULE}._table_has_time_conflict",
            AsyncMock(return_value=False),
        ):
            session = cast(
                Any,
                _SeqSession(
                    [
                        # Combo gathering for T1 → both C1 and C2 contain T1
                        _MockResult(rows=[("c1",), ("c2",)]),
                        # Combo gathering for T2 → only C1
                        _MockResult(rows=[("c1",)]),
                        # Final reservation check: combination_id IN ['c1','c2'] → hit via C2
                        _MockResult(scalar="res-on-c2"),
                    ]
                ),
            )
            result = await _combo_has_time_conflict(session, combo, _RESERVED_AT, _END_TIME)
        self.assertTrue(result)

    async def test_constituent_table_conflict_short_circuits(self) -> None:
        """If _table_has_time_conflict finds a conflict on a constituent table,
        combo check returns True without reaching step 2."""
        combo = _combo("c1", ["t1", "t2"])

        call_count = 0

        async def _conflict_on_first(session, tid, *a, **kw):
            nonlocal call_count
            call_count += 1
            return tid == "t1"

        with patch(
            f"{_MODULE}._table_has_time_conflict",
            side_effect=_conflict_on_first,
        ):
            session = cast(Any, _SeqSession([]))  # No session calls expected
            result = await _combo_has_time_conflict(session, combo, _RESERVED_AT, _END_TIME)
        self.assertTrue(result)
        # Only checked T1 (first hit) — didn't continue to T2 or step 2
        self.assertEqual(call_count, 1)
        self.assertEqual(session._idx, 0)

    async def test_exclude_reservation_id_propagated(self) -> None:
        """exclude_reservation_id is passed to both _table_has_time_conflict
        and the combo reservation query."""
        combo = _combo("c1", ["t1"])
        exclude_id = "res-being-updated"

        mock_conflict = AsyncMock(return_value=False)
        with patch(f"{_MODULE}._table_has_time_conflict", mock_conflict):
            session = cast(
                Any,
                _SeqSession(
                    [
                        # Combo gathering for T1
                        _MockResult(rows=[("c1",)]),
                        # Final reservation check → no match (excluded)
                        _MockResult(scalar=None),
                    ]
                ),
            )
            await _combo_has_time_conflict(
                session,
                combo,
                _RESERVED_AT,
                _END_TIME,
                exclude_reservation_id=exclude_id,
            )
        # Verify exclude_reservation_id was forwarded to _table_has_time_conflict
        mock_conflict.assert_called_once_with(
            session,
            "t1",
            _RESERVED_AT,
            _END_TIME,
            exclude_id,
        )
