"""Tests for the table-conflict guard (_guard_table_conflict).

The guard uses SELECT … FOR UPDATE on FloorTable rows to serialise concurrent
bookings for the same physical table, closing the TOCTOU gap between the
availability check and the INSERT/UPDATE.
"""

from __future__ import annotations

import asyncio
import importlib
import os
import sys
from datetime import datetime
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_REPO_ROOT = Path(__file__).resolve().parents[2]
_BACKEND_PATH = _REPO_ROOT / "backend"
_RESTATE_SERVICES_PATH = _REPO_ROOT / "restate_services"
if str(_BACKEND_PATH) not in sys.path:
    sys.path.insert(0, str(_BACKEND_PATH))
if str(_RESTATE_SERVICES_PATH) not in sys.path:
    sys.path.insert(0, str(_RESTATE_SERVICES_PATH))

reservation_mod = importlib.import_module("objects.reservation")
restate_mod = importlib.import_module("restate")

_RESERVED_AT = datetime(2026, 2, 27, 18, 0)
_END_TIME = datetime(2026, 2, 27, 20, 0)


def _mock_session() -> MagicMock:
    """Return a mock session whose execute() returns a generic scalar result."""
    session = AsyncMock()
    session.execute = AsyncMock(
        return_value=MagicMock(scalar_one_or_none=MagicMock(return_value="table-1"))
    )
    return session


# ── Unit tests ──────────────────────────────────────────────────────────────


class TestGuardTableConflict(IsolatedAsyncioTestCase):
    """Direct tests for _guard_table_conflict."""

    async def test_raises_on_table_conflict(self) -> None:
        session = _mock_session()
        with patch(
            "app.services.table_allocation._table_has_time_conflict",
            new_callable=AsyncMock,
            return_value=True,
        ):
            with self.assertRaises(restate_mod.TerminalError) as ctx:
                await reservation_mod._guard_table_conflict(
                    session,
                    table_id="table-1",
                    combination_id=None,
                    reserved_at=_RESERVED_AT,
                    end_time=_END_TIME,
                )
            self.assertIn("already booked", str(ctx.exception))

    async def test_passes_when_table_free(self) -> None:
        session = _mock_session()
        with patch(
            "app.services.table_allocation._table_has_time_conflict",
            new_callable=AsyncMock,
            return_value=False,
        ):
            await reservation_mod._guard_table_conflict(
                session,
                table_id="table-1",
                combination_id=None,
                reserved_at=_RESERVED_AT,
                end_time=_END_TIME,
            )  # should not raise

    async def test_raises_on_combo_conflict(self) -> None:
        combo = MagicMock()
        combo.table_ids = ["t-1", "t-2"]
        combo.id = "combo-1"

        session = AsyncMock()
        session.execute = AsyncMock(
            return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=combo))
        )

        with patch(
            "app.services.table_allocation._combo_has_time_conflict",
            new_callable=AsyncMock,
            return_value=True,
        ):
            with self.assertRaises(restate_mod.TerminalError) as ctx:
                await reservation_mod._guard_table_conflict(
                    session,
                    table_id=None,
                    combination_id="combo-1",
                    reserved_at=_RESERVED_AT,
                    end_time=_END_TIME,
                )
            self.assertIn("combination", str(ctx.exception).lower())

    async def test_passes_when_combo_free(self) -> None:
        combo = MagicMock()
        combo.table_ids = ["t-1", "t-2"]
        combo.id = "combo-1"

        session = AsyncMock()
        session.execute = AsyncMock(
            return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=combo))
        )

        with patch(
            "app.services.table_allocation._combo_has_time_conflict",
            new_callable=AsyncMock,
            return_value=False,
        ):
            await reservation_mod._guard_table_conflict(
                session,
                table_id=None,
                combination_id="combo-1",
                reserved_at=_RESERVED_AT,
                end_time=_END_TIME,
            )  # should not raise

    async def test_noop_when_no_table_or_combo(self) -> None:
        session = AsyncMock()
        await reservation_mod._guard_table_conflict(
            session,
            table_id=None,
            combination_id=None,
            reserved_at=_RESERVED_AT,
            end_time=_END_TIME,
        )
        session.execute.assert_not_called()

    async def test_exclude_reservation_id_passed_through(self) -> None:
        """The exclude_reservation_id kwarg is forwarded to the conflict checker."""
        session = _mock_session()
        with patch(
            "app.services.table_allocation._table_has_time_conflict",
            new_callable=AsyncMock,
            return_value=False,
        ) as mock_check:
            await reservation_mod._guard_table_conflict(
                session,
                table_id="table-1",
                combination_id=None,
                reserved_at=_RESERVED_AT,
                end_time=_END_TIME,
                exclude_reservation_id="res-99",
            )
            mock_check.assert_awaited_once_with(
                session, "table-1", _RESERVED_AT, _END_TIME, "res-99"
            )


# ── Concurrency tests ──────────────────────────────────────────────────────


class TestTableConflictConcurrency(IsolatedAsyncioTestCase):
    """Simulate two concurrent guard calls for the same table.

    Uses an asyncio.Lock to model PostgreSQL's FOR UPDATE row lock.
    A shared ``booked`` list models committed reservations: after the guard
    passes, the caller "inserts" into ``booked`` then releases the lock
    (simulating COMMIT).
    """

    async def test_concurrent_same_table_only_one_wins(self) -> None:
        lock = asyncio.Lock()
        booked: list[str] = []

        async def locking_execute(_stmt: object) -> MagicMock:
            await lock.acquire()
            return MagicMock(scalar_one_or_none=MagicMock(return_value="table-1"))

        async def check_conflict(
            _session: object,
            _table_id: str,
            _reserved_at: datetime,
            _end_time: datetime,
            exclude_id: str | None = None,
        ) -> bool:
            return any(r != exclude_id for r in booked)

        async def attempt(reservation_id: str) -> str:
            session = MagicMock()
            session.execute = locking_execute
            try:
                with patch(
                    "app.services.table_allocation._table_has_time_conflict",
                    side_effect=check_conflict,
                ):
                    await reservation_mod._guard_table_conflict(
                        session,
                        table_id="table-1",
                        combination_id=None,
                        reserved_at=_RESERVED_AT,
                        end_time=_END_TIME,
                        exclude_reservation_id=reservation_id,
                    )
                # Guard passed — "INSERT"
                booked.append(reservation_id)
                return "ok"
            except restate_mod.TerminalError:
                return "conflict"
            finally:
                if lock.locked():
                    lock.release()

        results = await asyncio.gather(attempt("res-1"), attempt("res-2"))

        self.assertEqual(sorted(results), ["conflict", "ok"])
        self.assertEqual(len(booked), 1)

    async def test_different_tables_both_succeed(self) -> None:
        """Non-overlapping tables should not block each other."""
        booked: list[str] = []

        async def locking_execute(table_id: str) -> MagicMock:
            # Each table gets its own lock
            return MagicMock(scalar_one_or_none=MagicMock(return_value=table_id))

        async def no_conflict(*_args: object, **_kwargs: object) -> bool:
            return False

        async def attempt(reservation_id: str, table_id: str) -> str:
            session = MagicMock()
            session.execute = AsyncMock(
                return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=table_id))
            )
            try:
                with patch(
                    "app.services.table_allocation._table_has_time_conflict",
                    side_effect=no_conflict,
                ):
                    await reservation_mod._guard_table_conflict(
                        session,
                        table_id=table_id,
                        combination_id=None,
                        reserved_at=_RESERVED_AT,
                        end_time=_END_TIME,
                        exclude_reservation_id=reservation_id,
                    )
                booked.append(reservation_id)
                return "ok"
            except restate_mod.TerminalError:
                return "conflict"

        results = await asyncio.gather(
            attempt("res-1", "t-1"),
            attempt("res-2", "t-2"),
        )

        self.assertEqual(results, ["ok", "ok"])
        self.assertEqual(len(booked), 2)
