# standard imports
import os
import hashlib
import logging

# external imports
import gnupg

# local imports
from clicada.error import AuthError

logg = logging.getLogger(__name__)


class PGPAuthCrypt:

    typ = 'gnupg'

    def __init__(self, db_dir, auth_key, pgp_dir=None):
        self.db_dir = db_dir
        try:
            bytes.fromhex(auth_key)
        except TypeError:
            raise AuthError('invalid key {}'.format(auth_key))
        except ValueError:
            raise AuthError('invalid key {}'.format(auth_key))
        self.auth_key = auth_key
        self.gpg = gnupg.GPG(gnupghome=pgp_dir)
        self.secret = None


    def get_secret(self, passphrase=''):
        if passphrase == None:
            passphrase = ''
        p = os.path.join(self.db_dir, '.secret')
        try:
            f = open(p, 'rb')
        except FileNotFoundError:
            h = hashlib.sha256()
            h.update(bytes.fromhex(self.auth_key))
            h.update(passphrase.encode('utf-8'))
            z = h.digest()
            secret = self.gpg.encrypt(z, [self.auth_key], always_trust=True)
            if not secret.ok:
                raise AuthError('could not encrypt secret for {}'.format(self.auth_key))

            d = os.path.dirname(p)
            os.makedirs(d, exist_ok=True)
            f = open(p, 'wb')
            f.write(secret.data)
            f.close()
            f = open(p, 'rb')
        secret = self.gpg.decrypt_file(f, passphrase=passphrase)
        if not secret.ok:
            raise AuthError('could not decrypt encryption secret. wrong password?')
        f.close()
        self.secret = secret.data
        self.__passphrase = passphrase


    def get_passphrase(self):
        return self.__passphrase


    def fingerprint(self):
        return self.auth_key


    def sign(self, plaintext, encoding, passphrase='', detach=True):
        r = self.gpg.sign(plaintext, passphrase=passphrase, detach=detach)
        if len(r.data) == 0:
            raise AuthError('signing failed: ' + r.status)

        if encoding == 'base64':
            r = r.data

        return r
