"""Table combination business logic — chair config generation, joining-side
suggestions, capacity computation, and availability helpers."""

from __future__ import annotations

import math
from datetime import datetime

from sqlalchemy import cast, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from app.models.chair import ChairSide
from app.models.reservation import Reservation
from app.models.table import FloorTable
from app.models.table_combination import CombinedChairConfig, TableCombination

# ── Chair config generation ──────────────────────────────────────────────────


def generate_combined_chair_configs(
    combo: TableCombination,
    tables: list[FloorTable],
    suggested_disables: dict[str, list[int]] | None = None,
) -> list[CombinedChairConfig]:
    """Create CombinedChairConfig rows from each table's chairs.

    For every Chair on each table, create a matching config row.  If
    *suggested_disables* is provided (table_id → list of slot_indexes),
    those chairs start as ``enabled=False``.
    """
    if suggested_disables is None:
        suggested_disables = {}

    configs: list[CombinedChairConfig] = []

    for table in tables:
        disabled_slots = set(suggested_disables.get(table.id, []))
        # Generate from table's shape + slot_count (mirrors tables service logic)
        chair_defs = _chair_definitions_for_table(table)
        for slot_index, side in chair_defs:
            configs.append(
                CombinedChairConfig(
                    combination_id=combo.id,
                    table_id=table.id,
                    slot_index=slot_index,
                    side=side,
                    enabled=slot_index not in disabled_slots,
                    restaurant_id=combo.restaurant_id,
                )
            )

    return configs


def _chair_definitions_for_table(table: FloorTable) -> list[tuple[int, str]]:
    """Return (slot_index, side) pairs based on table shape + slot_count.

    Mirrors the logic in ``app.services.tables.generate_chairs``.
    """
    from app.models.table import TableShape

    defs: list[tuple[int, str]] = []

    if table.shape == TableShape.ROUND:
        for i in range(table.slot_count):
            defs.append((i, ChairSide.AROUND))

    elif table.shape == TableShape.SQUARE:
        sides = [ChairSide.TOP, ChairSide.RIGHT, ChairSide.BOTTOM, ChairSide.LEFT]
        for i in range(table.slot_count):
            defs.append((i, sides[i % len(sides)]))

    elif table.shape == TableShape.RECTANGLE:
        half = table.slot_count // 2
        remainder = table.slot_count % 2
        top_count = half + remainder
        for i in range(top_count):
            defs.append((i, ChairSide.TOP))
        for i in range(half):
            defs.append((top_count + i, ChairSide.BOTTOM))

    else:
        # Unknown shape — fall back to "around"
        for i in range(table.slot_count):
            defs.append((i, ChairSide.AROUND))

    return defs


# ── Joining-side suggestion ──────────────────────────────────────────────────


def suggest_joining_side_disables(
    tables: list[FloorTable],
) -> dict[str, list[int]]:
    """Analyze table positions to determine which chairs to suggest disabling.

    For each pair of adjacent tables, determine the joining direction and
    return the slot indexes that sit on the joining side.

    Returns a dict of ``{table_id: [slot_index, ...]}`` where each slot_index
    should be suggested as disabled.
    """
    from app.models.table import TableShape

    disables: dict[str, list[int]] = {}

    if len(tables) < 2:
        return disables

    # Compare each pair of tables
    for i, t1 in enumerate(tables):
        for t2 in tables[i + 1 :]:
            t1_center_x = t1.x + t1.width / 2
            t1_center_y = t1.y + t1.height / 2
            t2_center_x = t2.x + t2.width / 2
            t2_center_y = t2.y + t2.height / 2

            dx = t2_center_x - t1_center_x
            dy = t2_center_y - t1_center_y

            # Determine primary joining direction
            if abs(dx) >= abs(dy):
                # Horizontal adjacency
                if dx > 0:
                    # T2 is to the right of T1
                    t1_side, t2_side = ChairSide.RIGHT, ChairSide.LEFT
                else:
                    # T2 is to the left of T1
                    t1_side, t2_side = ChairSide.LEFT, ChairSide.RIGHT
            else:
                # Vertical adjacency
                if dy > 0:
                    # T2 is below T1
                    t1_side, t2_side = ChairSide.BOTTOM, ChairSide.TOP
                else:
                    # T2 is above T1
                    t1_side, t2_side = ChairSide.TOP, ChairSide.BOTTOM

            # Collect slot indexes on the joining side for each table
            t1_defs = _chair_definitions_for_table(t1)
            t2_defs = _chair_definitions_for_table(t2)

            if t1.shape == TableShape.ROUND:
                # For round tables, find the slot(s) closest to T2's direction
                t1_joining = _round_table_joining_slots(t1, t2_center_x, t2_center_y)
            else:
                t1_joining = [idx for idx, side in t1_defs if side == t1_side]

            if t2.shape == TableShape.ROUND:
                t2_joining = _round_table_joining_slots(t2, t1_center_x, t1_center_y)
            else:
                t2_joining = [idx for idx, side in t2_defs if side == t2_side]

            disables.setdefault(t1.id, []).extend(t1_joining)
            disables.setdefault(t2.id, []).extend(t2_joining)

    # Deduplicate
    return {tid: sorted(set(slots)) for tid, slots in disables.items()}


def _round_table_joining_slots(
    table: FloorTable,
    target_x: float,
    target_y: float,
) -> list[int]:
    """For a round table, find the ``around`` slot(s) closest to the target direction.

    Slots are evenly distributed around the circle.  We pick the one (or two,
    for tables with many slots) whose angular position is closest to the angle
    from the table center to the target point.
    """
    cx = table.x + table.width / 2
    cy = table.y + table.height / 2
    target_angle = math.atan2(target_y - cy, target_x - cx)

    n = table.slot_count
    if n == 0:
        return []

    # Each slot sits at angle = 2π * i / n (starting from 0 = right)
    best_slots: list[tuple[float, int]] = []
    for i in range(n):
        slot_angle = 2 * math.pi * i / n
        diff = abs(_normalize_angle(slot_angle - target_angle))
        best_slots.append((diff, i))

    best_slots.sort()

    # Pick the closest slot; if the table has >=6 slots, also pick 2nd closest
    count = 2 if n >= 6 else 1
    return [idx for _, idx in best_slots[:count]]


def _normalize_angle(angle: float) -> float:
    """Normalize angle to [-π, π]."""
    while angle > math.pi:
        angle -= 2 * math.pi
    while angle < -math.pi:
        angle += 2 * math.pi
    return angle


# ── Capacity computation ─────────────────────────────────────────────────────


async def compute_combined_capacity(session: AsyncSession, combo_id: str) -> int:
    """Count enabled CombinedChairConfig rows for a combination."""
    result = await session.execute(
        select(func.count())
        .select_from(CombinedChairConfig)
        .where(
            CombinedChairConfig.combination_id == combo_id,
            CombinedChairConfig.enabled == True,  # noqa: E712
        )
    )
    return result.scalar_one()


# ── Combination membership check ─────────────────────────────────────────────


async def is_in_combination(session: AsyncSession, table_id: str) -> bool:
    """Return ``True`` if *table_id* is part of any TableCombination."""
    result = await session.execute(
        select(TableCombination.id)
        .where(
            TableCombination.table_ids.op("@>")(cast([table_id], JSONB)),  # type: ignore[attr-defined]
        )
        .limit(1)
    )
    return result.scalar_one_or_none() is not None


# ── Combo availability ───────────────────────────────────────────────────────


async def check_combo_availability(
    session: AsyncSession,
    combo_id: str,
    reserved_at: datetime,
    end_time: datetime,
) -> bool:
    """Check if all tables in a combo are free during the time window.

    Returns ``True`` if the combo is available (none of its tables have
    conflicting reservations — individual or via another combo).
    """
    # Fetch the combo's table_ids
    combo_result = await session.execute(
        select(TableCombination).where(TableCombination.id == combo_id)
    )
    combo = combo_result.scalar_one_or_none()
    if combo is None:
        return False

    # Check for conflicting reservations on any of the combo's tables
    for table_id in combo.table_ids:
        has_conflict = await _table_has_conflict(session, table_id, reserved_at, end_time)
        if has_conflict:
            return False

    return True


async def _table_has_conflict(
    session: AsyncSession,
    table_id: str,
    reserved_at: datetime,
    end_time: datetime,
) -> bool:
    """Check if a single table has conflicting reservations in the time window.

    A conflict exists if there is any confirmed/pending reservation that
    overlaps [reserved_at, end_time). This includes:
    - Direct table reservations (table_id matches)
    - Combo reservations where the table is in the combo's table_ids
    """

    # Direct table conflict
    direct = await session.execute(
        select(Reservation.id)
        .where(
            Reservation.table_id == table_id,
            Reservation.status.in_(["pending", "confirmed", "seated"]),  # type: ignore[attr-defined]
            Reservation.reserved_at < end_time,
            Reservation.end_time > reserved_at,
        )
        .limit(1)
    )
    if direct.scalar_one_or_none() is not None:
        return True

    # Combo-based conflict: find all combos containing this table, then check
    # if any reservation uses those combos in the time window
    combo_result = await session.execute(
        select(TableCombination.id).where(
            TableCombination.table_ids.op("@>")(cast([table_id], JSONB)),  # type: ignore[attr-defined]
        )
    )
    combo_ids = [row[0] for row in combo_result.all()]

    if combo_ids:
        combo_conflict = await session.execute(
            select(Reservation.id)
            .where(
                Reservation.combination_id.in_(combo_ids),  # type: ignore[union-attr]
                Reservation.status.in_(["pending", "confirmed", "seated"]),  # type: ignore[attr-defined]
                Reservation.reserved_at < end_time,
                Reservation.end_time > reserved_at,
            )
            .limit(1)
        )
        if combo_conflict.scalar_one_or_none() is not None:
            return True

    return False
