# standard imports
import os
import unittest
import logging
import hashlib

# third-party imports
import gnupg
import eth_keys
import confini

# local imports
import ecuth
from ecuth.error import ChallengeError
from ecuth.error import TokenExpiredError
from ecuth.error import SessionExpiredError
from ecuth.error import SessionError
from ecuth.challenge import source_hash
from ecuth.filter.eip712 import EIP712Filter

logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()


root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
gpg_dir = os.path.join(root_dir, '.gnupg')
logg.debug('gpg dir {}'.format(gpg_dir))
gpg = gnupg.GPG(gnupghome=gpg_dir)


def sha1filter(s):
    logg.debug('object {}'.format(s))
    h = hashlib.sha1()
    h.update(s)
    z = h.digest()
    logg.debug('shafilter {} -> {}'.format(s.hex(), z.hex()))
    return z


# fetch from file for test
def mock_fetch(address=None):
    f = open(os.path.join(root_dir, 'test', 'data', address), 'rb')
    d = f.read(1024*1024)
    f.close()
    return d


# GNUPG data pluggable test decrypter
def decrypter(data):
    d = gpg.decrypt(data, passphrase='tralala')
    if d.trust_level < d.TRUST_FULLY:
        raise ValueError('untrusted data')
    logg.debug('trust {}'.format(d.trust_level))
    return str(d)


class TestCore(unittest.TestCase):

    config = confini.Config(os.path.join(root_dir, 'config'))

    def setUp(self):
        self.config.process()
        self.config.require('NAME', 'EIP712')
        self.config.require('VERSION', 'EIP712')
        self.config.require('BASE_URL', 'ECUTH')
        self.config.require('CHAIN_ID', 'ECUTH')
        self.config.validate()


    def tearDown(self):
        pass


    @unittest.skip('foo')
    def test_basic(self):
        r = ecuth.SimpleRetriever(self.config, decrypter)
        r._fetch  = mock_fetch
        pk_bytes = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005')
        pk = eth_keys.keys.PrivateKey(pk_bytes)
        address = pk.public_key.to_checksum_address()
        #address = '0xe1AB8145F7E55DC933d51a18c793F901A3A0b276'

        # wrong challenge
        ip = '127.0.0.1'
        (c, expire) = r.challenge(ip)
        challenge_key = ecuth.ecuth.source_hash(ip, c)   
        r.auth[challenge_key].challenge = pk_bytes
        signature = pk.sign_msg(c)
        with self.assertRaises(ChallengeError):
            r.load(ip, c, signature)

        (c, expire) = r.challenge(ip)
        signature = pk.sign_msg(c)
        (refresh, auth) = r.load(ip, c, signature)
        self.assertTrue(r.read(address, 'ussd.session'))
        self.assertFalse(r.write(address, 'ussd.session'))
        self.assertFalse(r.read(address, 'ussd.pin'))
        self.assertTrue(r.write(address, 'ussd.pin'))

        # verify reverse lookup
        r.check(auth)

        # invalidate auth token
        r.session[address].auth_expire = 0
        with self.assertRaises(TokenExpiredError):
            self.assertTrue(r.read(address, 'ussd.session'))

        token = r.renew(address, refresh)
        self.assertTrue(r.read(address, 'ussd.session'))

        # invalidate refresh token
        r.session[address].refresh_expire = 0
        r.session[address].auth_expire = 0
        with self.assertRaises(SessionExpiredError):
            r.renew(address, refresh)
             

    def test_eip712(self):
        r = ecuth.SimpleRetriever(self.config, decrypter)
        eip712_filter = EIP712Filter(self.config.get('EIP712_NAME'), self.config.get('EIP712_VERSION'), self.config.get('ECUTH_CHAIN_ID'))
        r.add_challenge_filter(eip712_filter.filter, 'eip712')

        r._fetch  = mock_fetch
        pk_bytes = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005')
        pk = eth_keys.keys.PrivateKey(pk_bytes)
        address = pk.public_key.to_checksum_address()

        ip = '127.0.0.1'
        (c, expire) = r.challenge(ip)
        k = source_hash(ip.encode('utf-8'), c)

        # signing the nonce only will not work
        eip712_c = eip712_filter.filter(c)
        #eip712_c = r.auth[k].eip712(r.name, r.version, r.chain_id)
        signature = pk.sign_msg(c)
        with self.assertRaises(FileNotFoundError):
            r.load(ip, c, signature)

        # signing the eip712 structure is ok
        (c, expire) = r.challenge(ip)
        k = source_hash(ip.encode('utf-8'), c)
        eip712_c = eip712_filter.filter(c)
        #eip712_c = r.auth[k].eip712(r.name, r.version, r.chain_id)
        signature = pk.sign_msg(eip712_c)
        (refresh, auth, auth_expire) = r.load(ip, c, signature)


    def test_multi_filters(self):
        r = ecuth.SimpleRetriever(self.config, decrypter)
        r.add_challenge_filter(sha1filter)
        eip712_filter = EIP712Filter(self.config.get('EIP712_NAME'), self.config.get('EIP712_VERSION'), self.config.get('ECUTH_CHAIN_ID'))
        r.add_challenge_filter(eip712_filter.filter, 'eip712')

        r._fetch  = mock_fetch
        pk_bytes = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005')
        pk = eth_keys.keys.PrivateKey(pk_bytes)
        address = pk.public_key.to_checksum_address()

        ip = '127.0.0.1'
        (c, expire) = r.challenge(ip)
        k = source_hash(ip.encode('utf-8'), c)
        signature = pk.sign_msg(c)
        with self.assertRaises(FileNotFoundError):
            r.load(ip, c, signature)

        (c, expire) = r.challenge(ip)
        k = source_hash(ip.encode('utf-8'), c)
        logg.debug('challenge {}'.format(c))
        filtered_c = sha1filter(c)
        signature = pk.sign_msg(filtered_c)

        with self.assertRaises(FileNotFoundError):
            r.load(ip, c, signature)

        (c, expire) = r.challenge(ip)
        k = source_hash(ip.encode('utf-8'), c)
        logg.debug('challenge {}'.format(c))
        first_filtered_c = sha1filter(c)
        second_filtered_c = eip712_filter.filter(first_filtered_c)
        signature = pk.sign_msg(second_filtered_c)
        (refresh, auth, expire) = r.load(ip, c, signature)


if __name__ == '__main__':
    unittest.main()
