# standard imports
import enum
import re
import logging

# external imports
from hexathon import (
        strip_0x,
        pad,
        )

# local imports
from chainlib.hash import keccak256_string_to_hex
from .address import to_checksum_address

logg = logging.getLogger(__name__)


re_method = r'^[a-zA-Z0-9_]+$'


class ABIContractDecoder:

    def __init__(self):
        self.types = []
        self.contents = []


    def typ(self, v):
        if not isinstance(v, ABIContractType):
            raise TypeError('method type not valid; expected {}, got {}'.format(type(ABIContractType).__name__, type(v).__name__))
        self.types.append(v.value)
        self.__log_typ()


    def val(self, v):
        self.contents.append(v)


    def uint256(self, v):
        return int(v, 16)


    def address(self, v):
        a = strip_0x(v)[64-40:]
        return to_checksum_address(a)


    def __log_typ(self):
        logg.debug('types set to ({})'.format(','.join(self.types)))


    def decode(self):
        r = []
        for i in range(len(self.types)):
            m = getattr(self, self.types[i])
            r.append(m(self.contents[i]))
        return r


    def get(self):
        return self.decode()


    def __str__(self):
        return self.decode()



class ABIContractType(enum.Enum):

    BYTES32 = 'bytes32'
    UINT256 = 'uint256'
    ADDRESS = 'address'


class ABIContractEncoder:


    def __init__(self):
        self.types = []
        self.contents = []
        self.method_name = None
        self.method_contents = []


    def method(self, m):
        if re.match(re_method, m) == None:
            raise ValueError('Invalid method {}, must match regular expression {}'.format(re_method))
        self.method_name = m
        self.__log_method()


    def typ(self, v):
        if self.method_name == None:
            raise AttributeError('method name must be set before adding types')
        if not isinstance(v, ABIContractType):
            raise TypeError('method type not valid; expected {}, got {}'.format(type(ABIContractType).__name__, type(v).__name__))
        self.method_contents.append(v.value)
        self.__log_method()


    def __log_method(self):
        logg.debug('method set to {}'.format(self.get_method()))


    def __log_latest(self, v):
        l = len(self.types) - 1 
        logg.debug('Encoder added {} -> {} ({})'.format(v, self.contents[l], self.types[l].value))


    def uint256(self, v):
        v = int(v)
        b = v.to_bytes(32, 'big')
        self.contents.append(b.hex())
        self.types.append(ABIContractType.UINT256)
        self.__log_latest(v)


    def address(self, v):
        self.bytes_fixed(32, v, 20)
        self.types.append(ABIContractType.ADDRESS)
        self.__log_latest(v)


    def bytes32(self, v):
        self.bytes_fixed(32, v)
        self.types.append(ABIContractType.BYTES32)
        self.__log_latest(v)


    def bytes_fixed(self, mx, v, exact=0):
        typ = type(v).__name__
        if typ == 'str':
            v = strip_0x(v)
            l = len(v)
            if exact > 0 and l != exact * 2:
                raise ValueError('value wrong size; expected {}, got {})'.format(mx, l))
            if l > mx * 2:
                raise ValueError('value too long ({})'.format(l))
            v = pad(v, mx)
        elif typ == 'bytes':
            l = len(v)
            if exact > 0 and l != exact:
                raise ValueError('value wrong size; expected {}, got {})'.format(mx, l))
            b = bytearray(mx)
            b[mx-l:] = v
            v = pad(b.hex(), mx)
        else:
            raise ValueError('invalid input {}'.format(typ))
        self.contents.append(v)



    def get_method(self):
        if self.method_name == None:
            return ''
        return '{}({})'.format(self.method_name, ','.join(self.method_contents))


    def get_method_signature(self):
        s = self.get_method()
        return keccak256_string_to_hex(s)[:8]


    def get_contents(self):
        return ''.join(self.contents)


    def get(self):
        return self.encode()


    def encode(self):
        m = self.get_method_signature()
        c = self.get_contents()
        return m + c


    def __str__(self):
        return self.encode()



def abi_decode_single(typ, v):
    d = ABIContractDecoder()
    d.typ(typ)
    d.val(v)
    r = d.decode()
    return r[0]
