"""Better Auth JWT authentication dependency for FastAPI.

Validates Bearer JWTs issued by Better Auth (`jwt` plugin) against its JWKS
endpoint at `${BETTER_AUTH_URL}/api/auth/jwks`. JWKS is cached in memory and
refreshed atomically on key-mismatch (handles BA key rotation).

JWT payload shape (set by `jwt.definePayload` in `frontend/src/lib/auth.config.ts`):
    {
      "sub": "<user id>",
      "email": "...",
      "emailVerified": true,
      "activeOrganizationId": "..." | null,
      "activeTeamId": "..." | null,
      "role": "owner" | "admin" | "member" | null,
      "iat", "exp", "iss", "aud"
    }

Unlike the Clerk integration this replaces, there is NO user upsert path:
Better Auth writes user rows directly during sign-up; the Python backend
only reads. The previous `_upsert_user` helper is gone.
"""

import asyncio
from dataclasses import dataclass
from typing import Annotated, cast

import httpx
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.config import Settings, get_settings

_bearer = HTTPBearer()

# Module-level JWKS cache — avoids per-request fetches.
_jwks_cache: dict[str, object] | None = None
_jwks_refresh_lock = asyncio.Lock()
_jwks_refresh_task: asyncio.Task[dict[str, object]] | None = None

# Better Auth's jwt() plugin defaults to EdDSA / Ed25519. Configured to be
# defensive: accept the small set of asymmetric algorithms BA may issue.
_ACCEPTED_ALGORITHMS = ("EdDSA", "RS256", "ES256")


@dataclass(slots=True)
class CurrentUser:
    """Authenticated user + active tenancy state extracted from the JWT.

    `active_org_id`, `active_team_id`, `active_restaurant_id`, and `role`
    are `None` for users who have not yet completed onboarding (they have
    no organization).

    `active_restaurant_id` is the restaurant_id of the active team — added
    to the JWT so RLS policies key on it directly via
    `auth.session()->>'activeRestaurantId'`.
    """

    id: str
    email: str
    email_verified: bool
    active_org_id: str | None
    active_team_id: str | None
    active_restaurant_id: str | None
    role: str | None


async def _fetch_jwks(jwks_url: str) -> dict[str, object]:
    async with httpx.AsyncClient(timeout=10.0) as client:
        response = await client.get(jwks_url)
        _ = response.raise_for_status()
        return cast(dict[str, object], response.json())


async def _get_jwks(*, jwks_url: str, force_refresh: bool = False) -> dict[str, object]:
    """Returns the cached JWKS, refreshing atomically on demand.

    Concurrent callers awaiting a refresh share the same in-flight fetch
    (per the existing `auth-integration` concurrency-safe-refresh requirement).
    """
    global _jwks_cache  # noqa: PLW0603
    global _jwks_refresh_task  # noqa: PLW0603

    if _jwks_cache is not None and not force_refresh:
        return _jwks_cache

    async with _jwks_refresh_lock:
        if _jwks_cache is not None and not force_refresh:
            return _jwks_cache

        if _jwks_refresh_task is None or _jwks_refresh_task.done() or force_refresh:
            _jwks_refresh_task = asyncio.create_task(_fetch_jwks(jwks_url))
        refresh_task = _jwks_refresh_task

    jwks = await refresh_task
    async with _jwks_refresh_lock:
        _jwks_cache = jwks
        if _jwks_refresh_task is refresh_task:
            _jwks_refresh_task = None
    return jwks


def _parse_payload(payload: dict[str, object]) -> CurrentUser:
    sub = payload.get("sub")
    email = payload.get("email")
    if not isinstance(sub, str) or not sub:
        raise jwt.InvalidTokenError("Missing sub claim")
    if not isinstance(email, str):
        raise jwt.InvalidTokenError("Missing email claim")

    email_verified_raw = payload.get("emailVerified", False)
    email_verified = bool(email_verified_raw) if isinstance(email_verified_raw, bool) else False

    active_org_raw = payload.get("activeOrganizationId")
    active_team_raw = payload.get("activeTeamId")
    active_restaurant_raw = payload.get("activeRestaurantId")
    role_raw = payload.get("role")

    return CurrentUser(
        id=sub,
        email=email,
        email_verified=email_verified,
        active_org_id=active_org_raw if isinstance(active_org_raw, str) else None,
        active_team_id=active_team_raw if isinstance(active_team_raw, str) else None,
        active_restaurant_id=(
            active_restaurant_raw if isinstance(active_restaurant_raw, str) else None
        ),
        role=role_raw if isinstance(role_raw, str) else None,
    )


async def get_current_user(
    credentials: Annotated[HTTPAuthorizationCredentials, Depends(_bearer)],
    settings: Annotated[Settings, Depends(get_settings)],
) -> CurrentUser:
    """FastAPI dependency: validate Better Auth JWT and return CurrentUser.

    Behavior:
    - Verifies signature against cached JWKS.
    - On any signature / key-mismatch failure, force-refreshes JWKS once and
      retries (handles BA key rotation).
    - Returns 401 on invalid / expired / unsigned tokens.
    """
    token = credentials.credentials

    async def _decode(jwks: dict[str, object]) -> CurrentUser:
        header = jwt.get_unverified_header(token)
        kid = header.get("kid")
        jwk_set = jwt.PyJWKSet.from_dict(jwks)
        try:
            signing_key = jwk_set[kid] if kid else next(iter(jwk_set))
        except (KeyError, StopIteration) as exc:
            # Unknown kid (rotation) or empty JWKS — surface as a JWT error so the
            # outer retry/refresh-then-401 path handles it instead of bubbling 500.
            raise jwt.InvalidKeyError(f"Signing key not found: {exc}") from exc
        payload = cast(
            dict[str, object],
            jwt.decode(
                token,
                signing_key,
                algorithms=list(_ACCEPTED_ALGORITHMS),
                # Audience check intentionally skipped: BA defaults `aud` to
                # the same value as `iss` (the SvelteKit origin), which we
                # may or may not be (FastAPI is on a different origin).
                options={"verify_aud": False},
            ),
        )
        return _parse_payload(payload)

    jwks = await _get_jwks(jwks_url=settings.BETTER_AUTH_JWKS_URL)

    try:
        return await _decode(jwks)
    except jwt.PyJWTError:
        # Maybe a key rotation. Force-refresh and retry once.
        refreshed_jwks = await _get_jwks(
            jwks_url=settings.BETTER_AUTH_JWKS_URL,
            force_refresh=True,
        )
        try:
            return await _decode(refreshed_jwks)
        except jwt.PyJWTError as exc:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid or expired token",
                headers={"WWW-Authenticate": "Bearer"},
            ) from exc
