"""Centralized service block resolution — single source of truth for all availability paths.

All four availability check paths (public endpoint, agent tool, reservation saga,
Restate handler) MUST use these helpers instead of querying ServiceBlock directly.
"""

from dataclasses import dataclass
from datetime import date, datetime, time, timedelta

from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, select

from app.models.service_block import ServiceBlock, ServiceBlockOverride, ServiceBlockZone

_DEFAULT_DURATION_MINUTES = 90


@dataclass
class ResolvedBlock:
    """A service block enriched with zone availability data.

    ``open_zone_ids`` is ``None`` when all zones are open (no explicit
    assignment), or a list of zone IDs that are open for this block.
    """

    block: ServiceBlock
    open_zone_ids: list[str] | None = None


async def resolve_service_blocks(
    session: AsyncSession,
    restaurant_id: str,
    target_date: date,
) -> list[ResolvedBlock]:
    """Return the applicable service blocks for a restaurant on a given date.

    Resolution order:
    1. Check for an active ServiceBlockOverride covering ``target_date``.
       If found, synthesise ServiceBlock-like objects from its JSON ``blocks`` array
       and return those (normal day-of-week blocks are completely ignored).
    2. Otherwise, return normal ServiceBlock entries for the date's ``day_of_week``.
    """
    # Step 1: Check for active overrides covering the target date
    override_result = await session.execute(
        select(ServiceBlockOverride).where(
            ServiceBlockOverride.restaurant_id == restaurant_id,
            ServiceBlockOverride.is_active == True,  # noqa: E712
            ServiceBlockOverride.start_date <= target_date,
            ServiceBlockOverride.end_date >= target_date,
        )
    )
    override = override_result.scalar_one_or_none()

    if override is not None and override.blocks:
        # Synthesise ServiceBlock-like objects from override JSON
        synthesised: list[ResolvedBlock] = []
        for idx, block_def in enumerate(override.blocks):
            raw_zones = block_def.get("open_zone_ids")
            zone_ids = raw_zones if raw_zones else None
            sb = ServiceBlock(
                id=f"override-{override.id}-{idx}",
                restaurant_id=restaurant_id,
                day_of_week=target_date.weekday(),
                name=block_def.get("name", override.name),
                block_type=block_def.get("block_type", "open"),
                start_time=time.fromisoformat(block_def["start_time"]),
                end_time=time.fromisoformat(block_def["end_time"]),
                max_covers=block_def.get("max_covers"),
                default_duration_minutes=block_def.get("default_duration_minutes"),
                is_active=True,
                display_order=idx,
                slot_interval_minutes=block_def.get("slot_interval_minutes", 30),
            )
            synthesised.append(ResolvedBlock(block=sb, open_zone_ids=zone_ids))
        return synthesised

    # Step 2: Fall back to normal day-of-week blocks
    day_of_week = target_date.weekday()
    result = await session.execute(
        select(ServiceBlock)
        .where(
            ServiceBlock.restaurant_id == restaurant_id,
            ServiceBlock.day_of_week == day_of_week,
            ServiceBlock.is_active == True,  # noqa: E712
        )
        .order_by(ServiceBlock.display_order, ServiceBlock.start_time)  # type: ignore[arg-type]
    )
    blocks = list(result.scalars().all())

    # Fetch zone assignments for all blocks in one query
    if blocks:
        block_ids = [b.id for b in blocks]
        zone_result = await session.execute(
            select(ServiceBlockZone).where(col(ServiceBlockZone.service_block_id).in_(block_ids))
        )
        zone_rows = zone_result.scalars().all()
        zone_map: dict[str, list[str]] = {}
        for row in zone_rows:
            zone_map.setdefault(row.service_block_id, []).append(row.zone_id)
    else:
        zone_map = {}

    return [
        ResolvedBlock(block=block, open_zone_ids=zone_map.get(block.id) or None) for block in blocks
    ]


def find_block_for_time(
    blocks: list[ResolvedBlock],
    target_time: time,
) -> ResolvedBlock | None:
    """Find the open block whose ``[start_time, end_time)`` range contains *target_time*.

    Returns ``None`` if no matching open block is found.
    """
    for resolved in blocks:
        if resolved.block.block_type != "open":
            continue
        if resolved.block.start_time <= target_time < resolved.block.end_time:
            return resolved
    return None


def calculate_end_time(
    reserved_at: datetime,
    resolved: ResolvedBlock | None,
) -> datetime:
    """Return ``reserved_at + block.default_duration_minutes``.

    Falls back to 90 minutes when *resolved* is ``None`` or its
    ``default_duration_minutes`` is not set.
    """
    minutes = _DEFAULT_DURATION_MINUTES
    if resolved is not None and resolved.block.default_duration_minutes:
        minutes = resolved.block.default_duration_minutes
    return reserved_at + timedelta(minutes=minutes)


def generate_available_slots(resolved: "ResolvedBlock | ServiceBlock") -> list[time]:
    """Generate all valid reservation start times within an open block.

    Iterates from ``block.start_time`` by ``block.slot_interval_minutes``,
    stopping when ``candidate + default_duration_minutes > block.end_time``.
    Returns an empty list for closed blocks.
    """
    block = resolved.block if isinstance(resolved, ResolvedBlock) else resolved

    if block.block_type != "open":
        return []

    interval = block.slot_interval_minutes or 30
    duration = block.default_duration_minutes or _DEFAULT_DURATION_MINUTES

    # Convert times to minutes-since-midnight for arithmetic
    start_mins = block.start_time.hour * 60 + block.start_time.minute
    end_mins = block.end_time.hour * 60 + block.end_time.minute

    slots: list[time] = []
    current = start_mins
    while current + duration <= end_mins:
        slots.append(time(hour=current // 60, minute=current % 60))
        current += interval

    return slots


def snap_to_interval(requested_time: time, block: "ResolvedBlock | ServiceBlock") -> time | None:
    """Find the nearest valid slot to *requested_time* within *block*.

    Returns ``None`` if no valid slot is reasonably close (i.e. the requested
    time falls outside the block's valid slot range).
    """
    valid_slots = generate_available_slots(block)
    if not valid_slots:
        return None

    req_mins = requested_time.hour * 60 + requested_time.minute
    best: time | None = None
    best_dist = float("inf")

    for slot in valid_slots:
        slot_mins = slot.hour * 60 + slot.minute
        dist = abs(req_mins - slot_mins)
        if dist < best_dist:
            best_dist = dist
            best = slot

    return best


def resolve_blocks_in_memory(
    target_date: date,
    blocks_by_dow: dict[int, list[ServiceBlock]],
    active_overrides: list["ServiceBlockOverride"],
    restaurant_id: str,
    zone_assignments: dict[str, list[str]] | None = None,
) -> list[ResolvedBlock]:
    """Pure in-memory variant of :func:`resolve_service_blocks`.

    Operates on pre-fetched data so the caller can avoid N+1 queries when
    resolving blocks for an entire date range.
    """
    # Check overrides first — same precedence as resolve_service_blocks
    for override in active_overrides:
        if override.start_date <= target_date <= override.end_date and override.blocks:
            synthesised: list[ResolvedBlock] = []
            for idx, block_def in enumerate(override.blocks):
                raw_zones = block_def.get("open_zone_ids")
                zone_ids = raw_zones if raw_zones else None
                sb = ServiceBlock(
                    id=f"override-{override.id}-{idx}",
                    restaurant_id=restaurant_id,
                    day_of_week=target_date.weekday(),
                    name=block_def.get("name", override.name),
                    block_type=block_def.get("block_type", "open"),
                    start_time=time.fromisoformat(block_def["start_time"]),
                    end_time=time.fromisoformat(block_def["end_time"]),
                    max_covers=block_def.get("max_covers"),
                    default_duration_minutes=block_def.get("default_duration_minutes"),
                    is_active=True,
                    display_order=idx,
                    slot_interval_minutes=block_def.get("slot_interval_minutes", 30),
                )
                synthesised.append(ResolvedBlock(block=sb, open_zone_ids=zone_ids))
            return synthesised

    # Fall back to normal day-of-week blocks
    normal_blocks = blocks_by_dow.get(target_date.weekday(), [])
    za = zone_assignments or {}
    return [
        ResolvedBlock(block=block, open_zone_ids=za.get(block.id) or None)
        for block in normal_blocks
    ]
