from __future__ import annotations

import importlib
import os
import sys
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, Protocol, cast
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, patch

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_REPO_ROOT = Path(__file__).resolve().parents[2]
_RESTATE_SERVICES_PATH = _REPO_ROOT / "restate_services"
if str(_RESTATE_SERVICES_PATH) not in sys.path:
    sys.path.insert(0, str(_RESTATE_SERVICES_PATH))

reservation_object_module = importlib.import_module("objects.reservation")
reservation_saga_module = importlib.import_module("reservation_saga")


class _RunCompensationsFn(Protocol):
    async def __call__(
        self,
        ctx: _CompensationContext,
        reservation_id: str,
        compensations: list[tuple[str, dict[str, Any]]],
    ) -> list[dict[str, Any]]: ...


_run_compensations = cast(
    _RunCompensationsFn,
    reservation_saga_module._run_compensations,
)


class _ReservationRow:
    id: str
    status: str
    restaurant_id: str

    def __init__(self, reservation_id: str, status: str) -> None:
        self.id = reservation_id
        self.status = status
        self.restaurant_id = "restaurant-test-001"


class _ExecuteResult:
    _reservation: _ReservationRow

    def __init__(self, reservation: _ReservationRow) -> None:
        self._reservation = reservation

    def scalar_one_or_none(self) -> _ReservationRow:
        return self._reservation


class _FakeSession:
    _reservation: _ReservationRow
    commit_calls: int

    def __init__(self, reservation: _ReservationRow) -> None:
        self._reservation = reservation
        self.commit_calls = 0

    async def execute(self, *_args: object, **_kwargs: object) -> _ExecuteResult:
        return _ExecuteResult(self._reservation)

    async def commit(self) -> None:
        self.commit_calls += 1


class _SessionContext:
    _session: _FakeSession

    def __init__(self, session: _FakeSession) -> None:
        self._session = session

    async def __aenter__(self) -> _FakeSession:
        return self._session

    async def __aexit__(self, *_args: object) -> bool:
        return False


class _DurableStatusContext:
    _reservation_id: str
    run_calls: list[str]

    def __init__(self, reservation_id: str) -> None:
        self._reservation_id = reservation_id
        self.run_calls = []

    def key(self) -> str:
        return self._reservation_id

    async def run(
        self,
        step_name: str,
        fn: Callable[[], Awaitable[dict[str, str]]],
    ) -> dict[str, str]:
        self.run_calls.append(step_name)
        return await fn()


class _CompensationContext:
    def __init__(self, *, fail_refund: bool = False) -> None:
        self.fail_refund = fail_refund
        self.calls: list[tuple[str, str, dict[str, Any]]] = []
        self.durable_contexts: list[_DurableStatusContext] = []

    async def object_call(
        self,
        handler: Callable[[object, dict[str, str]], Awaitable[dict[str, str]]],
        *,
        key: str,
        arg: dict[str, str],
    ) -> dict[str, str]:
        durable_ctx = _DurableStatusContext(key)
        self.durable_contexts.append(durable_ctx)
        self.calls.append(("object_call", getattr(handler, "__name__", str(handler)), dict(arg)))
        return await handler(durable_ctx, arg)

    async def service_call(
        self,
        handler: Callable[[object, dict[str, Any]], Awaitable[dict[str, Any]]],
        *,
        arg: dict[str, Any],
    ) -> dict[str, Any]:
        self.calls.append(("service_call", getattr(handler, "__name__", str(handler)), dict(arg)))
        if self.fail_refund and getattr(handler, "__name__", "") == "refund_deposit":
            raise RuntimeError("refund service unavailable")
        return {"status": "refund_requested"}

    async def run(
        self,
        step_name: str,
        fn: Callable[[], Awaitable[dict[str, str]]],
    ) -> dict[str, str]:
        self.calls.append(("run", step_name, {}))
        return await fn()


class TestReservationCompensationDurability(IsolatedAsyncioTestCase):
    async def test_compensation_routes_to_durable_status_handler_and_persists_cancelled(
        self,
    ) -> None:
        reservation_id = "res-compensate-1"
        reservation_row = _ReservationRow(reservation_id=reservation_id, status="pending")
        session = _FakeSession(reservation_row)
        compensation_ctx = _CompensationContext()

        with (
            patch(
                "objects.reservation.get_db_session",
                side_effect=lambda *args, **kwargs: _SessionContext(session),
            ),
            patch(
                "event_publisher.publish_event",
                new_callable=AsyncMock,
            ),
            patch(
                "objects.reservation._create_in_app_notification",
                new_callable=AsyncMock,
            ),
        ):
            outcomes = await _run_compensations(
                compensation_ctx,
                reservation_id,
                [("cancel_reservation", {})],
            )

        self.assertEqual(len(compensation_ctx.calls), 1)
        self.assertEqual(
            compensation_ctx.calls[0],
            ("object_call", "cancel", {}),
        )
        self.assertEqual(outcomes[0]["compensation"], "cancel_reservation")
        self.assertEqual(outcomes[0]["status"], "ok")

        self.assertEqual(reservation_row.status, "cancelled")
        self.assertEqual(session.commit_calls, 1)
        self.assertIn("cancel_reservation", compensation_ctx.durable_contexts[0].run_calls)
        self.assertEqual(
            compensation_ctx.calls[0][1],
            reservation_object_module.cancel.__name__,
        )

    async def test_failure_injection_keeps_running_remaining_compensations(self) -> None:
        reservation_id = "res-compensate-2"
        reservation_row = _ReservationRow(reservation_id=reservation_id, status="pending")
        session = _FakeSession(reservation_row)
        compensation_ctx = _CompensationContext(fail_refund=True)

        with (
            patch(
                "objects.reservation.get_db_session",
                side_effect=lambda *args, **kwargs: _SessionContext(session),
            ),
            patch(
                "event_publisher.publish_event",
                new_callable=AsyncMock,
            ),
            patch(
                "objects.reservation._create_in_app_notification",
                new_callable=AsyncMock,
            ),
        ):
            outcomes = await _run_compensations(
                compensation_ctx,
                reservation_id,
                [
                    ("cancel_reservation", {}),
                    ("refund_deposit", {"deposit_cents": 900}),
                ],
            )

        self.assertEqual(outcomes[0]["compensation"], "refund_deposit")
        self.assertEqual(outcomes[0]["status"], "failed")
        self.assertEqual(outcomes[0]["retryable"], True)
        self.assertEqual(outcomes[1]["compensation"], "cancel_reservation")
        self.assertEqual(outcomes[1]["status"], "ok")
        self.assertEqual(reservation_row.status, "cancelled")

    async def test_compensation_order_is_reverse_of_registered_steps(self) -> None:
        reservation_id = "res-compensate-3"
        reservation_row = _ReservationRow(reservation_id=reservation_id, status="pending")
        session = _FakeSession(reservation_row)
        compensation_ctx = _CompensationContext()

        async def _fake_send_cancellation_email(_reservation_id: str) -> dict[str, str]:
            return {"status": "sent"}

        with (
            patch(
                "objects.reservation.get_db_session",
                side_effect=lambda *args, **kwargs: _SessionContext(session),
            ),
            patch(
                "reservation_saga._send_cancellation_email",
                side_effect=_fake_send_cancellation_email,
            ),
            patch(
                "event_publisher.publish_event",
                new_callable=AsyncMock,
            ),
            patch(
                "objects.reservation._create_in_app_notification",
                new_callable=AsyncMock,
            ),
        ):
            outcomes = await _run_compensations(
                compensation_ctx,
                reservation_id,
                [
                    ("cancel_reservation", {}),
                    ("refund_deposit", {"deposit_cents": 500}),
                    ("send_cancellation_email", {}),
                ],
            )

        self.assertEqual(
            [entry["compensation"] for entry in outcomes],
            ["send_cancellation_email", "refund_deposit", "cancel_reservation"],
        )
        self.assertEqual(compensation_ctx.calls[0][0], "run")
        self.assertEqual(compensation_ctx.calls[1][0], "service_call")
        self.assertEqual(compensation_ctx.calls[2][0], "object_call")
