"""Table allocation service — find the best-fit table or combination for a
reservation based on party size, capacity, and time-slot availability.

Ordering heuristics:
- Individual tables are sorted by (combo_membership_count ASC, capacity ASC)
  so tables NOT in any combination are preferred — this preserves combinations
  for larger parties that actually need them.
- Combinations are sorted by (overlap_score ASC, combined_capacity ASC) where
  overlap_score is the total number of *other* combinations that share their
  constituent tables.  Picking the combo with the lowest score leaves the most
  alternative combinations available for future reservations.
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime

from sqlalchemy import cast, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from app.models.chair import Chair
from app.models.reservation import Reservation
from app.models.table import FloorTable
from app.models.table_combination import TableCombination


@dataclass
class TableAllocationResult:
    table_id: str | None
    combination_id: str | None
    capacity: int


# ── Candidate queries ─────────────────────────────────────────────────────────


async def _get_table_candidates(
    session: AsyncSession,
    restaurant_id: str,
    party_size: int,
) -> list[tuple[str, int]]:
    """Return ``(table_id, capacity)`` tuples for tables with enough enabled
    chairs, ordered by (combo_membership_count ASC, capacity ASC).

    Tables that do not belong to any combination are preferred so that
    combination capacity is preserved for larger parties that need it.
    """
    # Correlated subquery: how many combinations include this table?
    combo_membership = (
        select(func.count())
        .where(
            TableCombination.restaurant_id == restaurant_id,
            TableCombination.table_ids.op("@>")(  # type: ignore[attr-defined]
                func.jsonb_build_array(FloorTable.id)
            ),
        )
        .correlate(FloorTable)
        .scalar_subquery()
        .label("combo_membership")
    )

    enabled_count = func.count(Chair.id).label("capacity")
    stmt = (
        select(FloorTable.id, enabled_count)
        .outerjoin(Chair, (Chair.table_id == FloorTable.id) & (Chair.enabled == True))  # noqa: E712
        .where(FloorTable.restaurant_id == restaurant_id)
        .group_by(FloorTable.id)
        .having(enabled_count >= party_size)
        .order_by(combo_membership.asc(), enabled_count.asc())
    )
    result = await session.execute(stmt)
    return [(row[0], row[1]) for row in result.all()]


async def _get_combo_candidates(
    session: AsyncSession,
    restaurant_id: str,
    party_size: int,
) -> list[TableCombination]:
    """Return combinations with enough capacity, ordered by overlap score.

    Overlap score = total number of *other* combinations sharing the
    candidate's constituent tables.  Lower score → picking this combo blocks
    fewer alternative seating arrangements.  Ties broken by capacity (smallest
    first).
    """
    # Fetch ALL combos for the restaurant to build the overlap map
    all_result = await session.execute(
        select(TableCombination).where(
            TableCombination.restaurant_id == restaurant_id,
        )
    )
    all_combos = list(all_result.scalars().all())

    candidates = [c for c in all_combos if c.combined_capacity >= party_size]
    return _rank_combo_candidates(candidates, all_combos)


def _rank_combo_candidates(
    candidates: list[TableCombination],
    all_combos: list[TableCombination],
) -> list[TableCombination]:
    """Sort *candidates* by (overlap_score ASC, combined_capacity ASC).

    ``overlap_score`` for a combo = sum over its table_ids of the number of
    *other* combinations each table appears in.  A lower score means the
    combo's tables are more isolated and picking it blocks fewer alternatives.
    """
    # table_id → set of combo IDs it belongs to
    table_combo_map: dict[str, set[str]] = {}
    for combo in all_combos:
        for tid in combo.table_ids:
            table_combo_map.setdefault(tid, set()).add(combo.id)

    def _sort_key(combo: TableCombination) -> tuple[int, int]:
        score = sum(len(table_combo_map.get(tid, set()) - {combo.id}) for tid in combo.table_ids)
        return (score, combo.combined_capacity)

    return sorted(candidates, key=_sort_key)


# ── Conflict checks ──────────────────────────────────────────────────────────


async def _table_has_time_conflict(
    session: AsyncSession,
    table_id: str,
    reserved_at: datetime,
    end_time: datetime,
    exclude_reservation_id: str | None = None,
) -> bool:
    """Return ``True`` if *table_id* has an overlapping reservation — either
    directly via ``Reservation.table_id`` or indirectly via a combination that
    includes this table."""
    _status_filter = ["pending", "confirmed", "seated"]

    # 1) Direct reservation on this table
    stmt = (
        select(Reservation.id)
        .where(
            Reservation.table_id == table_id,
            Reservation.status.in_(_status_filter),  # type: ignore[attr-defined]
            Reservation.reserved_at < end_time,
            Reservation.end_time > reserved_at,
        )
        .limit(1)
    )
    if exclude_reservation_id:
        stmt = stmt.where(Reservation.id != exclude_reservation_id)
    result = await session.execute(stmt)
    if result.scalar_one_or_none() is not None:
        return True

    # 2) Reservation via a combination that includes this table
    combo_ids_subq = (
        select(TableCombination.id)
        .where(
            TableCombination.table_ids.op("@>")(cast([table_id], JSONB)),  # type: ignore[attr-defined]
        )
        .scalar_subquery()
    )
    combo_stmt = (
        select(Reservation.id)
        .where(
            Reservation.combination_id.in_(combo_ids_subq),  # type: ignore[union-attr]
            Reservation.status.in_(_status_filter),  # type: ignore[attr-defined]
            Reservation.reserved_at < end_time,
            Reservation.end_time > reserved_at,
        )
        .limit(1)
    )
    if exclude_reservation_id:
        combo_stmt = combo_stmt.where(Reservation.id != exclude_reservation_id)
    combo_result = await session.execute(combo_stmt)
    return combo_result.scalar_one_or_none() is not None


async def _combo_has_time_conflict(
    session: AsyncSession,
    combo: TableCombination,
    reserved_at: datetime,
    end_time: datetime,
    exclude_reservation_id: str | None = None,
) -> bool:
    """Return ``True`` if any table in *combo* has a time conflict — either via
    a direct table reservation or via another combination that shares a table."""

    # 1) Direct table conflicts
    for tid in combo.table_ids:
        if await _table_has_time_conflict(
            session, tid, reserved_at, end_time, exclude_reservation_id
        ):
            return True

    # 2) Combo-based conflicts — collect ALL conflicting combo IDs in one pass
    all_combo_ids: set[str] = set()
    for tid in combo.table_ids:
        combo_result = await session.execute(
            select(TableCombination.id).where(
                TableCombination.table_ids.op("@>")(cast([tid], JSONB)),  # type: ignore[attr-defined]
            )
        )
        all_combo_ids.update(row[0] for row in combo_result.all())

    # Include this combo + others sharing tables — a reservation on *any*
    # of them means a constituent table is occupied.

    if not all_combo_ids:
        return False

    # Single query for all combo-based reservation conflicts
    stmt = (
        select(Reservation.id)
        .where(
            Reservation.combination_id.in_(list(all_combo_ids)),  # type: ignore[union-attr]
            Reservation.status.in_(["pending", "confirmed", "seated"]),  # type: ignore[attr-defined]
            Reservation.reserved_at < end_time,
            Reservation.end_time > reserved_at,
        )
        .limit(1)
    )
    if exclude_reservation_id:
        stmt = stmt.where(Reservation.id != exclude_reservation_id)
    result = await session.execute(stmt)
    return result.scalar_one_or_none() is not None


# ── Public API ────────────────────────────────────────────────────────────────


async def find_best_table(
    session: AsyncSession,
    restaurant_id: str,
    party_size: int,
    reserved_at: datetime,
    end_time: datetime,
    exclude_reservation_id: str | None = None,
) -> TableAllocationResult | None:
    """Find the smallest available table or combination for the party.

    Prefers individual tables over combinations. Returns ``None`` when nothing
    is available.
    """
    # Individual tables — prefer non-combination tables, then smallest
    candidates = await _get_table_candidates(session, restaurant_id, party_size)
    for table_id, capacity in candidates:
        if not await _table_has_time_conflict(
            session, table_id, reserved_at, end_time, exclude_reservation_id
        ):
            return TableAllocationResult(table_id=table_id, combination_id=None, capacity=capacity)

    # Combinations — prefer those with least overlap, then smallest
    combos = await _get_combo_candidates(session, restaurant_id, party_size)
    for combo in combos:
        if not await _combo_has_time_conflict(
            session, combo, reserved_at, end_time, exclude_reservation_id
        ):
            return TableAllocationResult(
                table_id=None,
                combination_id=combo.id,
                capacity=combo.combined_capacity,
            )

    return None


async def has_any_available_table(
    session: AsyncSession,
    restaurant_id: str,
    party_size: int,
    reserved_at: datetime,
    end_time: datetime,
) -> bool:
    """Return ``True`` if at least one table or combination can seat the party
    during the requested time window."""
    candidates = await _get_table_candidates(session, restaurant_id, party_size)
    for table_id, _ in candidates:
        if not await _table_has_time_conflict(session, table_id, reserved_at, end_time):
            return True

    combos = await _get_combo_candidates(session, restaurant_id, party_size)
    for combo in combos:
        if not await _combo_has_time_conflict(session, combo, reserved_at, end_time):
            return True

    return False


# ── Bulk availability (pre-fetched, in-memory) ────────────────────────────────


@dataclass
class BulkAvailabilityData:
    """Pre-fetched data for checking slot availability without per-slot queries.

    ``occupied`` maps each table_id to its occupied intervals, accounting for
    both direct table reservations and combo-based reservations.  Only
    reservations with active statuses (pending/confirmed/seated) are included —
    matching the semantics of ``_table_has_time_conflict``.

    ``all_reservations`` includes all non-cancelled reservations for the day
    window — used for block-level capacity checks which count all covers
    regardless of completion status.
    """

    table_candidates: list[tuple[str, int]]
    """(table_id, capacity) ordered by (combo_membership ASC, capacity ASC)."""

    ranked_combos: list[TableCombination]
    """Capacity-filtered and overlap-ranked combo candidates."""

    all_combos: list[TableCombination]
    """Every combination for the restaurant (for combo option filtering)."""

    occupied: dict[str, list[tuple[datetime, datetime]]]
    """table_id → list of (start, end) intervals from active reservations."""

    all_reservations: list[Reservation]
    """All non-cancelled reservations in the window (for block cover counts)."""


async def preload_slot_availability(
    session: AsyncSession,
    restaurant_id: str,
    party_size: int,
    window_start: datetime,
    window_end: datetime,
) -> BulkAvailabilityData:
    """Pre-fetch all data needed for bulk slot availability in 3 queries.

    Replaces the per-slot N+1 pattern where ``has_any_available_table``
    re-fetches tables, combos, and conflict-checked reservations for every
    candidate slot.
    """
    _ACTIVE_STATUSES = {"pending", "confirmed", "seated"}

    # Query 1: table candidates — same ordering as _get_table_candidates
    table_candidates = await _get_table_candidates(session, restaurant_id, party_size)

    # Query 2: ALL table combinations for the restaurant
    combo_result = await session.execute(
        select(TableCombination).where(
            TableCombination.restaurant_id == restaurant_id,
        )
    )
    all_combos = list(combo_result.scalars().all())
    combo_map: dict[str, TableCombination] = {c.id: c for c in all_combos}

    # Derive capacity-filtered + overlap-ranked combo candidates
    candidates = [c for c in all_combos if c.combined_capacity >= party_size]
    ranked_combos = _rank_combo_candidates(candidates, all_combos)

    # Query 3: all non-cancelled reservations overlapping the window
    res_result = await session.execute(
        select(Reservation).where(
            Reservation.restaurant_id == restaurant_id,
            Reservation.status != "cancelled",
            Reservation.reserved_at < window_end,
            Reservation.end_time > window_start,
        )
    )
    all_reservations = list(res_result.scalars().all())

    # Build table_id → occupied intervals (active statuses only)
    occupied: dict[str, list[tuple[datetime, datetime]]] = {}
    for r in all_reservations:
        if r.status not in _ACTIVE_STATUSES:
            continue
        if r.table_id:
            occupied.setdefault(r.table_id, []).append((r.reserved_at, r.end_time))
        if r.combination_id:
            combo = combo_map.get(r.combination_id)
            if combo:
                for tid in combo.table_ids:
                    occupied.setdefault(tid, []).append((r.reserved_at, r.end_time))

    return BulkAvailabilityData(
        table_candidates=table_candidates,
        ranked_combos=ranked_combos,
        all_combos=all_combos,
        occupied=occupied,
        all_reservations=all_reservations,
    )


def table_free_in_window(
    table_id: str,
    slot_start: datetime,
    slot_end: datetime,
    occupied: dict[str, list[tuple[datetime, datetime]]],
) -> bool:
    """In-memory check: True if *table_id* has no overlapping active reservation.

    Equivalent to ``not _table_has_time_conflict(...)`` but operates on
    pre-computed occupied intervals from ``preload_slot_availability``.
    """
    for r_start, r_end in occupied.get(table_id, []):
        if r_start < slot_end and r_end > slot_start:
            return False
    return True


def slot_has_availability(
    party_size: int,
    slot_start: datetime,
    slot_end: datetime,
    data: BulkAvailabilityData,
) -> bool:
    """In-memory equivalent of ``has_any_available_table(...)``.

    Checks individual tables first (preferred), then combinations.
    """
    for table_id, _ in data.table_candidates:
        if table_free_in_window(table_id, slot_start, slot_end, data.occupied):
            return True
    for combo in data.ranked_combos:
        if all(
            table_free_in_window(tid, slot_start, slot_end, data.occupied)
            for tid in combo.table_ids
        ):
            return True
    return False
