"""Test that concurrent reservation creation enforces max_covers.

Two concurrent creates with the same service block — only one should succeed
when total covers would exceed the block's max_covers limit.

The concurrency guarantee comes from the asyncio.Lock inside _FakeSession
which simulates PostgreSQL's row-level locking behaviour.
"""

from __future__ import annotations

import asyncio
import importlib
import os
import sys
from collections.abc import Awaitable, Callable
from datetime import datetime, time
from pathlib import Path
from types import SimpleNamespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, patch

from app.models.service_block import ServiceBlock
from app.services.service_blocks import ResolvedBlock

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_REPO_ROOT = Path(__file__).resolve().parents[2]
_BACKEND_PATH = _REPO_ROOT / "backend"
_RESTATE_SERVICES_PATH = _REPO_ROOT / "restate_services"
if str(_BACKEND_PATH) not in sys.path:
    sys.path.insert(0, str(_BACKEND_PATH))
if str(_RESTATE_SERVICES_PATH) not in sys.path:
    sys.path.insert(0, str(_RESTATE_SERVICES_PATH))

reservation_object_module = importlib.import_module("objects.reservation")
restate_module = importlib.import_module("restate")


class _ScalarsResult:
    def __init__(self, values: list[object]) -> None:
        self._values = values

    def scalars(self) -> _ScalarsResult:
        return self

    def all(self) -> list[object]:
        return self._values


class _SharedStore:
    def __init__(self) -> None:
        self.slot = SimpleNamespace(
            id="slot-1",
            start_time=time(18, 0),
            end_time=time(19, 0),
            max_covers=5,
            slot_duration_minutes=15,
        )
        self.bookings: list[object] = [SimpleNamespace(party_size=3)]
        self.slot_lock = asyncio.Lock()


class _FakeSession:
    """Simulates a DB session with locking for the booking-count query.

    The only query _insert issues through the session (after service-block
    functions are patched out) is the booking-count SELECT.  We serialise
    it behind slot_lock to model FOR UPDATE behaviour.
    """

    def __init__(self, store: _SharedStore) -> None:
        self.store = store
        self._locked = False
        self._pending: list[object] = []

    async def execute(self, _statement: object, _params: dict[str, object] | None = None) -> object:
        await self.store.slot_lock.acquire()
        self._locked = True
        return _ScalarsResult(list(self.store.bookings))

    def add(self, row: object) -> None:
        self._pending.append(row)

    async def commit(self) -> None:
        for row in self._pending:
            self.store.bookings.append(SimpleNamespace(party_size=getattr(row, "party_size", 0)))
        self._pending.clear()
        if self._locked:
            self.store.slot_lock.release()
            self._locked = False

    async def refresh(self, _row: object) -> None:
        return None


class _SessionContext:
    def __init__(self, store: _SharedStore) -> None:
        self._session = _FakeSession(store)

    async def __aenter__(self) -> _FakeSession:
        return self._session

    async def __aexit__(self, *_args: object) -> bool:
        if self._session._locked:
            self._session.store.slot_lock.release()
            self._session._locked = False
        return False


class _CreateContext:
    def __init__(self, reservation_id: str) -> None:
        self._reservation_id = reservation_id

    def key(self) -> str:
        return self._reservation_id

    async def run(self, step_name: str, fn: Callable[[], Awaitable[object]]) -> object:
        if step_name == "insert_reservation":
            return await fn()
        return None  # Skip events, notifications, allocation


# Fixed return values for patched service-block helpers
_SLOT = ServiceBlock(
    restaurant_id="test-r1",
    day_of_week=0,
    name="Dinner",
    block_type="open",
    start_time=time(18, 0),
    end_time=time(19, 0),
    max_covers=5,
    slot_interval_minutes=15,
)


class TestReservationCreateConcurrency(IsolatedAsyncioTestCase):
    async def test_concurrent_create_allows_at_most_one_success_when_covers_exhausted(self) -> None:
        payload = {
            "restaurant_id": "resto-1",
            "guest_name": "Ada",
            "guest_email": "ada@example.com",
            "guest_phone": "+3200000000",
            "party_size": 2,
            "reserved_at": "2026-02-22T18:15:00+00:00",
            "notes": None,
            "status": "pending",
            "source": "agent",
        }

        store = _SharedStore()

        async def _attempt(reservation_id: str) -> tuple[str, object]:
            try:
                result = await reservation_object_module.create(
                    _CreateContext(reservation_id),
                    dict(payload),
                )
                return ("ok", result)
            except Exception as exc:
                return ("error", exc)

        with (
            patch(
                "objects.reservation.get_db_session",
                side_effect=lambda *args, **kwargs: _SessionContext(store),
            ),
            patch(
                "objects.reservation._guard_table_conflict",
                new_callable=AsyncMock,
            ),
            patch(
                "app.services.service_blocks.resolve_service_blocks",
                new_callable=AsyncMock,
                return_value=[ResolvedBlock(block=_SLOT)],
            ),
            patch(
                "app.services.service_blocks.find_block_for_time",
                return_value=ResolvedBlock(block=_SLOT),
            ),
            patch(
                "app.services.service_blocks.snap_to_interval",
                return_value=time(18, 15),
            ),
            patch(
                "app.services.service_blocks.calculate_end_time",
                return_value=datetime(2026, 2, 22, 19, 0),
            ),
        ):
            results = await asyncio.gather(_attempt("res-1"), _attempt("res-2"))

        successes = [result for state, result in results if state == "ok"]
        failures = [result for state, result in results if state == "error"]

        self.assertEqual(len(successes), 1)
        self.assertEqual(len(failures), 1)
        self.assertIsInstance(failures[0], restate_module.TerminalError)
        self.assertEqual(str(failures[0]), "Time slot is fully booked")
        self.assertEqual(sum(int(getattr(b, "party_size", 0) or 0) for b in store.bookings), 5)
