""" Websocket handler """
import base64
import sys
from typing import Any, Awaitable, Union, Tuple, Optional

import aiohttp
import uuid
from absl import logging
from aiohttp import WSMessage
from aiohttp.web_exceptions import HTTPForbidden
from aiohttp.web_request import Request
from aiohttp.web_response import Response
from aiohttp.web_ws import WebSocketResponse
from multidict import CIMultiDictProxy

from ...base_handler import BaseHandler


class AuthType:
    """ Supported auth types """
    BASIC = "basic"
    BEARER = "bearer"


async def auth_basic(credentials: str) -> bool:
    """ Auth using basic credentials """
    # noinspection PyBroadException
    try:
        data = base64.b64decode(credentials.encode("utf-8")).decode("utf-8")
        login, password = data.split(":")
        if login == "kraken" and password == "kraken":
            return True
    except Exception:
        logging.error("auth_basic error:", exc_info=sys.exc_info())
    return False


def parse_auth_header(headers: "CIMultiDictProxy[str]")\
        -> Tuple[Optional[str], Optional[str]]:
    """ Parse auth header, return AuthType and AuthData """
    auth_header = headers.get("Authorization") or None
    if auth_header is not None:
        try:
            auth_type, auth_data = auth_header.split(" ")
            return auth_type, auth_data
        except ValueError:
            logging.error(f"Could not parse Auth Header: '{auth_header}'")
    return None, None


async def authenticate(request: Request) -> bool:
    """ Authenticate request """
    auth_type, auth_data = parse_auth_header(request.headers)
    if auth_type is not None and auth_data is not None:
        logging.info(f"auth_type and auth_data are fine, handling")
        if auth_type.lower() == AuthType.BASIC:
            logging.info(f"Handling Basic auth")
            return await auth_basic(auth_data)
        elif auth_data.lower() == AuthType.BEARER:
            logging.warning(f"{AuthType.BEARER} is not implemented!")
            return False
        else:
            logging.warning(f"{auth_type} is not supported!")
    logging.warning(f"Either, auth_type or auth_data is null")
    return False


def auth_required(f) -> Any:
    """ Check if user is auth """
    async def wr(self, request: Request, *args, **kwargs)\
            -> Awaitable[Union[WebSocketResponse, Response]]:
        """ wrapper """
        is_auth = await authenticate(request) or self.PATH.find("?") >= 0
        if is_auth:
            return await f(self, request, *args, **kwargs)
        raise HTTPForbidden()
    return wr


class WSHandler(BaseHandler):
    """ DoA Handler """
    METHOD = "GET"
    PATH = "/api/v1/ws"

    ws = None
    connection_id = None

    async def handle_text(self, data: str) -> None:
        """ handle web socket data """
        logging.info(f"Binary data received: {data}")
        await self.ws.send_str(data)

    @staticmethod
    async def handle_binary(data: bytes) -> None:
        """ handle web socket data """
        logging.info(f"Binary data received: {data}")

    async def handle_message(self, msg: WSMessage) -> None:
        """ Handle incoming websocket packet """
        if msg.type == aiohttp.WSMsgType.ERROR:
            ex = self.ws.exception()
            exc_info = (ex.__class__, ex, ex.__traceback__)
            logging.error("Ws error", exc_info=exc_info)
        elif msg.type == aiohttp.WSMsgType.BINARY:
            await self.handle_binary(msg.data)
        elif msg.type == aiohttp.WSMsgType.TEXT:
            await self.handle_text(msg.data)

    @auth_required
    async def handler(self, request: Request) -> WebSocketResponse:
        """ The handler """
        connection_id = uuid.uuid4()

        logging.info(f"WS opened. connection_id: '{connection_id}'")

        self.ws = ws = WebSocketResponse()
        await ws.prepare(request)

        logging.info(f"WS prepared. connection_id: '{connection_id}'")

        async for msg in ws:
            # noinspection PyTypeChecker
            await self.handle_message(msg)

        logging.info(f"WS closed. connection_id: '{connection_id}'")
        return ws
