# standard imports
import logging
import enum
import re
from urllib.parse import urlparse

# local imports
from .jsonrpc import (
        jsonrpc_template,
        jsonrpc_result,
        DefaultErrorParser,
        )

logg = logging.getLogger(__name__)

error_parser = DefaultErrorParser()


class ConnType(enum.Enum):

    CUSTOM = 0x00
    HTTP = 0x100
    HTTP_SSL = 0x101
    WEBSOCKET = 0x200
    WEBSOCKET_SSL = 0x201
    UNIX = 0x1000


re_http = '^http(s)?://'
re_ws = '^ws(s)?://'
re_unix = '^ipc://'

def str_to_connspec(s):
    if s == 'custom':
        return ConnType.CUSTOM

    m = re.match(re_http, s)
    if m != None:
        if re.group(1) != None:
            return ConnType.HTTP_SSL
        return ConnType.HTTP

    m = re.match(re_ws, s)
    if m != None:
        if re.group(1) != None:
            return ConnType.WEBSOCKET_SSL
        return ConnType.WEBSOCKET


    m = re.match(re_unix, s)
    if m != None:
        return ConnType.UNIX

    raise ValueError('unknown connection type {}'.format(s))


def from_conntype(t):
    if t in [ConnType.HTTP, ConnType.HTTP_SSL]:
        return HTTPConnection
    elif t in [ConnType.UNIX]:
        return UnixConnection
    raise NotImplementedError(t)



class RPCConnection():

    __locations = {}
    __constructors = {}

    def __init__(self, location=None):
        logg.debug('creating connection {}'.format(location))
        self.location = location


    @staticmethod
    def register_location(location, chain_spec, tag='default', constructor=None, exist_ok=False):
        chain_str = str(chain_spec)
        if RPCConnection.__locations.get(chain_str) == None:
            RPCConnection.__locations[chain_str] = {}
            RPCConnection.__constructors[chain_str] = {}
        elif not exist_ok:
            v = RPCConnection.__locations[chain_str].get(tag)
            if v != None:
                raise ValueError('duplicate registration of tag {}:{}, requested {} already had {}'.format(chain_str, tag, location, v))
        conntype = str_to_connspec(location)
        u = urlparse(location)
        RPCConnection.__locations[chain_str][tag] = (conntype, u.path)
        if constructor != None:
            RPCConnection.__constructors[chain_str][tag] = constructor
            logg.info('registered rpc connection {} ({}:{}) as {} with custom constructor {}'.format(location, chain_str, tag, conntype, constructor))
        else:
            logg.info('registered rpc connection {} ({}:{}) as {}'.format(location, chain_str, tag, conntype))


    @staticmethod
    def connect(chain_spec, tag='default'):
        chain_str = str(chain_spec)
        c = RPCConnection.__locations[chain_str][tag]
        constructor = RPCConnection.__constructors[chain_str].get(tag)
        if constructor == None:
            constructor = from_conntype(c[0])
        return constructor(c[1])


class HTTPConnection(RPCConnection):
    pass


class UnixConnection(RPCConnection):
    pass


class JSONRPCHTTPConnection(HTTPConnection):

    def do(self, o, error_parser=error_parser):
        req = Request(
                self.location,
                method='POST',
                )
        req.add_header('Content-Type', 'application/json')
        data = json.dumps(o)
        logg.debug('(HTTP) send {}'.format(data))
        r = urlopen(req, data=data.encode('utf-8'))
        result = json.load(r)
        logg.debug('(HTTP) recv {}'.format(result))
        if o['id'] != result['id']:
            raise ValueError('RPC id mismatch; sent {} received {}'.format(o['id'], result['id']))
        return jsonrpc_result(result, error_parser)


class JSONRPCUnixConnection(UnixConnection):

    def do(self, o, error_parser=error_parser):
        conn = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0)
        conn.connect(self.location)
        data = json.dumps(o)

        logg.debug('unix socket send {}'.format(data))
        l = len(data)
        n = 0
        while n < l:
            c = conn.send(data.encode('utf-8'))
            if c == 0:
                s.close()
                raise IOError('unix socket ({}/{}) {}'.format(n, l, data))
            n += c
        r = b''
        while True:
            b = conn.recv(4096)
            if len(b) == 0:
                break
            r += b
        conn.close()
        logg.debug('unix socket recv {}'.format(r.decode('utf-8')))
        result = json.loads(r)
        if result['id'] != o['id']:
            raise ValueError('RPC id mismatch; sent {} received {}'.format(o['id'], result['id']))

        return jsonrpc_result(result, error_parser)
