import asyncio
import time
from functools import wraps

from .log import log_request_builder, logger


class LifespanProtocol:
    error_transition = 'Invalid lifespan state transition'

    def __init__(self, callable):
        self.callable = callable
        self.event_queue = asyncio.Queue()
        self.event_startup = asyncio.Event()
        self.event_shutdown = asyncio.Event()
        self.unsupported = False
        self.errored = False
        self.failure_startup = False
        self.failure_shutdown = False
        self.interrupt = False
        self.exc = None
        self.state = {}

    async def handle(self):
        try:
            await self.callable(
                {'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.3'}, 'state': self.state},
                self.receive,
                self.send,
            )
        except Exception as exc:
            self.errored = True
            self.exc = exc
            if self.failure_startup or self.failure_shutdown:
                return
            self.unsupported = True
            logger.warn(
                'ASGI Lifespan errored, continuing without Lifespan support '
                '(to avoid Lifespan completely use "asginl" interface)'
            )
        finally:
            self.event_startup.set()
            self.event_shutdown.set()

    async def startup(self):
        loop = asyncio.get_event_loop()
        _handler_task = loop.create_task(self.handle())

        await self.event_queue.put({'type': 'lifespan.startup'})
        await self.event_startup.wait()

        if self.failure_startup or (self.errored and not self.unsupported):
            self.interrupt = True

    async def shutdown(self):
        self.state.clear()

        if self.errored:
            return

        await self.event_queue.put({'type': 'lifespan.shutdown'})
        await self.event_shutdown.wait()

        if self.failure_shutdown or (self.errored and not self.unsupported):
            self.interrupt = True

    async def receive(self):
        return await self.event_queue.get()

    def _handle_startup_complete(self, message):
        assert not self.event_startup.is_set(), self.error_transition
        assert not self.event_shutdown.is_set(), self.error_transition
        self.event_startup.set()

    def _handle_startup_failed(self, message):
        assert not self.event_startup.is_set(), self.error_transition
        assert not self.event_shutdown.is_set(), self.error_transition
        self.event_startup.set()
        self.failure_startup = True
        if message.get('message'):
            logger.error(message['message'])

    def _handle_shutdown_complete(self, message):
        assert self.event_startup.is_set(), self.error_transition
        assert not self.event_shutdown.is_set(), self.error_transition
        self.event_shutdown.set()

    def _handle_shutdown_failed(self, message):
        assert self.event_startup.is_set(), self.error_transition
        assert not self.event_shutdown.is_set(), self.error_transition
        self.event_shutdown.set()
        self.failure_shutdown = True
        if message.get('message'):
            logger.error(message['message'])

    _event_handlers = {
        'lifespan.startup.complete': _handle_startup_complete,
        'lifespan.startup.failed': _handle_startup_failed,
        'lifespan.shutdown.complete': _handle_shutdown_complete,
        'lifespan.shutdown.failed': _handle_shutdown_failed,
    }

    async def send(self, message):
        handler = self._event_handlers[message['type']]
        handler(self, message)


def _callback_wrapper(callback, scope_opts, state, access_log_fmt=None):
    root_url_path = scope_opts.get('url_path_prefix') or ''

    def _runner(scope, proto):
        scope.update(root_path=root_url_path, state=state.copy())
        return callback(scope, proto.receive, proto.send)

    async def _http_logger(scope, proto):
        rt, mt = time.time(), time.perf_counter()
        try:
            rv = await _runner(scope, proto)
        finally:
            access_log(rt, mt, scope, proto.sent_response_code)
        return rv

    def _ws_logger(scope, proto):
        access_log(time.time(), time.perf_counter(), scope, 101)
        return _runner(scope, proto)

    def _logger(scope, proto):
        if scope['type'] == 'http':
            return _http_logger(scope, proto)
        return _ws_logger(scope, proto)

    access_log = _build_access_logger(access_log_fmt)
    wrapper = _logger if access_log_fmt else _runner
    wraps(callback)(wrapper)

    return wrapper


def _build_access_logger(fmt):
    logger = log_request_builder(fmt)

    def access_log(rt, mt, scope, resp_code):
        logger(
            rt,
            mt,
            {
                'addr_remote': scope['client'][0],
                'protocol': 'HTTP/' + scope['http_version'],
                'path': scope['path'],
                'qs': scope['query_string'],
                'method': scope.get('method', '-'),
                'scheme': scope['scheme'],
            },
            resp_code,
        )

    return access_log
