"""Tests for ResolvedBlock, zone assignment resolution, and service block helpers."""

import os
import sys
from datetime import date, datetime, time
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

from fastapi import HTTPException, status

os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")

_BACKEND_PATH = Path(__file__).resolve().parents[1]
if str(_BACKEND_PATH) not in sys.path:
    sys.path.insert(0, str(_BACKEND_PATH))

from app.models.service_block import (  # noqa: E402
    ServiceBlock,
    ServiceBlockOverride,
    ServiceBlockZone,
)
from app.routers.service_block_overrides import (  # noqa: E402
    _check_override_zone_deselection_guard,
)
from app.routers.service_blocks import (  # noqa: E402
    _check_zone_deselection_guard,
    _validate_zone_ids,
    create_service_block,
    update_service_block,
)
from app.schemas.service_block import ServiceBlockCreate, ServiceBlockUpdate  # noqa: E402
from app.services.service_blocks import (  # noqa: E402
    ResolvedBlock,
    calculate_end_time,
    find_block_for_time,
    generate_available_slots,
    resolve_blocks_in_memory,
    snap_to_interval,
)


def _make_block(
    *,
    id: str = "block-1",
    restaurant_id: str = "r1",
    day_of_week: int = 0,
    name: str = "Lunch",
    block_type: str = "open",
    start_time: time = time(12, 0),
    end_time: time = time(14, 0),
    max_covers: int | None = 40,
    default_duration_minutes: int | None = 90,
    slot_interval_minutes: int = 30,
    is_active: bool = True,
    display_order: int = 0,
) -> ServiceBlock:
    """Build a ServiceBlock with sensible defaults, overridden by kwargs."""
    return ServiceBlock(
        id=id,
        restaurant_id=restaurant_id,
        day_of_week=day_of_week,
        name=name,
        block_type=block_type,
        start_time=start_time,
        end_time=end_time,
        max_covers=max_covers,
        default_duration_minutes=default_duration_minutes,
        slot_interval_minutes=slot_interval_minutes,
        is_active=is_active,
        display_order=display_order,
    )


# ── ResolvedBlock dataclass ──────────────────────────────────────────────


class TestResolvedBlock(IsolatedAsyncioTestCase):
    def test_resolved_block_none_zones(self):
        block = _make_block()
        rb = ResolvedBlock(block=block)
        self.assertIsNone(rb.open_zone_ids)

    def test_resolved_block_explicit_zones(self):
        block = _make_block()
        rb = ResolvedBlock(block=block, open_zone_ids=["z1", "z2"])
        self.assertEqual(rb.open_zone_ids, ["z1", "z2"])


# ── find_block_for_time ──────────────────────────────────────────────────


class TestFindBlockForTime(IsolatedAsyncioTestCase):
    def test_returns_resolved_block(self):
        blocks = [
            ResolvedBlock(
                block=_make_block(start_time=time(12, 0), end_time=time(14, 0)),
                open_zone_ids=["z1"],
            ),
        ]
        result = find_block_for_time(blocks, time(12, 30))
        self.assertIsNotNone(result)
        assert result is not None
        self.assertIsInstance(result, ResolvedBlock)
        self.assertEqual(result.open_zone_ids, ["z1"])

    def test_returns_none_when_no_match(self):
        blocks = [
            ResolvedBlock(
                block=_make_block(start_time=time(12, 0), end_time=time(14, 0)),
            ),
        ]
        result = find_block_for_time(blocks, time(10, 0))
        self.assertIsNone(result)

    def test_skips_closed_blocks(self):
        blocks = [
            ResolvedBlock(
                block=_make_block(
                    id="closed-1",
                    block_type="closed",
                    start_time=time(12, 0),
                    end_time=time(14, 0),
                ),
                open_zone_ids=["z1"],
            ),
        ]
        result = find_block_for_time(blocks, time(12, 30))
        self.assertIsNone(result)

    def test_returns_correct_block_among_multiple(self):
        blocks = [
            ResolvedBlock(
                block=_make_block(
                    id="morning",
                    start_time=time(9, 0),
                    end_time=time(11, 0),
                ),
                open_zone_ids=["z-morning"],
            ),
            ResolvedBlock(
                block=_make_block(
                    id="lunch",
                    start_time=time(12, 0),
                    end_time=time(14, 0),
                ),
                open_zone_ids=["z-lunch"],
            ),
        ]
        result = find_block_for_time(blocks, time(12, 30))
        self.assertIsNotNone(result)
        assert result is not None
        self.assertEqual(result.block.id, "lunch")
        self.assertEqual(result.open_zone_ids, ["z-lunch"])


# ── generate_available_slots ─────────────────────────────────────────────


class TestGenerateAvailableSlots(IsolatedAsyncioTestCase):
    def test_with_resolved_block(self):
        rb = ResolvedBlock(
            block=_make_block(
                start_time=time(12, 0),
                end_time=time(14, 0),
                slot_interval_minutes=30,
                default_duration_minutes=90,
            ),
        )
        slots = generate_available_slots(rb)
        # 12:00 + 90m = 13:30 <= 14:00 OK
        # 12:30 + 90m = 14:00 <= 14:00 OK
        # 13:00 + 90m = 14:30 > 14:00 STOP
        self.assertEqual(slots, [time(12, 0), time(12, 30)])

    def test_with_plain_service_block(self):
        block = _make_block(
            start_time=time(12, 0),
            end_time=time(14, 0),
            slot_interval_minutes=30,
            default_duration_minutes=90,
        )
        slots = generate_available_slots(block)
        self.assertEqual(slots, [time(12, 0), time(12, 30)])

    def test_closed_block_returns_empty(self):
        rb = ResolvedBlock(block=_make_block(block_type="closed"))
        slots = generate_available_slots(rb)
        self.assertEqual(slots, [])


# ── calculate_end_time ───────────────────────────────────────────────────


class TestCalculateEndTime(IsolatedAsyncioTestCase):
    def test_with_resolved_block(self):
        rb = ResolvedBlock(
            block=_make_block(default_duration_minutes=120),
        )
        start = datetime(2026, 3, 15, 12, 0)
        end = calculate_end_time(start, rb)
        self.assertEqual(end, datetime(2026, 3, 15, 14, 0))

    def test_with_none(self):
        start = datetime(2026, 3, 15, 12, 0)
        end = calculate_end_time(start, None)
        # Fallback: 90 minutes
        self.assertEqual(end, datetime(2026, 3, 15, 13, 30))

    def test_with_block_missing_duration(self):
        rb = ResolvedBlock(
            block=_make_block(default_duration_minutes=None),
        )
        start = datetime(2026, 3, 15, 12, 0)
        end = calculate_end_time(start, rb)
        # Falls back to 90 minutes
        self.assertEqual(end, datetime(2026, 3, 15, 13, 30))


# ── snap_to_interval ────────────────────────────────────────────────────


class TestSnapToInterval(IsolatedAsyncioTestCase):
    def test_with_resolved_block(self):
        rb = ResolvedBlock(
            block=_make_block(
                start_time=time(12, 0),
                end_time=time(14, 0),
                slot_interval_minutes=30,
                default_duration_minutes=90,
            ),
        )
        result = snap_to_interval(time(12, 10), rb)
        self.assertEqual(result, time(12, 0))

    def test_snaps_to_nearest(self):
        rb = ResolvedBlock(
            block=_make_block(
                start_time=time(12, 0),
                end_time=time(14, 0),
                slot_interval_minutes=30,
                default_duration_minutes=90,
            ),
        )
        result = snap_to_interval(time(12, 20), rb)
        self.assertEqual(result, time(12, 30))

    def test_closed_block_returns_none(self):
        rb = ResolvedBlock(block=_make_block(block_type="closed"))
        result = snap_to_interval(time(12, 0), rb)
        self.assertIsNone(result)


# ── resolve_blocks_in_memory ─────────────────────────────────────────────


class TestResolveBlocksInMemory(IsolatedAsyncioTestCase):
    def test_normal_blocks_with_zone_assignments(self):
        b1 = _make_block(id="b1", day_of_week=0)
        b2 = _make_block(id="b2", day_of_week=0)
        blocks_by_dow = {0: [b1, b2]}
        zone_assignments = {"b1": ["z1", "z2"], "b2": ["z3"]}

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 9),  # Monday = 0
            blocks_by_dow=blocks_by_dow,
            active_overrides=[],
            restaurant_id="r1",
            zone_assignments=zone_assignments,
        )

        self.assertEqual(len(result), 2)
        self.assertIsInstance(result[0], ResolvedBlock)
        self.assertEqual(result[0].block.id, "b1")
        self.assertEqual(result[0].open_zone_ids, ["z1", "z2"])
        self.assertEqual(result[1].block.id, "b2")
        self.assertEqual(result[1].open_zone_ids, ["z3"])

    def test_normal_blocks_without_zone_assignments(self):
        b1 = _make_block(id="b1", day_of_week=0)
        blocks_by_dow = {0: [b1]}

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 9),
            blocks_by_dow=blocks_by_dow,
            active_overrides=[],
            restaurant_id="r1",
            zone_assignments=None,
        )

        self.assertEqual(len(result), 1)
        self.assertIsNone(result[0].open_zone_ids)

    def test_override_blocks_with_open_zone_ids(self):
        override = ServiceBlockOverride(
            id="ov-1",
            restaurant_id="r1",
            start_date=date(2026, 3, 15),
            end_date=date(2026, 3, 20),
            is_active=True,
            name="Holiday",
            blocks=[
                {
                    "name": "Special",
                    "block_type": "open",
                    "start_time": "18:00",
                    "end_time": "22:00",
                    "max_covers": 30,
                    "open_zone_ids": ["z1", "z2"],
                }
            ],
        )

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 16),
            blocks_by_dow={},
            active_overrides=[override],
            restaurant_id="r1",
        )

        self.assertEqual(len(result), 1)
        self.assertIsInstance(result[0], ResolvedBlock)
        self.assertEqual(result[0].open_zone_ids, ["z1", "z2"])
        self.assertEqual(result[0].block.name, "Special")
        self.assertEqual(result[0].block.start_time, time(18, 0))

    def test_override_blocks_without_open_zone_ids(self):
        override = ServiceBlockOverride(
            id="ov-2",
            restaurant_id="r1",
            start_date=date(2026, 3, 15),
            end_date=date(2026, 3, 20),
            is_active=True,
            name="Weekend",
            blocks=[
                {
                    "name": "Brunch",
                    "block_type": "open",
                    "start_time": "10:00",
                    "end_time": "14:00",
                    "max_covers": 20,
                }
            ],
        )

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 16),
            blocks_by_dow={},
            active_overrides=[override],
            restaurant_id="r1",
        )

        self.assertEqual(len(result), 1)
        self.assertIsNone(result[0].open_zone_ids)

    def test_override_blocks_with_empty_open_zone_ids(self):
        override = ServiceBlockOverride(
            id="ov-3",
            restaurant_id="r1",
            start_date=date(2026, 3, 15),
            end_date=date(2026, 3, 20),
            is_active=True,
            name="Midweek",
            blocks=[
                {
                    "name": "Dinner",
                    "block_type": "open",
                    "start_time": "18:00",
                    "end_time": "22:00",
                    "max_covers": 50,
                    "open_zone_ids": [],
                }
            ],
        )

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 16),
            blocks_by_dow={},
            active_overrides=[override],
            restaurant_id="r1",
        )

        self.assertEqual(len(result), 1)
        # Empty list treated as None (all zones open)
        self.assertIsNone(result[0].open_zone_ids)

    def test_override_takes_precedence_over_normal(self):
        """Override for the date wins over day-of-week blocks."""
        normal = _make_block(id="normal-1", day_of_week=0)
        override = ServiceBlockOverride(
            id="ov-4",
            restaurant_id="r1",
            start_date=date(2026, 3, 9),
            end_date=date(2026, 3, 9),
            is_active=True,
            name="Special Monday",
            blocks=[
                {
                    "name": "Override Block",
                    "block_type": "open",
                    "start_time": "11:00",
                    "end_time": "15:00",
                    "max_covers": 25,
                    "open_zone_ids": ["z-special"],
                }
            ],
        )

        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 9),
            blocks_by_dow={0: [normal]},
            active_overrides=[override],
            restaurant_id="r1",
            zone_assignments={"normal-1": ["z1"]},
        )

        self.assertEqual(len(result), 1)
        self.assertEqual(result[0].block.name, "Override Block")
        self.assertEqual(result[0].open_zone_ids, ["z-special"])

    def test_normal_block_with_empty_zone_list(self):
        """Zone assignment with empty list → None (all zones)."""
        b1 = _make_block(id="b1", day_of_week=0)
        result = resolve_blocks_in_memory(
            target_date=date(2026, 3, 9),
            blocks_by_dow={0: [b1]},
            active_overrides=[],
            restaurant_id="r1",
            zone_assignments={"b1": []},
        )
        self.assertEqual(len(result), 1)
        self.assertIsNone(result[0].open_zone_ids)


def _mock_scalars_result(values: list[str]) -> MagicMock:
    result = MagicMock()
    scalars = MagicMock()
    scalars.all.return_value = values
    result.scalars.return_value = scalars
    return result


def _mock_rows_result(rows: list[tuple[str, str]]) -> MagicMock:
    result = MagicMock()
    result.all.return_value = rows
    return result


def _mock_scalar_one_or_none_result(value: object) -> MagicMock:
    result = MagicMock()
    result.scalar_one_or_none.return_value = value
    return result


class TestValidateZoneIds(IsolatedAsyncioTestCase):
    async def test_all_zone_ids_found_passes(self):
        session = AsyncMock()
        session.execute.return_value = _mock_scalars_result(["z1", "z2"])

        await _validate_zone_ids(session, "rest-1", ["z1", "z2"])

        session.execute.assert_awaited_once()

    async def test_missing_zone_ids_raise_422(self):
        session = AsyncMock()
        session.execute.return_value = _mock_scalars_result(["z1"])

        with self.assertRaises(HTTPException) as ctx:
            await _validate_zone_ids(session, "rest-1", ["z1", "z2", "z3"])

        self.assertEqual(ctx.exception.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY)
        self.assertEqual(ctx.exception.detail, "Invalid zone IDs: z2, z3")

    async def test_empty_zone_ids_short_circuit(self):
        session = AsyncMock()

        await _validate_zone_ids(session, "rest-1", [])

        session.execute.assert_not_called()


class TestZoneDeselectionGuard(IsolatedAsyncioTestCase):
    async def test_no_conflicts_passes(self):
        session = AsyncMock()
        session.execute.return_value = _mock_rows_result([])

        await _check_zone_deselection_guard(
            session,
            "rest-1",
            0,
            time(12, 0),
            time(14, 0),
            {"z1"},
        )

        session.execute.assert_awaited_once()

    async def test_conflicts_raise_409_with_zone_counts(self):
        session = AsyncMock()
        session.execute.return_value = _mock_rows_result(
            [("res-1", "z1"), ("res-2", "z1"), ("res-3", "z2")]
        )

        with self.assertRaises(HTTPException) as ctx:
            await _check_zone_deselection_guard(
                session,
                "rest-1",
                0,
                time(12, 0),
                time(14, 0),
                {"z1", "z2"},
            )

        self.assertEqual(ctx.exception.status_code, status.HTTP_409_CONFLICT)
        self.assertEqual(
            ctx.exception.detail,
            {
                "message": "Cannot remove zones with future reservations",
                "conflicting_zones": {"z1": 2, "z2": 1},
            },
        )

    async def test_empty_removed_zone_ids_short_circuit(self):
        session = AsyncMock()

        await _check_zone_deselection_guard(
            session,
            "rest-1",
            0,
            time(12, 0),
            time(14, 0),
            set(),
        )

        session.execute.assert_not_called()


class TestOverrideZoneDeselectionGuard(IsolatedAsyncioTestCase):
    async def test_no_conflicts_passes(self):
        session = AsyncMock()
        session.execute.return_value = _mock_rows_result([])

        await _check_override_zone_deselection_guard(
            session,
            "rest-1",
            date(2026, 3, 15),
            date(2026, 3, 20),
            [
                {
                    "start_time": time(18, 0),
                    "end_time": time(22, 0),
                    "open_zone_ids": ["z1", "z2"],
                }
            ],
            [
                {
                    "start_time": time(18, 0),
                    "end_time": time(22, 0),
                    "open_zone_ids": ["z2"],
                }
            ],
        )

        session.execute.assert_awaited_once()

    async def test_conflicts_raise_409_with_zone_counts(self):
        session = AsyncMock()
        session.execute.return_value = _mock_rows_result([("res-1", "z1"), ("res-2", "z1")])

        with self.assertRaises(HTTPException) as ctx:
            await _check_override_zone_deselection_guard(
                session,
                "rest-1",
                date(2026, 3, 15),
                date(2026, 3, 20),
                [
                    {
                        "start_time": time(18, 0),
                        "end_time": time(22, 0),
                        "open_zone_ids": ["z1", "z2"],
                    }
                ],
                [
                    {
                        "start_time": time(18, 0),
                        "end_time": time(22, 0),
                        "open_zone_ids": ["z2"],
                    }
                ],
            )

        self.assertEqual(ctx.exception.status_code, status.HTTP_409_CONFLICT)
        self.assertEqual(
            ctx.exception.detail,
            {
                "message": "Cannot remove zones from block 0 with future reservations",
                "conflicting_zones": {"z1": 2},
            },
        )


class TestServiceBlockZoneCascade(IsolatedAsyncioTestCase):
    def test_table_and_foreign_keys_use_cascade_delete(self):
        table = ServiceBlockZone.__table__  # type: ignore[attr-defined]

        self.assertEqual(ServiceBlockZone.__tablename__, "service_block_zones")
        self.assertEqual(table.name, "service_block_zones")

        service_block_fk = next(iter(table.c.service_block_id.foreign_keys))
        zone_fk = next(iter(table.c.zone_id.foreign_keys))

        self.assertEqual(service_block_fk.target_fullname, "service_block.id")
        self.assertEqual(service_block_fk.ondelete, "CASCADE")
        self.assertEqual(zone_fk.target_fullname, "zone.id")
        self.assertEqual(zone_fk.ondelete, "CASCADE")

        self.assertTrue(table.c.service_block_id.primary_key)
        self.assertTrue(table.c.zone_id.primary_key)


class TestServiceBlockZoneCRUD(IsolatedAsyncioTestCase):
    async def test_create_with_zone_ids_adds_join_rows(self):
        session = MagicMock()
        session.execute = AsyncMock(return_value=MagicMock())
        session.commit = AsyncMock()
        session.refresh = AsyncMock()
        session.add = MagicMock()
        restaurant = MagicMock(id="rest-1")
        payload = ServiceBlockCreate(
            day_of_week=0,
            name="Lunch",
            block_type="open",
            start_time=time(12, 0),
            end_time=time(14, 0),
            max_covers=40,
            default_duration_minutes=90,
            is_active=True,
            display_order=0,
            slot_interval_minutes=30,
            open_zone_ids=["z1", "z2"],
        )

        with (
            patch("app.routers.service_blocks._check_overlap", AsyncMock()) as mock_overlap,
            patch("app.routers.service_blocks._validate_zone_ids", AsyncMock()) as mock_validate,
        ):
            result = await create_service_block(
                payload=payload,
                session=session,
                restaurant=restaurant,
            )

        mock_overlap.assert_awaited_once_with(
            session,
            "rest-1",
            payload.day_of_week,
            payload.start_time,
            payload.end_time,
        )
        mock_validate.assert_awaited_once_with(session, "rest-1", ["z1", "z2"])
        self.assertEqual(session.commit.await_count, 2)
        session.refresh.assert_awaited_once()

        added_objects = [call.args[0] for call in session.add.call_args_list]
        self.assertIsInstance(added_objects[0], ServiceBlock)
        self.assertEqual(added_objects[0].restaurant_id, "rest-1")
        self.assertEqual(
            [(obj.service_block_id, obj.zone_id) for obj in added_objects[1:]],
            [(added_objects[0].id, "z1"), (added_objects[0].id, "z2")],
        )
        self.assertEqual(result["open_zone_ids"], ["z1", "z2"])

    async def test_update_removing_zones_checks_guard_and_resyncs_rows(self):
        existing_block = _make_block(id="block-1", restaurant_id="rest-1")
        select_result = _mock_scalar_one_or_none_result(existing_block)
        delete_result = MagicMock()
        session = MagicMock()
        session.execute = AsyncMock(side_effect=[select_result, delete_result])
        session.commit = AsyncMock()
        session.refresh = AsyncMock()
        session.add = MagicMock()
        restaurant = MagicMock(id="rest-1")
        payload = ServiceBlockUpdate(
            day_of_week=0,
            name="Updated Lunch",
            block_type="open",
            start_time=time(12, 0),
            end_time=time(14, 0),
            max_covers=50,
            default_duration_minutes=90,
            is_active=True,
            display_order=0,
            slot_interval_minutes=30,
            open_zone_ids=["z2", "z3"],
        )

        with (
            patch("app.routers.service_blocks._check_overlap", AsyncMock()) as mock_overlap,
            patch(
                "app.routers.service_blocks._get_zone_ids_for_block",
                AsyncMock(return_value=["z1", "z2"]),
            ),
            patch(
                "app.routers.service_blocks._check_zone_deselection_guard",
                AsyncMock(),
            ) as mock_guard,
            patch("app.routers.service_blocks._validate_zone_ids", AsyncMock()) as mock_validate,
        ):
            result = await update_service_block(
                block_id="block-1",
                payload=payload,
                session=session,
                restaurant=restaurant,
            )

        mock_overlap.assert_awaited_once_with(
            session,
            "rest-1",
            payload.day_of_week,
            payload.start_time,
            payload.end_time,
            exclude_id="block-1",
        )
        mock_guard.assert_awaited_once_with(
            session,
            "rest-1",
            payload.day_of_week,
            payload.start_time,
            payload.end_time,
            {"z1"},
        )
        validate_args = mock_validate.await_args.args  # type: ignore[reportAttributeAccessIssue]
        self.assertEqual(validate_args[0], session)
        self.assertEqual(validate_args[1], "rest-1")
        self.assertCountEqual(validate_args[2], ["z2", "z3"])
        self.assertEqual(session.execute.await_count, 2)
        delete_stmt = session.execute.await_args_list[1].args[0]
        self.assertEqual(delete_stmt.table.name, ServiceBlockZone.__tablename__)
        self.assertEqual(session.commit.await_count, 1)
        session.refresh.assert_awaited_once_with(existing_block)

        added_objects = [call.args[0] for call in session.add.call_args_list]
        self.assertIs(added_objects[0], existing_block)
        self.assertEqual(
            [(obj.service_block_id, obj.zone_id) for obj in added_objects[1:]],
            [("block-1", "z2"), ("block-1", "z3")],
        )
        self.assertEqual(existing_block.name, "Updated Lunch")
        self.assertEqual(existing_block.max_covers, 50)
        self.assertEqual(result["open_zone_ids"], ["z2", "z3"])
