import contextlib
import json
import uuid
from collections.abc import AsyncGenerator
from typing import Any

import httpx
import logfire
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic_ai import AgentRunResultEvent
from pydantic_ai.exceptions import UsageLimitExceeded
from pydantic_ai.messages import (
    FunctionToolCallEvent,
    FunctionToolResultEvent,
    ModelMessage,
    PartDeltaEvent,
    PartStartEvent,
    TextPart,
    TextPartDelta,
)
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from app.agents.deps import AgentDeps, CallerIdentity
from app.agents.errors import classify_error, get_error_message
from app.agents.history import deserialize_messages, is_conversation_stale, serialize_messages
from app.agents.restaurant import detect_language, restaurant_agent
from app.agents.runner import AGENT_USAGE_LIMITS, build_agent_deps
from app.auth.better_auth import CurrentUser, get_current_user
from app.auth.tenant import get_current_restaurant, get_tenant_session
from app.db.base import utcnow
from app.db.session import commit_write, get_session
from app.dependencies import get_restate_client
from app.models.conversation import Conversation, Message
from app.models.conversation_verification import ConversationVerification
from app.models.customer import Customer
from app.models.restaurant import Restaurant

router = APIRouter(prefix="/chat", tags=["chat"])
public_router = APIRouter(prefix="/public/chat", tags=["public-chat"])
limiter = Limiter(key_func=get_remote_address)


class ChatRequest(BaseModel):
    message: str
    conversation_id: str | None = None
    agent_type: str | None = None  # unused, kept for API compatibility


class PublicChatRequest(BaseModel):
    restaurant_slug: str
    message: str
    conversation_id: str | None = None
    agent_type: str | None = None  # unused, kept for API compatibility


class ChatSessionSummary(BaseModel):
    id: str
    channel: str
    status: str
    started_at: str
    ended_at: str | None = None
    message_count: int = 0
    last_message_preview: str | None = None
    last_message_at: str | None = None
    context_enabled: bool = True
    customer_name: str | None = None
    identity_key: str | None = None


class ChatSessionsResponse(BaseModel):
    sessions: list[ChatSessionSummary]
    next_cursor: str | None = None
    total: int = 0


class ChatMessageRead(BaseModel):
    id: str
    role: str
    content: str
    created_at: str


class ChatStatsResponse(BaseModel):
    total_sessions: int = 0
    avg_messages_per_session: float = 0.0
    channel_distribution: list[dict[str, Any]] = []


async def _build_website_caller(
    session: AsyncSession,
    conversation: Conversation,
    restaurant_id: str,
) -> CallerIdentity:
    """Construct the caller identity used by both the public widget and the
    dashboard preview chat.

    Both surfaces MUST behave identically conversationally: same channel,
    same verification gate, same tool surface, same prompt. The only
    difference is how the restaurant is resolved (slug for public, active
    org for dashboard). Anything channel-driven (tool filtering, system
    prompt verification block, name-search restrictions, orphan-reservation
    guard) flows from this single helper.
    """
    caller = CallerIdentity(
        channel="website",
        identity_key=conversation.identity_key,
        customer_id=None,
        verified=False,
        conversation_id=conversation.id,
    )
    verification_result = await session.execute(
        select(ConversationVerification)
        .where(
            ConversationVerification.conversation_id == conversation.id,
            ConversationVerification.verified_at.isnot(None),  # type: ignore[union-attr]
            ConversationVerification.expires_at > utcnow(),
        )
        .order_by(ConversationVerification.verified_at.desc())  # type: ignore[union-attr]
        .limit(1)
    )
    verification = verification_result.scalar_one_or_none()
    if verification is None:
        return caller

    caller.verified = True
    customer_result = await session.execute(
        select(Customer).where(
            Customer.email == verification.email,
            Customer.restaurant_id == restaurant_id,
        )
    )
    customer = customer_result.scalar_one_or_none()
    if customer is not None:
        caller.customer_id = customer.id
        caller.customer_name = customer.name
        caller.customer_email = customer.email
    return caller


async def _get_or_create_conversation(
    session: AsyncSession,
    conversation_id: str | None,
    restaurant_id: str,
    agent_type: str,
    channel: str = "dashboard",
    language: str | None = None,
) -> Conversation:
    if conversation_id:
        result = await session.execute(
            select(Conversation).where(
                Conversation.identity_key == conversation_id,
                Conversation.restaurant_id == restaurant_id,
            )
        )
        existing = result.scalar_one_or_none()
        if existing:
            return existing

        # Backward compatibility for older rows where identity key equals row id.
        legacy_result = await session.execute(
            select(Conversation).where(
                Conversation.id == conversation_id,
                Conversation.restaurant_id == restaurant_id,
            )
        )
        legacy_existing = legacy_result.scalar_one_or_none()
        if legacy_existing:
            return legacy_existing

        conv = Conversation(
            id=str(uuid.uuid4()),
            identity_key=conversation_id,
            restaurant_id=restaurant_id,
            channel=channel,
            agent_type=agent_type,
            status="active",
            language=language,
        )
        session.add(conv)
        try:
            await commit_write(session)
            return conv
        except Exception:
            # Concurrent insert winner already created this identity key.
            await session.rollback()
            dedup_result = await session.execute(
                select(Conversation).where(
                    Conversation.identity_key == conversation_id,
                    Conversation.restaurant_id == restaurant_id,
                )
            )
            dedup_existing = dedup_result.scalar_one_or_none()
            if dedup_existing:
                return dedup_existing
            raise

    generated_conversation_id = str(uuid.uuid4())
    conv = Conversation(
        id=generated_conversation_id,
        identity_key=generated_conversation_id,
        restaurant_id=restaurant_id,
        channel=channel,
        agent_type=agent_type,
        status="active",
        language=language,
    )
    session.add(conv)
    await commit_write(session)
    return conv


async def _load_structured_history(
    session: AsyncSession,
    conversation_id: str,
    restaurant_id: str,
) -> list[ModelMessage]:
    """Load structured message history for a conversation within a tenant."""
    result = await session.execute(
        select(Conversation).where(
            Conversation.id == conversation_id,
            Conversation.restaurant_id == restaurant_id,
        )
    )
    conv = result.scalar_one_or_none()
    if conv is None:
        return []
    if not conv.context_enabled:
        logfire.info("conversation_context_disabled", conversation_id=conversation_id)
        return []
    if is_conversation_stale(conv.last_message_at):
        logfire.info("conversation_history_expired", conversation_id=conversation_id)
        return []
    return deserialize_messages(conv.message_history_json)


async def _save_messages(
    session: AsyncSession,
    conversation_id: str,
    restaurant_id: str,
    user_text: str,
    assistant_text: str,
    tool_calls: list[dict[str, Any]] | None,
    agent_type: str = "",
    token_count: int | None = None,
    messages_json: str | None = None,
) -> None:
    # Load conversation and use its restaurant_id for messages
    result = await session.execute(select(Conversation).where(Conversation.id == conversation_id))
    conv = result.scalar_one_or_none()
    if conv is None:
        return
    user_msg = Message(
        conversation_id=conversation_id,
        role="user",
        content=user_text,
        restaurant_id=conv.restaurant_id,
    )
    assistant_msg = Message(
        conversation_id=conversation_id,
        role="assistant",
        content=assistant_text,
        tool_calls=tool_calls or None,
        token_count=token_count,
        restaurant_id=conv.restaurant_id,
    )
    session.add(user_msg)
    session.add(assistant_msg)

    # Update session metadata on the conversation
    result = await session.execute(select(Conversation).where(Conversation.id == conversation_id))
    conv = result.scalar_one_or_none()
    if conv:
        now = utcnow()
        # Get current count, default 0 for old rows
        current_count = getattr(conv, "message_count", None) or 0
        conv.message_count = current_count + 2
        conv.last_message_preview = assistant_text[:200] if assistant_text else None
        conv.last_message_at = now
        conv.ended_at = now
        if agent_type:
            conv.last_agent_type = agent_type
        if messages_json is not None:
            conv.message_history_json = messages_json
        session.add(conv)

    await commit_write(session)


async def _stream_agent_response(
    message: str,
    message_history: list[ModelMessage],
    deps: AgentDeps,
    conversation_id: str,
) -> AsyncGenerator[str, None]:
    """Stream an agent response via SSE.

    The unified restaurant_agent handles all capabilities (FAQ, reservation,
    takeaway).  Which tools are available is controlled by
    ``deps.enabled_capabilities``, filtered at runtime via ``prepare_tools``.
    Write durability is achieved at the tool level: reservation and order
    mutations route through ``restate_proxy()`` to Restate Virtual Objects
    for exactly-once semantics.  FAQ tools are read-only.
    """
    routed_evt = {"type": "agent_routed", "agent_type": "restaurant", "language": deps.language}
    yield f"data: {json.dumps(routed_evt)}\n\n"

    full_text = ""
    collected_tool_calls: list[dict[str, Any]] = []

    token_count: int | None = None
    all_messages: list[ModelMessage] = []

    try:
        logfire.info(
            "llm_stream_started",
            conversation_id=conversation_id,
            history_items=len(message_history),
        )
        with logfire.span(
            "agent_conversation",
            restaurant_id=deps.restaurant_id,
            conversation_id=conversation_id,
            language=deps.language,
        ):
            async for event in restaurant_agent.run_stream_events(
                message,
                deps=deps,
                message_history=message_history,
                usage_limits=AGENT_USAGE_LIMITS,
            ):
                if isinstance(event, FunctionToolCallEvent):
                    tool_call_event = {
                        "type": "tool_call",
                        "name": event.part.tool_name,
                        "args": event.part.args_as_dict(),
                    }
                    collected_tool_calls.append(tool_call_event)
                    yield f"data: {json.dumps(tool_call_event)}\n\n"
                    continue

                if isinstance(event, FunctionToolResultEvent):
                    tool_name = event.result.tool_name or ""
                    yield f"data: {json.dumps({'type': 'tool_result', 'name': tool_name})}\n\n"
                    continue

                if isinstance(event, PartStartEvent) and isinstance(event.part, TextPart):
                    initial_text = event.part.content
                    if initial_text:
                        full_text += initial_text
                        yield f"data: {json.dumps({'type': 'token', 'content': initial_text})}\n\n"
                    continue

                if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
                    text_delta = event.delta.content_delta
                    if text_delta:
                        full_text += text_delta
                        yield f"data: {json.dumps({'type': 'token', 'content': text_delta})}\n\n"
                    continue

                if isinstance(event, AgentRunResultEvent):
                    all_messages = event.result.all_messages()
                    usage = event.result.usage()
                    token_count = usage.total_tokens if hasattr(usage, "total_tokens") else None

    except UsageLimitExceeded as exc:
        error_code = "context_overflow"
        error_message = get_error_message(error_code, deps.language)
        logfire.warning(
            "agent_usage_limit_exceeded",
            conversation_id=conversation_id,
            restaurant_id=deps.restaurant_id,
            error=str(exc),
        )
        error_evt = json.dumps({"type": "error", "code": error_code, "message": error_message})
        yield f"data: {error_evt}\n\n"

    except Exception as exc:
        import traceback as tb

        tb_str = tb.format_exc()
        error_code = classify_error(exc)
        error_message = get_error_message(error_code, deps.language)
        logfire.error(
            "agent_stream_error",
            error=str(exc),
            error_code=error_code,
            traceback=tb_str,
        )
        error_evt = json.dumps({"type": "error", "code": error_code, "message": error_message})
        yield f"data: {error_evt}\n\n"

    logfire.info(
        "llm_stream_completed",
        conversation_id=conversation_id,
        token_chars=len(full_text),
        tool_calls=len(collected_tool_calls),
        token_count=token_count,
    )

    # Emit internal meta BEFORE done event so event_generator can
    # persist messages before signaling completion to the client.
    meta_payload = {
        "full_text": full_text,
        "tool_calls": collected_tool_calls,
        "token_count": token_count,
        "all_messages": serialize_messages(all_messages) if full_text else "",
    }
    yield f"__meta__:{json.dumps(meta_payload)}\n"

    logfire.info("sse_done_event_emitting", conversation_id=conversation_id)
    yield f"data: {json.dumps({'type': 'done', 'conversation_id': conversation_id})}\n\n"
    yield "data: [DONE]\n\n"


@router.post("/stream")
@limiter.limit("60/minute")
async def chat_stream(
    body: ChatRequest,
    request: Request,
    # ``_current_user`` is referenced purely for the auth-gate side effect of
    # the dependency; the preview chat itself does not read user attributes.
    _current_user: CurrentUser = Depends(get_current_user),
    restaurant: Restaurant = Depends(get_current_restaurant),
    session: AsyncSession = Depends(get_tenant_session),
    client: httpx.AsyncClient = Depends(get_restate_client),
) -> StreamingResponse:
    language = detect_language(body.message)
    with logfire.span(
        "chat_stream",
        restaurant_id=restaurant.id,
        language=language,
    ):
        logfire.info(
            "llm_observability_baseline",
            conversation_id=body.conversation_id,
        )
        # Dashboard preview chat MUST mimic the public widget conversation
        # 1:1 — same agent channel, same verification gate, same tool surface,
        # same prompt — so operators can validate the real guest experience.
        # We tag the Conversation row with channel="preview" so operator
        # sessions are filterable separately from real widget traffic, but
        # the CallerIdentity passed to the agent is identical to the widget
        # (see ``_build_website_caller``).
        conv = await _get_or_create_conversation(
            session,
            body.conversation_id,
            restaurant.id,
            "restaurant",
            channel="preview",
            language=language,
        )
        caller = await _build_website_caller(session, conv, restaurant.id)
        deps = build_agent_deps(
            session=session,
            restaurant=restaurant,
            http_client=client,
            caller=caller,
            language=language,
        )
        conversation_key = conv.id
        structured_history = await _load_structured_history(
            session, conversation_key, restaurant.id
        )

        full_text_holder: list[str] = []
        tool_calls_holder: list[list[dict[str, Any]]] = []
        token_count_holder: list[int | None] = []
        messages_json_holder: list[str] = []

        async def event_generator() -> AsyncGenerator[str, None]:
            async for chunk in _stream_agent_response(
                body.message,
                structured_history,
                deps,
                conversation_key,
            ):
                if chunk.startswith("__meta__:"):
                    meta = json.loads(chunk[len("__meta__:") :])
                    full_text_holder.append(meta["full_text"])
                    tool_calls_holder.append(meta["tool_calls"])
                    token_count_holder.append(meta.get("token_count"))
                    messages_json_holder.append(meta.get("all_messages", ""))
                    # Persist messages NOW, before the done event reaches
                    # the client, so the next request sees full history.
                    ft = meta["full_text"]
                    if ft:
                        mj = meta.get("all_messages") or None
                        await _save_messages(
                            session,
                            conversation_key,
                            restaurant.id,
                            body.message,
                            ft,
                            meta["tool_calls"],
                            agent_type="restaurant",
                            token_count=meta.get("token_count"),
                            messages_json=mj,
                        )
                else:
                    yield chunk

            # Mark conversation incomplete on client disconnect
            if await request.is_disconnected():
                result = await session.execute(
                    select(Conversation).where(Conversation.id == conversation_key)
                )
                conv = result.scalar_one_or_none()
                if conv and conv.status == "active":
                    conv.status = "incomplete"
                    session.add(conv)
                    await session.commit()

        from app.realtime import sse_response

        return sse_response(
            event_generator(),
            extra_headers={"X-Conversation-Id": conversation_key},
        )


@router.get("/sessions", response_model=ChatSessionsResponse)
async def list_chat_sessions(
    limit: int = Query(default=20, ge=1, le=100),
    cursor: str | None = Query(default=None),
    date_from: str | None = Query(default=None),
    date_to: str | None = Query(default=None),
    channel: str | None = Query(default=None),
    # agent_type filter removed — single unified agent now
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> Any:
    from datetime import datetime

    count_stmt = (
        select(func.count())
        .select_from(Conversation)
        .where(Conversation.restaurant_id == restaurant.id)
    )
    stmt = (
        select(Conversation, Customer.name.label("customer_name"))  # type: ignore[union-attr]
        .outerjoin(Customer, Conversation.customer_id == Customer.id)
        .where(Conversation.restaurant_id == restaurant.id)
    )

    if date_from:
        try:
            dt_from = datetime.fromisoformat(date_from)
            stmt = stmt.where(Conversation.created_at >= dt_from)
            count_stmt = count_stmt.where(Conversation.created_at >= dt_from)
        except ValueError:
            pass
    if date_to:
        try:
            dt_to = datetime.fromisoformat(date_to)
            stmt = stmt.where(Conversation.created_at <= dt_to)
            count_stmt = count_stmt.where(Conversation.created_at <= dt_to)
        except ValueError:
            pass
    if channel:
        stmt = stmt.where(Conversation.channel == channel)
        count_stmt = count_stmt.where(Conversation.channel == channel)
    if cursor:
        try:
            cursor_dt = datetime.fromisoformat(cursor)
            stmt = stmt.where(Conversation.created_at < cursor_dt)
        except ValueError:
            pass

    total_result = await session.execute(count_stmt)
    total = total_result.scalar() or 0

    stmt = stmt.order_by(Conversation.created_at.desc()).limit(limit + 1)  # type: ignore[attr-defined]
    result = await session.execute(stmt)
    rows = list(result.all())

    has_more = len(rows) > limit
    if has_more:
        rows = rows[:limit]

    sessions = []
    for conv, customer_name in rows:
        sessions.append(
            ChatSessionSummary(
                id=conv.id,
                channel=getattr(conv, "channel", "") or "",
                status=getattr(conv, "status", ""),
                started_at=(
                    conv.created_at.isoformat()
                    if hasattr(conv, "created_at") and conv.created_at
                    else ""
                ),
                ended_at=conv.ended_at.isoformat() if conv.ended_at else None,
                message_count=getattr(conv, "message_count", 0) or 0,
                last_message_preview=getattr(conv, "last_message_preview", None),
                last_message_at=(
                    conv.last_message_at.isoformat() if conv.last_message_at else None
                ),
                context_enabled=getattr(conv, "context_enabled", True),
                customer_name=customer_name,
                identity_key=conv.identity_key,
            )
        )

    next_cursor = None
    if has_more and rows:
        last_conv = rows[-1][0]
        if hasattr(last_conv, "created_at") and last_conv.created_at:
            next_cursor = last_conv.created_at.isoformat()

    return {"sessions": sessions, "next_cursor": next_cursor, "total": total}


class PatchConversationRequest(BaseModel):
    context_enabled: bool


@router.patch("/sessions/{session_id}")
async def patch_conversation(
    session_id: str,
    body: PatchConversationRequest,
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> ChatSessionSummary:
    result = await session.execute(
        select(Conversation, Customer.name.label("customer_name"))  # type: ignore[union-attr]
        .outerjoin(Customer, Conversation.customer_id == Customer.id)
        .where(
            Conversation.id == session_id,
            Conversation.restaurant_id == restaurant.id,
        )
    )
    row = result.one_or_none()
    if row is None:
        raise HTTPException(status_code=404, detail="Chat session not found")
    conv, customer_name = row

    conv.context_enabled = body.context_enabled
    if not body.context_enabled:
        # Ending the conversation forces the next inbound message to create a
        # fresh one with default context_enabled=True.  This is the intended
        # semantics of the dashboard "forget conversation" toggle.
        conv.status = "ended"
        conv.ended_at = utcnow()
        conv.message_history_json = None
    session.add(conv)
    await commit_write(session)

    return ChatSessionSummary(
        id=conv.id,
        channel=getattr(conv, "channel", "") or "",
        status=getattr(conv, "status", ""),
        started_at=(
            conv.created_at.isoformat() if hasattr(conv, "created_at") and conv.created_at else ""
        ),
        ended_at=conv.ended_at.isoformat() if conv.ended_at else None,
        message_count=getattr(conv, "message_count", 0) or 0,
        last_message_preview=getattr(conv, "last_message_preview", None),
        last_message_at=(conv.last_message_at.isoformat() if conv.last_message_at else None),
        context_enabled=conv.context_enabled,
        customer_name=customer_name,
        identity_key=conv.identity_key,
    )


@router.get("/sessions/{session_id}/messages", response_model=list[ChatMessageRead])
async def get_session_messages(
    session_id: str,
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> list[Any]:
    conv_result = await session.execute(
        select(Conversation).where(
            Conversation.id == session_id,
            Conversation.restaurant_id == restaurant.id,
        )
    )
    conv = conv_result.scalar_one_or_none()
    if conv is None:
        raise HTTPException(status_code=404, detail="Chat session not found")

    msg_result = await session.execute(
        select(Message)
        .where(Message.conversation_id == session_id)
        .order_by(Message.created_at.asc())  # type: ignore[attr-defined]
    )
    messages = msg_result.scalars().all()

    return [
        ChatMessageRead(
            id=getattr(m, "id", ""),
            role=m.role,
            content=m.content,
            created_at=(
                m.created_at.isoformat() if hasattr(m, "created_at") and m.created_at else ""
            ),
        )
        for m in messages
    ]


@router.get("/stats", response_model=ChatStatsResponse)
async def get_chat_stats(
    date_from: str | None = Query(default=None),
    date_to: str | None = Query(default=None),
    session: AsyncSession = Depends(get_tenant_session),
    restaurant: Restaurant = Depends(get_current_restaurant),
) -> Any:
    from datetime import datetime

    base_where = [Conversation.restaurant_id == restaurant.id]
    if date_from:
        with contextlib.suppress(ValueError):
            base_where.append(Conversation.created_at >= datetime.fromisoformat(date_from))
    if date_to:
        with contextlib.suppress(ValueError):
            base_where.append(Conversation.created_at <= datetime.fromisoformat(date_to))

    total_result = await session.execute(
        select(func.count()).select_from(Conversation).where(*base_where)
    )
    total_sessions = total_result.scalar() or 0

    avg_result = await session.execute(
        select(func.avg(Conversation.message_count)).where(*base_where)
    )
    avg_messages = float(avg_result.scalar() or 0.0)

    channel_result = await session.execute(
        select(Conversation.channel, func.count().label("count"))
        .where(*base_where)
        .group_by(Conversation.channel)
    )
    channel_dist = [{"channel": row[0], "count": row[1]} for row in channel_result.all()]

    return {
        "total_sessions": total_sessions,
        "avg_messages_per_session": round(avg_messages, 1),
        "channel_distribution": channel_dist,
    }


async def _resolve_restaurant_by_slug(slug: str, session: AsyncSession) -> Restaurant:
    """Load a restaurant by slug and set tenant context for downstream RLS queries."""
    result = await session.execute(select(Restaurant).where(Restaurant.slug == slug))
    restaurant = result.scalar_one_or_none()
    if not restaurant:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Restaurant not found")
    return restaurant


@public_router.post("/stream")
@limiter.limit("30/minute")
async def public_chat_stream(
    body: PublicChatRequest,
    request: Request,
    session: AsyncSession = Depends(get_session),
    client: httpx.AsyncClient = Depends(get_restate_client),
) -> StreamingResponse:
    language = detect_language(body.message)
    restaurant = await _resolve_restaurant_by_slug(body.restaurant_slug, session)
    with logfire.span(
        "public_chat_stream",
        restaurant_id=restaurant.id,
        restaurant_slug=body.restaurant_slug,
        language=language,
    ):
        logfire.info(
            "public_chat_stream_request",
            conversation_id=body.conversation_id,
            slug=body.restaurant_slug,
        )
        conv = await _get_or_create_conversation(
            session,
            body.conversation_id,
            restaurant.id,
            "restaurant",
            channel="website",
            language=language,
        )
        caller = await _build_website_caller(session, conv, restaurant.id)
        deps = build_agent_deps(
            session=session,
            restaurant=restaurant,
            http_client=client,
            caller=caller,
            language=language,
        )
        conversation_key = conv.id
        history = await _load_structured_history(session, conversation_key, restaurant.id)

        full_text_holder: list[str] = []
        tool_calls_holder: list[list[dict[str, Any]]] = []
        token_count_holder: list[int | None] = []
        messages_json_holder: list[str] = []

        async def event_generator() -> AsyncGenerator[str, None]:
            async for chunk in _stream_agent_response(
                body.message,
                history,
                deps,
                conversation_key,
            ):
                if chunk.startswith("__meta__:"):
                    meta = json.loads(chunk[len("__meta__:") :])
                    full_text_holder.append(meta["full_text"])
                    tool_calls_holder.append(meta["tool_calls"])
                    token_count_holder.append(meta.get("token_count"))
                    messages_json_holder.append(meta.get("all_messages", ""))
                    # Persist messages NOW, before the done event reaches
                    # the client, so the next request sees full history.
                    ft = meta["full_text"]
                    if ft:
                        mj = meta.get("all_messages") or None
                        await _save_messages(
                            session,
                            conversation_key,
                            restaurant.id,
                            body.message,
                            ft,
                            meta["tool_calls"],
                            agent_type="restaurant",
                            token_count=meta.get("token_count"),
                            messages_json=mj,
                        )
                else:
                    yield chunk

            # Mark conversation incomplete on client disconnect
            if await request.is_disconnected():
                result = await session.execute(
                    select(Conversation).where(Conversation.id == conversation_key)
                )
                conv = result.scalar_one_or_none()
                if conv and conv.status == "active":
                    conv.status = "incomplete"
                    session.add(conv)
                    await session.commit()

        from app.realtime import sse_response

        return sse_response(
            event_generator(),
            extra_headers={"X-Conversation-Id": conversation_key},
        )
