from __future__ import annotations

import asyncio
import json
from datetime import UTC, datetime, timedelta
from typing import Any
from unittest.mock import patch

import jwt
import pytest
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials

from app.auth import better_auth
from app.config import Settings


@pytest.fixture(autouse=True)
def reset_jwks_cache():
    better_auth._jwks_cache = None
    better_auth._jwks_refresh_task = None
    yield
    better_auth._jwks_cache = None
    better_auth._jwks_refresh_task = None


def _settings() -> Settings:
    return Settings(
        NEON_DATABASE_URL="postgresql://user:pass@localhost:5432/testdb",
        REDIS_URL="redis://localhost:6379/0",
        BETTER_AUTH_URL="https://auth.example.com",
    )


def _credentials(token: str) -> HTTPAuthorizationCredentials:
    return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)


def _rsa_keypair(*, kid: str) -> tuple[Any, dict[str, object]]:
    private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
    jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(private_key.public_key()))  # type: ignore[attr-defined]
    jwk.update({"kid": kid, "use": "sig", "alg": "RS256"})
    return private_key, jwk


def _token(
    private_key: Any,
    *,
    kid: str,
    subject: str = "user-123",
    email: str = "owner@example.com",
    email_verified: bool = True,
    active_org_id: str | None = "org-123",
    active_team_id: str | None = "team-123",
    active_restaurant_id: str | None = "restaurant-123",
    role: str | None = "owner",
    expires_at: datetime | None = None,
) -> str:
    now = datetime.now(UTC)
    payload: dict[str, object] = {
        "sub": subject,
        "email": email,
        "emailVerified": email_verified,
        "activeOrganizationId": active_org_id,
        "activeTeamId": active_team_id,
        "activeRestaurantId": active_restaurant_id,
        "role": role,
        "iss": "https://auth.example.com",
        "aud": "https://frontend.example.com",
        "iat": int(now.timestamp()),
        "exp": int((expires_at or now + timedelta(minutes=5)).timestamp()),
    }
    return jwt.encode(payload, private_key, algorithm="RS256", headers={"kid": kid})


class _Response:
    def __init__(self, payload: dict[str, object]) -> None:
        self._payload = payload

    def raise_for_status(self) -> None:
        return None

    def json(self) -> dict[str, object]:
        return self._payload


class _AsyncClientFactory:
    def __init__(self, payloads: list[dict[str, object]], *, delay: float = 0.0) -> None:
        self._payloads = payloads
        self._delay = delay
        self.calls = 0

    def __call__(self, *args: object, **kwargs: object) -> _AsyncClientFactory:
        return self

    async def __aenter__(self) -> _AsyncClientFactory:
        return self

    async def __aexit__(self, *_args: object) -> bool:
        return False

    async def get(self, _url: str) -> _Response:
        self.calls += 1
        if self._delay:
            await asyncio.sleep(self._delay)
        index = min(self.calls - 1, len(self._payloads) - 1)
        return _Response(self._payloads[index])


@pytest.mark.asyncio
async def test_valid_jwt_populates_current_user() -> None:
    private_key, jwk = _rsa_keypair(kid="kid-1")
    token = _token(private_key, kid="kid-1")
    client_factory = _AsyncClientFactory([{"keys": [jwk]}])

    with patch("app.auth.better_auth.httpx.AsyncClient", client_factory):
        user = await better_auth.get_current_user(
            credentials=_credentials(token),
            settings=_settings(),
        )

    assert user == better_auth.CurrentUser(
        id="user-123",
        email="owner@example.com",
        email_verified=True,
        active_org_id="org-123",
        active_team_id="team-123",
        active_restaurant_id="restaurant-123",
        role="owner",
    )
    assert client_factory.calls == 1


@pytest.mark.asyncio
async def test_expired_jwt_returns_401() -> None:
    private_key, jwk = _rsa_keypair(kid="kid-1")
    token = _token(
        private_key,
        kid="kid-1",
        expires_at=datetime.now(UTC) - timedelta(seconds=1),
    )
    client_factory = _AsyncClientFactory([{"keys": [jwk]}, {"keys": [jwk]}])

    with (
        patch("app.auth.better_auth.httpx.AsyncClient", client_factory),
        pytest.raises(HTTPException) as exc_info,
    ):
        await better_auth.get_current_user(
            credentials=_credentials(token),
            settings=_settings(),
        )

    assert exc_info.value.status_code == 401
    assert exc_info.value.detail == "Invalid or expired token"
    assert client_factory.calls == 2


@pytest.mark.asyncio
async def test_unknown_kid_triggers_refresh_then_401_not_500() -> None:
    """Regression: a JWT signed with a `kid` missing from the cached JWKS must
    flow through the refresh-then-401 path, not raise a bare KeyError/500."""
    private_key_a, jwk_a = _rsa_keypair(kid="kid-a")
    _, jwk_b = _rsa_keypair(kid="kid-b")
    # Token signed with kid-a, but JWKS will only contain kid-b for both fetches.
    token = _token(private_key_a, kid="kid-a")
    client_factory = _AsyncClientFactory([{"keys": [jwk_b]}, {"keys": [jwk_b]}])

    with (
        patch("app.auth.better_auth.httpx.AsyncClient", client_factory),
        pytest.raises(HTTPException) as exc_info,
    ):
        await better_auth.get_current_user(
            credentials=_credentials(token),
            settings=_settings(),
        )

    assert exc_info.value.status_code == 401
    assert client_factory.calls == 2


@pytest.mark.asyncio
async def test_missing_sub_claim_returns_401() -> None:
    private_key, jwk = _rsa_keypair(kid="kid-1")
    now = datetime.now(UTC)
    token = jwt.encode(
        {
            "email": "owner@example.com",
            "emailVerified": True,
            "iat": int(now.timestamp()),
            "exp": int((now + timedelta(minutes=5)).timestamp()),
        },
        private_key,
        algorithm="RS256",
        headers={"kid": "kid-1"},
    )
    client_factory = _AsyncClientFactory([{"keys": [jwk]}, {"keys": [jwk]}])

    with (
        patch("app.auth.better_auth.httpx.AsyncClient", client_factory),
        pytest.raises(HTTPException) as exc_info,
    ):
        await better_auth.get_current_user(
            credentials=_credentials(token),
            settings=_settings(),
        )

    assert exc_info.value.status_code == 401
    assert exc_info.value.detail == "Invalid or expired token"
    assert client_factory.calls == 2


@pytest.mark.asyncio
async def test_concurrent_cold_start_shares_single_jwks_fetch() -> None:
    private_key, jwk = _rsa_keypair(kid="kid-1")
    token = _token(private_key, kid="kid-1")
    client_factory = _AsyncClientFactory([{"keys": [jwk]}], delay=0.05)
    settings = _settings()
    credentials = _credentials(token)

    with patch("app.auth.better_auth.httpx.AsyncClient", client_factory):
        users = await asyncio.gather(
            *[
                better_auth.get_current_user(credentials=credentials, settings=settings)
                for _ in range(8)
            ]
        )

    assert {user.id for user in users} == {"user-123"}
    assert client_factory.calls == 1


@pytest.mark.asyncio
async def test_key_rotation_refreshes_and_retries() -> None:
    private_a, jwk_a = _rsa_keypair(kid="shared-kid")
    private_b, jwk_b = _rsa_keypair(kid="shared-kid")
    token = _token(private_b, kid="shared-kid", role="admin")
    client_factory = _AsyncClientFactory([{"keys": [jwk_a]}, {"keys": [jwk_b]}])

    with patch("app.auth.better_auth.httpx.AsyncClient", client_factory):
        user = await better_auth.get_current_user(
            credentials=_credentials(token),
            settings=_settings(),
        )

    assert user.role == "admin"
    assert user.active_team_id == "team-123"
    assert client_factory.calls == 2
