#!/usr/bin/env python3
import asyncio
import os
import sys

import asyncpg

TENANT_TABLES = [
    "zone",
    "floor_table",
    "reservation",
    "customer",
    "order",
    "menu_item",
    "service_block",
    "service_block_override",
    "table_combination",
    "conversation",
    "knowledge_document",
    "faq_entry",
    "common_question",
    "chair",
    "order_item",
    "message",
    "service_block_zones",
    "combined_chair_config",
    "notification",
    "whatsapp_account",
    "whatsapp_template",
]


async def main() -> int:
    dsn = os.getenv("NEON_DATABASE_URL_DIRECT") or (sys.argv[1] if len(sys.argv) > 1 else None)
    if not dsn:
        print("NEON_DATABASE_URL_DIRECT env var or DSN argument is required", file=sys.stderr)
        return 2

    conn = await asyncpg.connect(dsn)
    try:
        # Check relrowsecurity and relforcerowsecurity
        rows = await conn.fetch(
            """
            SELECT c.relname, c.relrowsecurity, c.relforcerowsecurity
            FROM pg_class c
            JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE n.nspname = 'public' AND c.relkind = 'r'
              AND c.relname = ANY($1::text[])
            ORDER BY c.relname
            """,
            TENANT_TABLES,
        )
        missing_rls = [r["relname"] for r in rows if not r["relrowsecurity"]]
        missing_force = [r["relname"] for r in rows if not r["relforcerowsecurity"]]

        # Check tenant_isolation policy exists per table
        pol_rows = await conn.fetch(
            """
            SELECT tablename, policyname
            FROM pg_policies
            WHERE schemaname = 'public' AND tablename = ANY($1::text[])
            ORDER BY tablename
            """,
            TENANT_TABLES,
        )
        policy_by_table: dict[str, set[str]] = {}
        for r in pol_rows:
            policy_by_table.setdefault(r["tablename"], set()).add(r["policyname"])
        missing_policy = [
            t for t in TENANT_TABLES if "tenant_isolation" not in policy_by_table.get(t, set())
        ]

        exit_code = 0
        if missing_rls:
            print("Tables missing RLS enabled:", ", ".join(missing_rls))
            exit_code = 1
        if missing_force:
            print("Tables missing FORCE RLS:", ", ".join(missing_force))
            exit_code = 1
        if missing_policy:
            print("Tables missing tenant_isolation policy:", ", ".join(missing_policy))
            exit_code = 1

        if exit_code == 0:
            print("RLS audit passed: all tenant tables have RLS + FORCE + policy")
        return exit_code
    finally:
        await conn.close()


if __name__ == "__main__":
    raise SystemExit(asyncio.run(main()))
