"""RLS isolation integration tests.

Exercises the JWT-claim-based tenant isolation introduced by Alembic
revision ``0030_rls_via_pg_session_jwt`` and the
``_inject_tenant_jwt_claims`` event hook in ``app.db.session``.

Architecture under test:
- Backend connects with one pool as ``neondb_owner`` (BYPASSRLS).
- The ``after_begin`` SQLAlchemy hook runs ``SET LOCAL ROLE
  authenticated`` plus ``set_config('request.jwt.claims', …, true)`` on
  every transaction begin for a user-facing request.
- ``pg_session_jwt`` reads the claims from the GUC (PostgREST-compatible
  fallback mode); RLS policies key on
  ``auth.session() ->> 'activeRestaurantId'``.

Requirements:
- ``RUN_RLS_TESTS=1`` (opt-in; the suite isn't free — it talks to Neon).
- ``NEON_DATABASE_URL_DIRECT`` set to the owner-pool direct DSN.
- ``TEST_TENANT_ID`` + ``TEST_TENANT_ID_2`` — two real restaurant ids on
  the branch.
- Branch is at alembic ``0030_rls_via_pg_session_jwt`` or later.
- Branch has Data API enabled (so ``authenticated`` role + membership
  exists). See ``backend/OPERATIONS.md`` → "Neon RLS" §1.
"""

import json
import os
import uuid
from datetime import UTC, datetime

import asyncpg
import pytest

RUN = os.getenv("RUN_RLS_TESTS") == "1"
OWNER_DSN = os.getenv("NEON_DATABASE_URL_DIRECT")
RID1 = os.getenv("TEST_TENANT_ID")
RID2 = os.getenv("TEST_TENANT_ID_2")

pytestmark = pytest.mark.skipif(
    not RUN or not OWNER_DSN or not RID1 or not RID2 or RID1 == RID2,
    reason=(
        "RLS tests require RUN_RLS_TESTS=1 + NEON_DATABASE_URL_DIRECT + two "
        "distinct TEST_TENANT_ID / TEST_TENANT_ID_2 env vars"
    ),
)


def _claims(restaurant_id: str | None, sub: str = "test-user") -> str:
    payload: dict[str, object] = {"sub": sub, "role": "owner"}
    if restaurant_id is not None:
        payload["activeRestaurantId"] = restaurant_id
    return json.dumps(payload)


@pytest.mark.asyncio
async def test_tenant_isolation_read_write() -> None:
    """SET LOCAL ROLE authenticated + claim for A → only A's rows visible."""
    assert RID1 and RID2 and OWNER_DSN

    conn = await asyncpg.connect(OWNER_DSN)
    row_id = str(uuid.uuid4())
    try:
        # Seed a row under tenant A as the owner role (BYPASSRLS).
        await conn.execute(
            "INSERT INTO common_question (id, restaurant_id, question_key, "
            "question_text, is_answered, created_at) "
            "VALUES ($1, $2, $3, $4, $5, $6)",
            row_id,
            RID1,
            "rls-test-" + row_id[:8],
            "Isolation probe",
            False,
            datetime.now(UTC).replace(tzinfo=None),
        )

        # Read under SET LOCAL ROLE authenticated + claim scoped to A.
        async with conn.transaction():
            await conn.execute("SET LOCAL ROLE authenticated")
            await conn.execute(
                "SELECT set_config('request.jwt.claims', $1, true)",
                _claims(RID1),
            )
            visible = await conn.fetch("SELECT id FROM common_question WHERE id = $1", row_id)
        assert len(visible) == 1, "tenant A should see its own row"

        # Read under SET LOCAL ROLE authenticated + claim scoped to B.
        async with conn.transaction():
            await conn.execute("SET LOCAL ROLE authenticated")
            await conn.execute(
                "SELECT set_config('request.jwt.claims', $1, true)",
                _claims(RID2),
            )
            leaked = await conn.fetch("SELECT id FROM common_question WHERE id = $1", row_id)
        assert leaked == [], "tenant B must NOT see tenant A's row"
    finally:
        await conn.execute("DELETE FROM common_question WHERE id = $1", row_id)
        await conn.close()


@pytest.mark.asyncio
async def test_fail_closed_without_claims() -> None:
    """Missing `activeRestaurantId` claim → policy returns zero rows."""
    assert RID1 and OWNER_DSN

    conn = await asyncpg.connect(OWNER_DSN)
    row_id = str(uuid.uuid4())
    try:
        await conn.execute(
            "INSERT INTO common_question (id, restaurant_id, question_key, "
            "question_text, is_answered, created_at) "
            "VALUES ($1, $2, $3, $4, $5, $6)",
            row_id,
            RID1,
            "rls-fail-closed-" + row_id[:8],
            "Fail-closed probe",
            False,
            datetime.now(UTC).replace(tzinfo=None),
        )

        # Claims set but no activeRestaurantId.
        async with conn.transaction():
            await conn.execute("SET LOCAL ROLE authenticated")
            await conn.execute(
                "SELECT set_config('request.jwt.claims', $1, true)",
                _claims(None),
            )
            rows = await conn.fetch("SELECT id FROM common_question WHERE id = $1", row_id)
        assert rows == [], "missing claim must return zero rows (fail-closed)"

        # No claims object at all.
        async with conn.transaction():
            await conn.execute("SET LOCAL ROLE authenticated")
            rows = await conn.fetch("SELECT id FROM common_question WHERE id = $1", row_id)
        assert rows == [], "no JWT context at all must also return zero rows (fail-closed)"
    finally:
        await conn.execute("DELETE FROM common_question WHERE id = $1", row_id)
        await conn.close()


@pytest.mark.asyncio
async def test_with_check_rejects_mismatched_insert() -> None:
    """INSERT with restaurant_id ≠ claim → WITH CHECK rejects."""
    assert RID1 and RID2 and OWNER_DSN

    conn = await asyncpg.connect(OWNER_DSN)
    bad_id = str(uuid.uuid4())
    try:
        async with conn.transaction():
            await conn.execute("SET LOCAL ROLE authenticated")
            await conn.execute(
                "SELECT set_config('request.jwt.claims', $1, true)",
                _claims(RID1),
            )
            with pytest.raises(asyncpg.PostgresError):
                await conn.execute(
                    "INSERT INTO common_question "
                    "(id, restaurant_id, question_key, question_text, "
                    "is_answered, created_at) "
                    "VALUES ($1, $2, $3, $4, $5, $6)",
                    bad_id,
                    RID2,  # mismatched: claim says RID1, row says RID2
                    "rls-with-check-" + bad_id[:8],
                    "should be blocked",
                    False,
                    datetime.now(UTC).replace(tzinfo=None),
                )
    finally:
        await conn.close()


@pytest.mark.asyncio
async def test_owner_pool_bypasses_rls_for_internal_traffic() -> None:
    """Without SET LOCAL ROLE, queries run as neondb_owner (BYPASSRLS).

    This is the intended path for Restate handlers, cron sweeps, and
    webhook-by-id lookups — they have no user JWT and rely on
    application-layer ``WHERE restaurant_id = …`` filters.
    """
    assert RID1 and OWNER_DSN

    conn = await asyncpg.connect(OWNER_DSN)
    row_id = str(uuid.uuid4())
    try:
        await conn.execute(
            "INSERT INTO common_question (id, restaurant_id, question_key, "
            "question_text, is_answered, created_at) "
            "VALUES ($1, $2, $3, $4, $5, $6)",
            row_id,
            RID1,
            "rls-bypass-" + row_id[:8],
            "Bypass probe",
            False,
            datetime.now(UTC).replace(tzinfo=None),
        )

        # Default role is neondb_owner — BYPASSRLS in effect.
        rolname = await conn.fetchval("SELECT current_user")
        bypass = await conn.fetchval("SELECT rolbypassrls FROM pg_roles WHERE rolname=current_user")
        assert rolname == "neondb_owner"
        assert bypass is True

        rows = await conn.fetch("SELECT id FROM common_question WHERE id = $1", row_id)
        assert len(rows) == 1, "owner pool must see the row regardless of claims"
    finally:
        await conn.execute("DELETE FROM common_question WHERE id = $1", row_id)
        await conn.close()


@pytest.mark.asyncio
async def test_authenticated_role_is_nobypassrls_and_member_of_owner() -> None:
    """Sanity check: the SET ROLE target must be NOBYPASSRLS, and
    neondb_owner must be a member of it (the Data API enable step does
    both — this test catches a branch where step 1 was skipped)."""
    assert OWNER_DSN

    conn = await asyncpg.connect(OWNER_DSN)
    try:
        bypass = await conn.fetchval(
            "SELECT rolbypassrls FROM pg_roles WHERE rolname = 'authenticated'"
        )
        assert bypass is False, (
            "authenticated role has rolbypassrls=true — RLS will NOT enforce. "
            "Open a Neon support ticket — the role attribute is platform-managed."
        )

        is_member = await conn.fetchval(
            """
            SELECT EXISTS (
                SELECT 1
                  FROM pg_auth_members am
                  JOIN pg_roles r ON r.oid = am.roleid
                  JOIN pg_roles m ON m.oid = am.member
                 WHERE m.rolname = 'neondb_owner'
                   AND r.rolname = 'authenticated'
            )
            """
        )
        assert is_member, (
            "neondb_owner is not a member of authenticated — SET LOCAL ROLE will "
            "fail. Enable the Data API on this branch (Neon Console → Data API "
            "→ Enable). Setup guide: backend/OPERATIONS.md → 'Neon RLS' §1."
        )
    finally:
        await conn.close()
