from typing import Any

import sqlalchemy as sa
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, select

from app.auth.tenant import get_current_restaurant, get_tenant_session
from app.models.reservation import Reservation
from app.models.restaurant import Restaurant
from app.models.service_block import ServiceBlockOverride
from app.models.table import FloorTable
from app.models.zone import Zone
from app.schemas.service_block import (
    ServiceBlockOverrideCreate,
    ServiceBlockOverrideRead,
    ServiceBlockOverrideUpdate,
)

router = APIRouter(prefix="/service-block-overrides", tags=["service-block-overrides"])


async def _validate_block_zone_ids(
    session: AsyncSession,
    restaurant_id: str,
    blocks: list[dict[str, Any]] | None,
) -> None:
    if not blocks:
        return
    all_zone_ids: set[str] = set()
    for block_def in blocks:
        zone_ids = block_def.get("open_zone_ids")
        if zone_ids:
            all_zone_ids.update(zone_ids)
    if not all_zone_ids:
        return
    result = await session.execute(
        select(Zone.id).where(
            col(Zone.id).in_(list(all_zone_ids)),
            Zone.restaurant_id == restaurant_id,
        )
    )
    found = {row for row in result.scalars().all()}
    missing = all_zone_ids - found
    if missing:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail=f"Invalid zone IDs in override blocks: {', '.join(sorted(missing))}",
        )


async def _check_override_zone_deselection_guard(
    session: AsyncSession,
    restaurant_id: str,
    start_date: Any,
    end_date: Any,
    old_blocks: list[dict[str, Any]] | None,
    new_blocks: list[dict[str, Any]] | None,
) -> None:
    if not old_blocks or not new_blocks:
        return
    from sqlalchemy import func

    # Compute restaurant-local now at SQL level per row
    local_now = func.timezone(Restaurant.timezone, func.now())
    for i, (old_b, new_b) in enumerate(zip(old_blocks, new_blocks, strict=False)):
        old_zones = set(old_b.get("open_zone_ids") or [])
        new_zones = set(new_b.get("open_zone_ids") or [])
        # If old had no zones (all open) and new has zones, that's a restriction
        # But we only guard when zones are explicitly removed
        if not old_zones:  # was all zones open, can't compute removed
            continue
        removed = old_zones - new_zones if new_zones else set()  # new empty = all zones, no removal
        if not removed:
            continue
        block_start_time = old_b.get("start_time")
        block_end_time = old_b.get("end_time")
        if not block_start_time or not block_end_time:
            continue
        result = await session.execute(
            select(Reservation.id, FloorTable.zone_id)
            .join(FloorTable, Reservation.table_id == FloorTable.id)  # type: ignore[arg-type]
            .join(Restaurant, Reservation.restaurant_id == Restaurant.id)  # type: ignore[arg-type]
            .where(
                Reservation.restaurant_id == restaurant_id,
                Reservation.status.in_(["pending", "confirmed"]),  # type: ignore[attr-defined]
                Reservation.reserved_at >= local_now,
                col(FloorTable.zone_id).in_(list(removed)),
                sa.cast(Reservation.reserved_at, sa.Date) >= start_date,
                sa.cast(Reservation.reserved_at, sa.Date) <= end_date,
                sa.cast(Reservation.reserved_at, sa.Time) >= block_start_time,
                sa.cast(Reservation.reserved_at, sa.Time) < block_end_time,
            )
        )
        conflicts = list(result.all())
        if conflicts:
            zone_counts: dict[str, int] = {}
            for _, zone_id in conflicts:
                zone_counts[zone_id] = zone_counts.get(zone_id, 0) + 1
            raise HTTPException(
                status_code=status.HTTP_409_CONFLICT,
                detail={
                    "message": f"Cannot remove zones from block {i} with future reservations",
                    "conflicting_zones": zone_counts,
                },
            )


def _validate_override_intervals(blocks: list[dict[str, Any]] | None) -> None:
    """Reject with 422 if any override block has an invalid slot_interval_minutes."""
    if not blocks:
        return
    for block_def in blocks:
        interval = block_def.get("slot_interval_minutes")
        if interval is not None and (interval <= 0 or interval % 15 != 0):
            raise HTTPException(
                status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
                detail="slot_interval_minutes in override blocks must be a positive multiple of 15",
            )


async def _check_date_overlap(
    session: AsyncSession,
    restaurant_id: str,
    start_date: Any,
    end_date: Any,
    exclude_id: str | None = None,
) -> None:
    """Reject with 409 if a date range overlaps an existing active override."""
    query = select(ServiceBlockOverride).where(
        ServiceBlockOverride.restaurant_id == restaurant_id,
        ServiceBlockOverride.is_active == True,  # noqa: E712
        ServiceBlockOverride.start_date <= end_date,
        ServiceBlockOverride.end_date >= start_date,
    )
    if exclude_id is not None:
        query = query.where(ServiceBlockOverride.id != exclude_id)
    result = await session.execute(query)
    if result.scalar_one_or_none() is not None:
        raise HTTPException(
            status_code=status.HTTP_409_CONFLICT,
            detail="Date range overlaps an existing active override",
        )


@router.get("/", response_model=list[ServiceBlockOverrideRead])
async def list_overrides(
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> list[Any]:
    result = await session.execute(
        select(ServiceBlockOverride)
        .where(ServiceBlockOverride.restaurant_id == restaurant.id)
        .order_by(ServiceBlockOverride.start_date)  # type: ignore[arg-type]
    )
    return list(result.scalars().all())


@router.post("/", response_model=ServiceBlockOverrideRead, status_code=status.HTTP_201_CREATED)
async def create_override(
    payload: ServiceBlockOverrideCreate,
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> Any:
    _validate_override_intervals(payload.blocks)
    await _check_date_overlap(session, restaurant.id, payload.start_date, payload.end_date)
    await _validate_block_zone_ids(session, restaurant.id, payload.blocks)
    override = ServiceBlockOverride(**payload.model_dump(), restaurant_id=restaurant.id)
    session.add(override)
    await session.commit()
    await session.refresh(override)
    return override


@router.put("/{override_id}", response_model=ServiceBlockOverrideRead)
async def update_override(
    override_id: str,
    payload: ServiceBlockOverrideUpdate,
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> Any:
    result = await session.execute(
        select(ServiceBlockOverride).where(
            ServiceBlockOverride.id == override_id,
            ServiceBlockOverride.restaurant_id == restaurant.id,
        )
    )
    override = result.scalar_one_or_none()
    if override is None:
        raise HTTPException(status_code=404, detail="Override not found")

    old_blocks = override.blocks
    _validate_override_intervals(payload.blocks)

    await _check_date_overlap(
        session, restaurant.id, payload.start_date, payload.end_date, exclude_id=override_id
    )

    await _check_override_zone_deselection_guard(
        session,
        restaurant.id,
        override.start_date,
        override.end_date,
        old_blocks,
        payload.blocks,
    )
    await _validate_block_zone_ids(session, restaurant.id, payload.blocks)

    for key, value in payload.model_dump().items():
        setattr(override, key, value)

    session.add(override)
    await session.commit()
    await session.refresh(override)
    return override


@router.delete("/{override_id}", status_code=status.HTTP_200_OK)
async def delete_override(
    override_id: str,
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> dict[str, str]:
    result = await session.execute(
        select(ServiceBlockOverride).where(
            ServiceBlockOverride.id == override_id,
            ServiceBlockOverride.restaurant_id == restaurant.id,
        )
    )
    override = result.scalar_one_or_none()
    if override is None:
        raise HTTPException(status_code=404, detail="Override not found")

    await session.delete(override)
    await session.commit()
    return {"status": "deleted"}
