from __future__ import annotations

import asyncio
import time
from datetime import UTC, datetime, timedelta
from typing import Any

import httpx
import logfire
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select

from app.models.whatsapp import TemplateStatus, WhatsAppAccount, WhatsAppTemplate
from app.services.encryption import decrypt_token


class WhatsAppSendError(Exception):
    """Base error for outbound WhatsApp failures."""

    def __init__(
        self,
        message: str,
        *,
        is_transient: bool = False,
        meta_error_code: int | None = None,
    ) -> None:
        super().__init__(message)
        self.is_transient = is_transient
        self.meta_error_code = meta_error_code


class SessionWindowExpired(WhatsAppSendError):
    """Non-template send attempted outside the 24-hour customer service window."""

    def __init__(self) -> None:
        super().__init__(
            "24-hour session window expired; use send_template instead",
            is_transient=False,
        )


class WhatsAppSender:
    """Outbound WhatsApp message delivery via Meta Cloud API.

    Centralizes credential loading, session window enforcement, retry
    classification, payload construction, and Logfire telemetry for
    all outbound WhatsApp sends.
    """

    GRAPH_API_BASE = "https://graph.facebook.com"
    MAX_TEXT_LENGTH = 4096
    SESSION_WINDOW_HOURS = 24
    _TRANSIENT_STATUSES = {429, 500, 503}
    _MAX_RETRIES = 3
    _BASE_BACKOFF = 1.0
    _TRANSIENT_META_ERROR_CODES = {131030}
    _PERMANENT_META_ERROR_CODES = {131031}
    _SUPPORTED_MEDIA_TYPES = {"image", "video", "document", "audio"}

    def __init__(
        self,
        session: AsyncSession,
        api_version: str = "v21.0",
        client: httpx.AsyncClient | None = None,
    ) -> None:
        self.session = session
        self.api_version = api_version
        self._client = client or httpx.AsyncClient(timeout=30.0)
        self._owns_client = client is None

    async def aclose(self) -> None:
        if self._owns_client:
            await self._client.aclose()

    async def _load_credentials(self, phone_number_id: str) -> tuple[str, WhatsAppAccount]:
        result = await self.session.execute(
            select(WhatsAppAccount).where(WhatsAppAccount.phone_number_id == phone_number_id)
        )
        account = result.scalar_one_or_none()
        if account is None:
            raise ValueError(f"WhatsApp account not found for phone_number_id={phone_number_id}")
        if not account.is_active:
            raise ValueError(f"WhatsApp account is inactive for phone_number_id={phone_number_id}")
        access_token = decrypt_token(account.access_token_encrypted)
        return access_token, account

    async def _send_request(
        self,
        phone_number_id: str,
        payload: dict[str, Any],
        access_token: str,
    ) -> dict[str, Any]:
        url = f"{self.GRAPH_API_BASE}/{self.api_version}/{phone_number_id}/messages"
        headers = {
            "Authorization": f"Bearer {access_token}",
            "Content-Type": "application/json",
        }
        message_type = str(payload.get("type", "unknown"))

        for attempt in range(1, self._MAX_RETRIES + 1):
            started_at = time.monotonic()
            try:
                response = await self._client.post(url, json=payload, headers=headers)
            except httpx.RequestError as exc:
                latency_ms = round((time.monotonic() - started_at) * 1000, 2)
                is_last_attempt = attempt >= self._MAX_RETRIES
                log_method = logfire.error if is_last_attempt else logfire.warning
                log_method(
                    "whatsapp_send_request_error",
                    phone_number_id=phone_number_id,
                    message_type=message_type,
                    attempt=attempt,
                    latency_ms=latency_ms,
                    error=str(exc),
                    success=False,
                )
                if is_last_attempt:
                    raise WhatsAppSendError(
                        f"WhatsApp request failed after {attempt} attempts: {exc}",
                        is_transient=True,
                    ) from exc
                await asyncio.sleep(self._BASE_BACKOFF * (2 ** (attempt - 1)))
                continue

            latency_ms = round((time.monotonic() - started_at) * 1000, 2)
            response_json = self._safe_json(response)
            meta_error = response_json.get("error") if isinstance(response_json, dict) else None
            meta_error_code = self._meta_error_code(meta_error)
            is_transient = self._is_transient_error(response.status_code, meta_error_code)

            if response.is_success:
                wamid = self._extract_wamid(response_json)
                logfire.info(
                    "whatsapp_send_success",
                    phone_number_id=phone_number_id,
                    message_type=message_type,
                    attempt=attempt,
                    latency_ms=latency_ms,
                    wamid=wamid,
                    success=True,
                )
                if wamid is not None:
                    response_json.setdefault("wamid", wamid)
                return response_json

            error_message = self._format_error_message(response, meta_error)
            log_method = (
                logfire.warning if is_transient and attempt < self._MAX_RETRIES else logfire.error
            )
            log_method(
                "whatsapp_send_failed",
                phone_number_id=phone_number_id,
                message_type=message_type,
                attempt=attempt,
                latency_ms=latency_ms,
                status_code=response.status_code,
                meta_error_code=meta_error_code,
                error_message=error_message,
                success=False,
            )

            if is_transient and attempt < self._MAX_RETRIES:
                await asyncio.sleep(self._BASE_BACKOFF * (2 ** (attempt - 1)))
                continue

            raise WhatsAppSendError(
                error_message,
                is_transient=is_transient,
                meta_error_code=meta_error_code,
            )

        raise WhatsAppSendError(
            "WhatsApp request exhausted retries",
            is_transient=True,
        )

    def _check_session_window(self, last_customer_message_at: datetime) -> bool:
        last_message_at = self._as_utc_naive(last_customer_message_at)
        expires_at = last_message_at + timedelta(hours=self.SESSION_WINDOW_HOURS)
        return datetime.now(UTC).replace(tzinfo=None) <= expires_at

    async def mark_as_read(
        self,
        phone_number_id: str,
        message_id: str,
        *,
        typing_indicator: bool = True,
    ) -> None:
        """Mark an inbound message as read and optionally show a typing indicator.

        Best-effort: single attempt, never raises. Failures are logged
        but do not propagate — callers should not gate agent work on this.

        No session window enforcement — read receipts are not outbound
        messages and can be sent at any time.
        """
        try:
            access_token, account = await self._load_credentials(phone_number_id)
            payload: dict[str, Any] = {
                "messaging_product": "whatsapp",
                "status": "read",
                "message_id": message_id,
            }
            if typing_indicator:
                payload["typing_indicator"] = {"type": "text"}

            url = f"{self.GRAPH_API_BASE}/{self.api_version}/{phone_number_id}/messages"
            headers = {
                "Authorization": f"Bearer {access_token}",
                "Content-Type": "application/json",
            }
            with logfire.span(
                "whatsapp_mark_as_read",
                restaurant_id=account.restaurant_id,
                phone_number_id=phone_number_id,
                message_id=message_id,
                typing_indicator=typing_indicator,
            ):
                response = await self._client.post(url, json=payload, headers=headers)
                if not response.is_success:
                    body = self._safe_json(response)
                    logfire.warning(
                        "whatsapp_mark_read_failed",
                        phone_number_id=phone_number_id,
                        message_id=message_id,
                        status_code=response.status_code,
                        error=body.get("error"),
                    )
        except Exception as exc:  # noqa: BLE001
            logfire.warning(
                "whatsapp_mark_read_error",
                phone_number_id=phone_number_id,
                message_id=message_id,
                error=str(exc),
            )

    async def send_text(
        self,
        phone_number_id: str,
        to: str,
        text: str,
        *,
        last_customer_message_at: datetime | None = None,
    ) -> dict[str, Any]:
        self._ensure_session_window(last_customer_message_at)
        normalized_text = text.strip()
        if not normalized_text:
            raise ValueError("WhatsApp text message body cannot be empty")
        if len(normalized_text) > self.MAX_TEXT_LENGTH:
            raise ValueError(f"WhatsApp text message exceeds {self.MAX_TEXT_LENGTH} characters")

        access_token, account = await self._load_credentials(phone_number_id)
        payload: dict[str, Any] = {
            "messaging_product": "whatsapp",
            "recipient_type": "individual",
            "to": to,
            "type": "text",
            "text": {"body": normalized_text},
        }
        with logfire.span(
            "whatsapp_send_text",
            restaurant_id=account.restaurant_id,
            phone_number_id=phone_number_id,
            message_type="text",
        ):
            return await self._send_request(phone_number_id, payload, access_token)

    async def send_interactive(
        self,
        phone_number_id: str,
        to: str,
        body: str,
        buttons: list[dict[str, Any]] | None = None,
        sections: list[dict[str, Any]] | None = None,
        *,
        last_customer_message_at: datetime | None = None,
    ) -> dict[str, Any]:
        self._ensure_session_window(last_customer_message_at)
        if bool(buttons) == bool(sections):
            raise ValueError("Provide either buttons or sections for an interactive message")

        access_token, account = await self._load_credentials(phone_number_id)
        interactive: dict[str, Any] = {"body": {"text": body.strip()}}
        if buttons:
            interactive["type"] = "button"
            interactive["action"] = {"buttons": buttons}
        else:
            interactive["type"] = "list"
            interactive["action"] = {"sections": sections}

        payload: dict[str, Any] = {
            "messaging_product": "whatsapp",
            "recipient_type": "individual",
            "to": to,
            "type": "interactive",
            "interactive": interactive,
        }
        with logfire.span(
            "whatsapp_send_interactive",
            restaurant_id=account.restaurant_id,
            phone_number_id=phone_number_id,
            message_type="interactive",
        ):
            return await self._send_request(phone_number_id, payload, access_token)

    async def send_template(
        self,
        phone_number_id: str,
        to: str,
        template_name: str,
        language_code: str,
        components: list[dict[str, Any]] | None = None,
    ) -> dict[str, Any]:
        access_token, account = await self._load_credentials(phone_number_id)
        template_result = await self.session.execute(
            select(WhatsAppTemplate)
            .where(WhatsAppTemplate.whatsapp_account_id == account.id)
            .where(WhatsAppTemplate.name == template_name)
            .where(WhatsAppTemplate.language == language_code)
        )
        template = template_result.scalar_one_or_none()
        if template is None:
            raise ValueError(
                f"WhatsApp template not found: name={template_name}, language={language_code}"
            )
        if template.status != TemplateStatus.APPROVED.value:
            raise ValueError(
                f"WhatsApp template is not approved: name={template_name}, status={template.status}"
            )

        template_payload: dict[str, Any] = {
            "name": template_name,
            "language": {"code": language_code},
        }
        if components:
            template_payload["components"] = components
        payload: dict[str, Any] = {
            "messaging_product": "whatsapp",
            "recipient_type": "individual",
            "to": to,
            "type": "template",
            "template": template_payload,
        }

        with logfire.span(
            "whatsapp_send_template",
            restaurant_id=account.restaurant_id,
            phone_number_id=phone_number_id,
            message_type="template",
            template_name=template_name,
        ):
            return await self._send_request(phone_number_id, payload, access_token)

    async def send_media(
        self,
        phone_number_id: str,
        to: str,
        media_type: str,
        media_url_or_id: str,
        caption: str | None = None,
        *,
        last_customer_message_at: datetime | None = None,
    ) -> dict[str, Any]:
        self._ensure_session_window(last_customer_message_at)
        if media_type not in self._SUPPORTED_MEDIA_TYPES:
            raise ValueError(
                f"Unsupported WhatsApp media_type={media_type}. "
                f"Expected one of {sorted(self._SUPPORTED_MEDIA_TYPES)}"
            )

        access_token, account = await self._load_credentials(phone_number_id)
        media_payload: dict[str, Any] = self._build_media_payload(media_type, media_url_or_id)
        if caption:
            if media_type == "audio":
                raise ValueError("WhatsApp audio messages do not support captions")
            media_payload["caption"] = caption

        payload: dict[str, Any] = {
            "messaging_product": "whatsapp",
            "recipient_type": "individual",
            "to": to,
            "type": media_type,
            media_type: media_payload,
        }
        with logfire.span(
            "whatsapp_send_media",
            restaurant_id=account.restaurant_id,
            phone_number_id=phone_number_id,
            message_type=media_type,
        ):
            return await self._send_request(phone_number_id, payload, access_token)

    def _ensure_session_window(self, last_customer_message_at: datetime | None) -> None:
        if last_customer_message_at is None or not self._check_session_window(
            last_customer_message_at
        ):
            raise SessionWindowExpired()

    @staticmethod
    def _safe_json(response: httpx.Response) -> dict[str, Any]:
        try:
            payload = response.json()
        except ValueError:
            return {}
        return payload if isinstance(payload, dict) else {}

    @staticmethod
    def _extract_wamid(response_json: dict[str, Any]) -> str | None:
        messages = response_json.get("messages")
        if not isinstance(messages, list) or not messages:
            return None
        first_message = messages[0]
        if not isinstance(first_message, dict):
            return None
        wamid = first_message.get("id")
        return str(wamid) if isinstance(wamid, str) else None

    @classmethod
    def _is_transient_error(cls, status_code: int, meta_error_code: int | None) -> bool:
        if meta_error_code in cls._TRANSIENT_META_ERROR_CODES:
            return True
        if meta_error_code in cls._PERMANENT_META_ERROR_CODES:
            return False
        return status_code in cls._TRANSIENT_STATUSES

    @staticmethod
    def _meta_error_code(meta_error: Any) -> int | None:
        if not isinstance(meta_error, dict):
            return None
        code = meta_error.get("code")
        return code if isinstance(code, int) else None

    @staticmethod
    def _format_error_message(response: httpx.Response, meta_error: Any) -> str:
        if isinstance(meta_error, dict):
            message = meta_error.get("message")
            if isinstance(message, str) and message:
                return message
        body = response.text.strip()
        if body:
            return body[:500]
        return f"Meta API request failed with status {response.status_code}"

    @staticmethod
    def _build_media_payload(media_type: str, media_url_or_id: str) -> dict[str, str]:
        locator_key = "link" if media_url_or_id.startswith(("https://", "http://")) else "id"
        return {locator_key: media_url_or_id}

    @staticmethod
    def _as_utc_naive(value: datetime) -> datetime:
        if value.tzinfo is None:
            return value
        return value.astimezone(UTC).replace(tzinfo=None)
