from __future__ import annotations

import asyncio
import os
from datetime import UTC, datetime
from unittest import IsolatedAsyncioTestCase, TestCase

_ = os.environ.setdefault("NEON_DATABASE_URL", "postgresql://user:pass@localhost:5432/testdb")
_ = os.environ.setdefault("BETTER_AUTH_URL", "http://localhost:3000")
_ = os.environ.setdefault("INTERNAL_EMAIL_SHARED_SECRET", "test-secret")
_ = os.environ.setdefault("REDIS_URL", "redis://localhost:6379/0")

from app.realtime import SSE_HEADERS, sse_response  # noqa: E402
from app.realtime.connection_manager import MAX_QUEUE_SIZE, ConnectionManager  # noqa: E402
from app.realtime.events import DomainEvent, EventType  # noqa: E402


class TestDomainEventModel(TestCase):
    def test_event_type_enum_values_are_stable(self) -> None:
        self.assertEqual(EventType.reservation_created.value, "reservation_created")
        self.assertEqual(EventType.reservation_updated.value, "reservation_updated")
        self.assertEqual(EventType.reservation_status_changed.value, "reservation_status_changed")
        self.assertEqual(EventType.reservation_cancelled.value, "reservation_cancelled")
        self.assertEqual(EventType.order_created.value, "order_created")
        self.assertEqual(EventType.order_status_changed.value, "order_status_changed")
        self.assertEqual(EventType.order_cancelled.value, "order_cancelled")
        self.assertEqual(EventType.customer_updated.value, "customer_updated")
        self.assertEqual(EventType.customer_visit_recorded.value, "customer_visit_recorded")
        self.assertEqual(EventType.customer_noshow_recorded.value, "customer_noshow_recorded")

    def test_occurred_at_defaults_to_current_utc_time(self) -> None:
        before = datetime.now(UTC)
        event = DomainEvent(
            type=EventType.reservation_created,
            restaurant_id="rest-1",
            entity_id="res-1",
        )
        after = datetime.now(UTC)

        self.assertIsNotNone(event.occurred_at.tzinfo)
        self.assertGreaterEqual(event.occurred_at, before)
        self.assertLessEqual(event.occurred_at, after)

    def test_payload_serialization_and_json_round_trip(self) -> None:
        payload = {
            "reservation_number": "R-42",
            "covers": 4,
            "meta": {"source": "unit-test", "tags": ["vip", "window"]},
        }
        event = DomainEvent(
            type=EventType.reservation_updated,
            restaurant_id="rest-1",
            entity_id="res-42",
            payload=payload,
        )

        dumped = event.model_dump_json()
        restored = DomainEvent.model_validate_json(dumped)

        self.assertEqual(restored.type, EventType.reservation_updated)
        self.assertEqual(restored.restaurant_id, "rest-1")
        self.assertEqual(restored.entity_id, "res-42")
        self.assertEqual(restored.payload, payload)
        self.assertEqual(restored.model_dump(mode="json"), event.model_dump(mode="json"))


class TestConnectionManager(IsolatedAsyncioTestCase):
    def test_register_creates_connection_and_increments_active_count(self) -> None:
        connection_manager = ConnectionManager()

        connection = connection_manager.register(restaurant_id="rest-1")

        self.assertEqual(connection.restaurant_id, "rest-1")
        self.assertIsInstance(connection.queue, asyncio.Queue)
        self.assertEqual(connection_manager.active_count, 1)

    def test_disconnect_removes_connection_and_sends_sentinel(self) -> None:
        connection_manager = ConnectionManager()
        connection = connection_manager.register(restaurant_id="rest-1")

        connection_manager.disconnect(connection.id)

        self.assertEqual(connection_manager.active_count, 0)
        self.assertIsNone(connection.queue.get_nowait())

    async def test_enqueue_delivers_only_to_matching_restaurant_id(self) -> None:
        connection_manager = ConnectionManager()
        matching_connection = connection_manager.register(restaurant_id="rest-1")
        other_connection = connection_manager.register(restaurant_id="rest-2")
        event = DomainEvent(
            type=EventType.order_created,
            restaurant_id="rest-1",
            entity_id="order-123",
        )

        delivered = await connection_manager.enqueue(event)

        self.assertEqual(delivered, 1)
        self.assertEqual(matching_connection.queue.get_nowait(), event)
        self.assertTrue(other_connection.queue.empty())

    async def test_enqueue_drops_slow_clients_and_evicts_them(self) -> None:
        connection_manager = ConnectionManager()
        slow_connection = connection_manager.register(restaurant_id="rest-1")
        healthy_connection = connection_manager.register(restaurant_id="rest-1")

        for _ in range(MAX_QUEUE_SIZE):
            slow_connection.queue.put_nowait(
                DomainEvent(
                    type=EventType.customer_updated,
                    restaurant_id="rest-1",
                    entity_id="customer-1",
                )
            )

        event = DomainEvent(
            type=EventType.customer_visit_recorded,
            restaurant_id="rest-1",
            entity_id="customer-2",
        )

        delivered = await connection_manager.enqueue(event)

        self.assertEqual(delivered, 1)
        self.assertEqual(connection_manager.active_count, 1)
        self.assertEqual(healthy_connection.queue.get_nowait(), event)

    def test_active_count_property_reflects_current_connections(self) -> None:
        connection_manager = ConnectionManager()
        first = connection_manager.register(restaurant_id="rest-1")
        second = connection_manager.register(restaurant_id="rest-1")

        self.assertEqual(connection_manager.active_count, 2)
        connection_manager.disconnect(first.id)
        self.assertEqual(connection_manager.active_count, 1)
        connection_manager.disconnect(second.id)
        self.assertEqual(connection_manager.active_count, 0)


class TestSSEHelpers(IsolatedAsyncioTestCase):
    def test_sse_headers_have_expected_defaults(self) -> None:
        self.assertEqual(
            SSE_HEADERS,
            {
                "Cache-Control": "no-cache",
                "X-Accel-Buffering": "no",
            },
        )

    async def test_sse_response_returns_streaming_response_with_expected_headers(self) -> None:
        async def _generator():
            yield "data: hello\n\n"

        response = sse_response(_generator(), extra_headers={"X-Test": "1"})

        self.assertEqual(response.media_type, "text/event-stream")
        self.assertEqual(response.headers["cache-control"], "no-cache")
        self.assertEqual(response.headers["x-accel-buffering"], "no")
        self.assertEqual(response.headers["x-test"], "1")
