"""Database session factories — one pool, two effective roles via `SET LOCAL ROLE`.

Architecture:

- The connection pool always logs in as `neondb_owner` (BYPASSRLS) via
  `NEON_DATABASE_URL`. Server-internal traffic (Restate handlers, cron
  sweeps, public/webhook slug lookups) runs as that role and filters by
  `restaurant_id` in application code.
- For user-facing tenant traffic, the `after_begin` hook drops the
  current transaction's effective role to `authenticated` (a Neon-
  provisioned NOBYPASSRLS role that `neondb_owner` is automatically a
  member of) and injects the application-validated Better Auth JWT
  claims into the `request.jwt.claims` GUC, which the `pg_session_jwt`
  extension reads in its PostgREST-compatible fallback mode. RLS
  policies on tenant tables then key on
  `auth.session()->>'activeRestaurantId'`.

The trust boundary is the FastAPI dependency that calls
`get_tenant_session`. By the time we set `session.info["jwt_claims"]`,
the JWT signature has already been verified against Better Auth's JWKS
via `app.auth.better_auth.get_current_user`. The DB-level RLS layer
exists as defense in depth: a future bug that omits a
`.where(Model.restaurant_id == …)` clause in any tenant route will still
fail closed because the policy filters every row that doesn't match the
claim.

The `after_begin` listener reapplies on every transaction begin so
mid-request commits and savepoints cannot leak cross-tenant rows.
"""

import json
import os
from collections.abc import AsyncGenerator
from typing import Any
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from sqlalchemy import event, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session
from sqlmodel import SQLModel

_ASYNCPG_UNSUPPORTED = {"sslmode", "channel_binding"}


def _build_db_url(raw: str) -> str:
    """Normalise a Neon connection string for asyncpg.

    asyncpg does not understand libpq query parameters such as ``sslmode``
    or ``channel_binding``. We strip them and pass SSL via ``connect_args``.
    """
    url = raw.replace("postgresql://", "postgresql+asyncpg://", 1)
    parsed = urlparse(url)
    qs = parse_qs(parsed.query, keep_blank_values=True)
    for param in _ASYNCPG_UNSUPPORTED:
        qs.pop(param, None)
    cleaned = parsed._replace(query=urlencode(qs, doseq=True))
    return urlunparse(cleaned)


_DATABASE_URL = _build_db_url(os.environ.get("NEON_DATABASE_URL", ""))

engine = create_async_engine(
    _DATABASE_URL,
    pool_pre_ping=True,
    pool_size=5,
    max_overflow=10,
    connect_args={"ssl": "require"},
    echo=False,
)

async_session_factory = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


def _inject_tenant_jwt_claims(
    session: Session,
    transaction: Any,
    connection: Any,
) -> None:  # noqa: ARG001
    """SQLAlchemy ``after_begin`` event: scope the transaction to the
    `authenticated` role and feed pg_session_jwt the request's claims.

    Fires on every new transaction begin (including post-commit re-begins
    and savepoints). If ``session.info["jwt_claims"]`` is set, we:

      1. ``SET LOCAL ROLE authenticated`` — drops effective privileges to
         the Neon-provisioned NOBYPASSRLS role for the duration of the
         transaction. ``neondb_owner`` (our login role) is automatically
         a member of ``authenticated`` when the Data API is enabled on
         the branch, so the SET ROLE succeeds without any user-managed
         grants.
      2. ``set_config('request.jwt.claims', <json>, true)`` — populates
         the GUC that ``pg_session_jwt`` reads in fallback mode (when no
         JWK is configured at the cluster level). RLS policies then
         resolve ``auth.session()->>'activeRestaurantId'`` from the
         injected claims.

    Sessions without ``session.info["jwt_claims"]`` are a no-op — the
    transaction stays on ``neondb_owner`` (BYPASSRLS), which is exactly
    what server-internal traffic needs.
    """
    claims = session.info.get("jwt_claims")
    if not claims:
        return
    # Combine the role drop and the JWT-claims injection into a single
    # SELECT — two `set_config()` calls in one statement = one network
    # round-trip instead of two. On a typical Neon connection that saves
    # ~30-40 ms per transaction (each separate statement pays one full
    # request/response RTT). With 4-5 parallel API calls per page render,
    # the saving compounds to ~150 ms off the user-perceived page load.
    #
    # `set_config('role', 'authenticated', true)` is equivalent to
    # `SET LOCAL ROLE authenticated`: it sets the GUC that controls the
    # effective role for the rest of the transaction. neondb_owner is a
    # member of `authenticated` so the role switch succeeds without any
    # extra grants. `is_local=true` confines both settings to the current
    # transaction; commit/rollback restores neondb_owner automatically.
    claims_text = json.dumps(claims) if not isinstance(claims, str) else claims
    connection.execute(
        text(
            "SELECT set_config('role', 'authenticated', true), "
            "set_config('request.jwt.claims', :claims, true)"
        ),
        {"claims": claims_text},
    )


event.listen(Session, "after_begin", _inject_tenant_jwt_claims)


async def get_session() -> AsyncGenerator[AsyncSession, None]:
    """Owner-pool session (BYPASSRLS by default).

    Used by paths that legitimately need cross-tenant visibility:
      - ``get_current_restaurant`` membership lookup (Team / TeamMember /
        Restaurant joins span tables that are not tenant-scoped);
      - public-by-slug resolvers;
      - admin tooling and health probes.

    The dependency intentionally does not auto-commit. Write handlers
    must call ``commit_write`` explicitly so transaction ownership is
    obvious at call sites.
    """
    async with async_session_factory() as session:
        yield session


async def get_bypass_session() -> AsyncGenerator[AsyncSession, None]:
    """Explicit cross-tenant bypass (owner pool, BYPASSRLS).

    Use only when the caller has documented why crossing tenant
    boundaries is necessary (Restate cleanup sweeps, Mollie/WhatsApp
    webhook by-id lookups, etc.). Every call site should narrow with an
    explicit ``WHERE restaurant_id = …`` clause derived from the
    caller's payload before mutating tenant data.
    """
    async with async_session_factory() as session:
        yield session


async def get_restate_session() -> AsyncGenerator[AsyncSession, None]:
    """Restate handler session — owner pool, no tenant context.

    Restate handlers receive ``restaurant_id`` in their payload and
    filter application-side. RLS is not the protection layer for
    handlers; the invocation envelope is.
    """
    async with async_session_factory() as session:
        yield session


async def commit_write(session: AsyncSession) -> None:
    """Commit helper used by write flows to make ownership explicit."""
    await session.commit()


async def create_db_and_tables() -> None:
    """Dev-only helper: creates all tables without Alembic (not for production)."""
    async with engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.create_all)
