import asyncio
import multiprocessing
import sys
import time
from collections.abc import Callable, Sequence
from functools import wraps
from pathlib import Path
from typing import Any

from .._futures import _future_watcher_wrapper, _new_cbscheduler
from .._granian import ASGIWorker, RSGIWorker, WorkerSignal
from .._imports import dotenv
from .._internal import load_env
from .._types import SSLCtx
from ..asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
from ..errors import ConfigurationError, FatalError
from ..rsgi import _callback_wrapper as _rsgi_call_wrap, _callbacks_from_target as _rsgi_cbs_from_target
from .common import (
    _PY_312,
    _PYV,
    AbstractServer,
    AbstractWorker,
    HTTP1Settings,
    HTTP2Settings,
    HTTPModes,
    Interfaces,
    LogLevels,
    SSLProtocols,
    TaskImpl,
    logger,
)


class AsyncWorker(AbstractWorker):
    def __init__(self, parent, idx, target, args, sig):
        self._sig = sig
        self._loop = asyncio.get_event_loop()
        self._task = None
        self._wtask = None
        super().__init__(parent, idx, target, args)

    @staticmethod
    def wrap_target(target):
        @wraps(target)
        def wrapped(worker_id, sig, callback, sock, *args, **kwargs):
            loop = asyncio.get_event_loop()
            return target(worker_id, sig, callback, sock, loop, *args, **kwargs)

        return wrapped

    def _spawn(self, target, args):
        self._task = self._loop.create_task(target(*args))
        self._alive = True

    def _id(self):
        return id(self._task)

    async def _watcher(self):
        try:
            await self._task
        except BaseException:
            pass
        if not self.interrupt_by_parent:
            logger.error(f'Unexpected exit from worker-{self.idx + 1}')
            self.parent.interrupt_children.append(self.idx)
            self.parent.main_loop_interrupt.set()

    def _watch(self):
        self._wtask = self._loop.create_task(self._watcher())

    def start(self):
        logger.info(f'Spawning worker-{self.idx + 1} with {self._idl}: {self._id()}')
        self._watch()

    def is_alive(self):
        if not self._alive:
            return False
        return not self._task.done()

    def terminate(self):
        self._alive = False
        self.interrupt_by_parent = True
        self._sig.set()

    def kill(self):
        self._alive = False
        self.interrupt_by_parent = True
        self._task.cancel()

    def join(self, timeout=None):
        return asyncio.wait_for(self._task, timeout=timeout)


class Server(AbstractServer[AsyncWorker]):
    def __init__(
        self,
        target: Any,
        address: str = '127.0.0.1',
        port: int = 8000,
        uds: Path | None = None,
        interface: Interfaces = Interfaces.RSGI,
        blocking_threads: int | None = None,
        blocking_threads_idle_timeout: int = 30,
        runtime_threads: int = 1,
        runtime_blocking_threads: int | None = None,
        task_impl: TaskImpl = TaskImpl.asyncio,
        http: HTTPModes = HTTPModes.auto,
        websockets: bool = True,
        backlog: int = 128,
        backpressure: int | None = None,
        http1_settings: HTTP1Settings | None = None,
        http2_settings: HTTP2Settings | None = None,
        log_enabled: bool = True,
        log_level: LogLevels = LogLevels.info,
        log_dictconfig: dict[str, Any] | None = None,
        log_access: bool = False,
        log_access_format: str | None = None,
        ssl_cert: Path | None = None,
        ssl_key: Path | None = None,
        ssl_key_password: str | None = None,
        ssl_protocol_min: SSLProtocols = SSLProtocols.tls13,
        ssl_ca: Path | None = None,
        ssl_crl: list[Path] | None = None,
        ssl_client_verify: bool = False,
        url_path_prefix: str | None = None,
        factory: bool = False,
        static_path_route: Sequence[str] | None = None,
        static_path_mount: Sequence[Path] | None = None,
        static_path_dir_to_file: str | None = None,
        static_path_expires: int = 86400,
    ):
        super().__init__(
            target=target,
            address=address,
            port=port,
            uds=uds,
            interface=interface,
            blocking_threads=blocking_threads,
            blocking_threads_idle_timeout=blocking_threads_idle_timeout,
            runtime_threads=runtime_threads,
            runtime_blocking_threads=runtime_blocking_threads,
            task_impl=task_impl,
            http=http,
            websockets=websockets,
            backlog=backlog,
            backpressure=backpressure,
            http1_settings=http1_settings,
            http2_settings=http2_settings,
            log_enabled=log_enabled,
            log_level=log_level,
            log_dictconfig=log_dictconfig,
            log_access=log_access,
            log_access_format=log_access_format,
            ssl_cert=ssl_cert,
            ssl_key=ssl_key,
            ssl_key_password=ssl_key_password,
            ssl_protocol_min=ssl_protocol_min,
            ssl_ca=ssl_ca,
            ssl_crl=ssl_crl,
            ssl_client_verify=ssl_client_verify,
            url_path_prefix=url_path_prefix,
            factory=factory,
            static_path_route=static_path_route,
            static_path_mount=static_path_mount,
            static_path_dir_to_file=static_path_dir_to_file,
            static_path_expires=static_path_expires,
        )
        self.main_loop_interrupt = asyncio.Event()

    def _spawn_worker(self, idx, target, callback_loader) -> AsyncWorker:
        sig = WorkerSignal()

        return AsyncWorker(
            parent=self,
            idx=idx,
            target=target,
            args=(
                idx + 1,
                sig,
                callback_loader,
                self._shd,
                self.runtime_threads,
                self.runtime_blocking_threads,
                self.blocking_threads,
                self.blocking_threads_idle_timeout,
                self.backpressure,
                self.task_impl,
                self.http,
                self.http1_settings,
                self.http2_settings,
                self.websockets,
                self.static_path,
                self.log_access_format if self.log_access else None,
                self.ssl_ctx,
                {'url_path_prefix': self.url_path_prefix},
            ),
            sig=sig,
        )

    @staticmethod
    @AsyncWorker.wrap_target
    async def _spawn_asgi_worker(
        worker_id: int,
        shutdown_event: Any,
        callback: Any,
        sock: Any,
        loop: Any,
        runtime_threads: int,
        runtime_blocking_threads: int | None,
        blocking_threads: int,
        blocking_threads_idle_timeout: int,
        backpressure: int,
        task_impl: TaskImpl,
        http_mode: HTTPModes,
        http1_settings: HTTP1Settings | None,
        http2_settings: HTTP2Settings | None,
        websockets: bool,
        static_path: tuple[str, str, str | None, str | None] | None,
        log_access_fmt: str | None,
        ssl_ctx: SSLCtx,
        scope_opts: dict[str, Any],
    ):
        wcallback = _future_watcher_wrapper(_asgi_call_wrap(callback, scope_opts, {}, log_access_fmt))

        worker = ASGIWorker(
            worker_id,
            sock,
            None,
            runtime_threads,
            runtime_blocking_threads,
            blocking_threads,
            blocking_threads_idle_timeout,
            backpressure,
            http_mode,
            http1_settings,
            http2_settings,
            websockets,
            static_path,
            *ssl_ctx,
            (None, None),
        )
        serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
        scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
        await serve(scheduler, loop, shutdown_event)

    @staticmethod
    @AsyncWorker.wrap_target
    async def _spawn_asgi_lifespan_worker(
        worker_id: int,
        shutdown_event: Any,
        callback: Any,
        sock: Any,
        loop: Any,
        runtime_threads: int,
        runtime_blocking_threads: int | None,
        blocking_threads: int,
        blocking_threads_idle_timeout: int,
        backpressure: int,
        task_impl: TaskImpl,
        http_mode: HTTPModes,
        http1_settings: HTTP1Settings | None,
        http2_settings: HTTP2Settings | None,
        websockets: bool,
        static_path: tuple[str, str, str | None, str | None] | None,
        log_access_fmt: str | None,
        ssl_ctx: SSLCtx,
        scope_opts: dict[str, Any],
    ):
        lifespan_handler = LifespanProtocol(callback)
        wcallback = _future_watcher_wrapper(
            _asgi_call_wrap(callback, scope_opts, lifespan_handler.state, log_access_fmt)
        )

        await lifespan_handler.startup()
        if lifespan_handler.interrupt:
            logger.error('ASGI lifespan startup failed', exc_info=lifespan_handler.exc)
            raise FatalError('ASGI lifespan startup')

        worker = ASGIWorker(
            worker_id,
            sock,
            None,
            runtime_threads,
            runtime_blocking_threads,
            blocking_threads,
            blocking_threads_idle_timeout,
            backpressure,
            http_mode,
            http1_settings,
            http2_settings,
            websockets,
            static_path,
            *ssl_ctx,
            (None, None),
        )
        serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
        scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
        await serve(scheduler, loop, shutdown_event)
        await lifespan_handler.shutdown()

    @staticmethod
    @AsyncWorker.wrap_target
    async def _spawn_rsgi_worker(
        worker_id: int,
        shutdown_event: Any,
        callback: Any,
        sock: Any,
        loop: Any,
        runtime_threads: int,
        runtime_blocking_threads: int | None,
        blocking_threads: int,
        blocking_threads_idle_timeout: int,
        backpressure: int,
        task_impl: TaskImpl,
        http_mode: HTTPModes,
        http1_settings: HTTP1Settings | None,
        http2_settings: HTTP2Settings | None,
        websockets: bool,
        static_path: tuple[str, str, str | None, str | None] | None,
        log_access_fmt: str | None,
        ssl_ctx: SSLCtx,
        scope_opts: dict[str, Any],
    ):
        callback, callback_init, callback_del = _rsgi_cbs_from_target(callback)
        wcallback = _future_watcher_wrapper(_rsgi_call_wrap(callback, log_access_fmt))
        callback_init(loop)

        worker = RSGIWorker(
            worker_id,
            sock,
            None,
            runtime_threads,
            runtime_blocking_threads,
            blocking_threads,
            blocking_threads_idle_timeout,
            backpressure,
            http_mode,
            http1_settings,
            http2_settings,
            websockets,
            static_path,
            *ssl_ctx,
            (None, None),
        )
        serve = worker.serve_async_uds if sock.is_uds() else worker.serve_async
        scheduler = _new_cbscheduler(loop, wcallback, impl_asyncio=task_impl == TaskImpl.asyncio)
        await serve(scheduler, loop, shutdown_event)
        callback_del(loop)

    async def _respawn_workers(self, workers, spawn_target, target_loader, delay: float = 0):
        for idx in workers:
            self.respawned_wrks[idx] = time.monotonic()
            logger.info(f'Respawning worker-{idx + 1}')
            old_wrk = self.wrks.pop(idx)
            wrk = self._spawn_worker(idx=idx, target=spawn_target, callback_loader=target_loader)
            wrk.start()
            self.wrks.insert(idx, wrk)
            await asyncio.sleep(delay)
            logger.info(f'Stopping old worker-{idx + 1}')
            old_wrk.terminate()
            await old_wrk.join(self.workers_kill_timeout)
            if self.workers_kill_timeout:
                # the worker might still be reported alive after `join`, let's context switch
                if old_wrk.is_alive():
                    await asyncio.sleep(0.001)
                if old_wrk.is_alive():
                    logger.warning(f'Killing old worker-{idx + 1} after it refused to gracefully stop')
                    old_wrk.kill()
                    await old_wrk.join()

    async def _stop_workers(self):
        for wrk in self.wrks:
            wrk.terminate()

        for wrk in self.wrks:
            await wrk.join(self.workers_kill_timeout)
            if self.workers_kill_timeout:
                if wrk.is_alive():
                    logger.warning(f'Killing worker-{wrk.idx} after it refused to gracefully stop')
                    wrk.kill()

        self.wrks.clear()

    def startup(self, spawn_target, target_loader):
        logger.info('Starting granian (embedded)')
        self._init_shared_socket()
        proto = 'https' if self.ssl_ctx[0] else 'http'
        logger.info(f'Listening at: {proto}://{self._bind_addr_fmt}')

        load_env(self.env_files)
        self._call_hooks(self.hooks_startup)
        self._spawn_workers(spawn_target, target_loader)

    async def _serve_loop(self, spawn_target, target_loader):
        while True:
            await self.main_loop_interrupt.wait()
            if self.interrupt_signal:
                break

            if self.interrupt_children:
                break

            if self.reload_signal:
                await self._reload(spawn_target, target_loader)

    async def shutdown(self, exit_code=0):
        logger.info('Shutting down granian')
        await self._stop_workers()
        self._call_hooks(self.hooks_shutdown)
        if self.bind_uds and self.bind_uds.exists():
            self.bind_uds.unlink()

    async def _serve(self, spawn_target, target_loader):
        target = target_loader()
        self.startup(spawn_target, target)
        await self._serve_loop(spawn_target, target)
        await self.shutdown()

    async def serve(self, spawn_target: Callable[..., None] | None = None):
        def target_loader(*args, **kwargs):
            if self.factory:
                return self.target()
            return self.target

        default_spawners = {
            Interfaces.ASGI: self._spawn_asgi_lifespan_worker,
            Interfaces.ASGINL: self._spawn_asgi_worker,
            Interfaces.RSGI: self._spawn_rsgi_worker,
        }

        logger.warning('Embedded server is experimental!')

        if self.interface == Interfaces.WSGI:
            logger.error('WSGI is not supported in embedded mode')
            raise ConfigurationError('interface')

        if self.reload_on_changes:
            logger.error('The changes reloader is not supported in embedded mode')
            raise ConfigurationError('reload')

        if self.workers_rss:
            logger.error('The resource monitor is not supported in embedded mode')
            raise ConfigurationError('workers_max_rss')

        if self.metrics_enabled:
            logger.error('Metrics are not available in embedded mode')
            raise ConfigurationError('metrics_enabled')

        if not spawn_target:
            spawn_target = default_spawners[self.interface]

        if self.bind_uds and sys.platform == 'win32':
            logger.error('Unix Domain sockets are not available on Windows')
            raise ConfigurationError('uds')

        if self.blocking_threads > 1:
            logger.error('Blocking threads > 1 is not supported on ASGI and RSGI')
            raise ConfigurationError('blocking_threads')

        if self.websockets:
            if self.http == HTTPModes.http2:
                logger.info('Websockets are not supported on HTTP/2 only, ignoring')

        if self.env_files and dotenv is None:
            logger.error('Environment file(s) usage requires the granian[dotenv] extra')
            raise ConfigurationError('env_files')

        if self.blocking_threads_idle_timeout < 5 or self.blocking_threads_idle_timeout > 600:
            logger.error('Blocking threads idle timeout must be between 10 and 600 seconds')
            raise ConfigurationError('blocking_threads_idle_timeout')

        cpus = multiprocessing.cpu_count()
        if self.runtime_threads > cpus:
            logger.warning(
                'Configured number of Rust threads appears to be too high given the amount of CPU cores available. '
                'Mind that Rust threads are not involved in Python code execution, and they almost never be the '
                'limiting factor in scaling. Consider configuring the amount of blocking threads instead'
            )

        if self.task_impl == TaskImpl.rust:
            if _PYV >= _PY_312:
                self.task_impl = TaskImpl.asyncio
                logger.warning('Rust task implementation is not available on Python >= 3.12, falling back to asyncio')
            else:
                logger.warning('Rust task implementation is experimental!')

        await self._serve(spawn_target, target_loader)

    def stop(self):
        self.signal_handler_interrupt()

    def reload(self):
        self.signal_handler_reload()
