commit a730ebdfdb33975f51aefe8994afb5e9dac5e627
parent 54d344351aa2dd91355b2d0b7307a9dcf13055fa
Author: Louis Holbrook <accounts-gitlab@holbrook.no>
Date: Fri, 16 Apr 2021 10:49:50 +0000
Associate accounts through transaction adds
Diffstat:
10 files changed, 758 insertions(+), 50 deletions(-)
diff --git a/crypto_account_cache/account.py b/crypto_account_cache/account.py
@@ -0,0 +1,70 @@
+# standard imports
+import uuid
+import os
+import hashlib
+
+# local imports
+from .tag import TagPool
+from .crypto import Salter
+
+
+def sprinkle(data, salt):
+ h = hashlib.new('sha256')
+ h.update(data)
+ h.update(salt)
+ return h.digest()
+
+
+class Account(Salter):
+
+ def __init__(self, account, label=None, tags=[], create_digest=True):
+ if label == None:
+ label = str(account)
+ self.label = label
+ self.account_src = None
+ if create_digest:
+ self.account_src = account
+ self.account = sprinkle(self.account_src, self.salt)
+ else:
+ self.account = account
+ self.tags = TagPool()
+ for tag in tags:
+ self.tags.create(tag)
+
+
+ def connect(self, account):
+ if not isinstance(account, Account):
+ raise TypeError('account must be type crypto_account_cache.account.Account')
+ self.tags.merge(account.tags)
+
+
+ def is_same(self, account):
+ if not isinstance(account, Account):
+ raise TypeError('account must be type crypto_account_cache.account.Account')
+ return self.account == account.account
+
+
+ def is_account(self, account):
+ return sprinkle(account, self.salt) == self.account
+
+
+ def serialize(self):
+ b = self.tags.serialize() + self.account
+ return b
+
+
+ @staticmethod
+ def from_serialized(b, label=None):
+ l = len(b)
+ if l % 32 > 0:
+ raise ValueError('invalid data length; remainder {} of 32'.format(l % 32))
+ if l < 64:
+ raise ValueError('invalid data length; expected minimum 64, got {}'.format(l))
+
+ a = Account(b[-32:], label=label, create_digest=False)
+ a.tags.deserialize(b[:-32])
+ return a
+
+
+ def __str__(self):
+ return '{} [{}]'.format(self.account.hex(), str(self.tags))
diff --git a/crypto_account_cache/cache.py b/crypto_account_cache/cache.py
@@ -1,5 +1,6 @@
# standard imports
import os
+import logging
# external imports
from moolb import Bloom
@@ -7,27 +8,231 @@ from moolb import Bloom
# local imports
from .name import for_label
from .store import FileStore
+from .account import Account
+logg = logging.getLogger().getChild(__name__)
-class CryptoCache:
+
+def to_index(block_height, tx_index):
+ b = block_height.to_bytes(12, 'big')
+ b += tx_index.to_bytes(4, 'big')
+ return b
+
+
+def from_index(b):
+ block_height = int.from_bytes(b[:12], 'big')
+ tx_index = int.from_bytes(b[12:], 'big')
+ return (block_height, tx_index)
+
+
+class CacheBloom:
rounds = 3
+
+ def __init__(self, bits_size):
+ self.bits_size = bits_size
+ self.filter = {
+ 'subject': None,
+ 'object': None,
+ 'cache': None,
+ 'extra': None,
+ }
+
+
+ def reset(self):
+ self.filter['subject'] = Bloom(self.bits_size, CacheBloom.rounds)
+ self.filter['object'] = Bloom(self.bits_size, CacheBloom.rounds)
+ self.filter['cache'] = Bloom(self.bits_size, CacheBloom.rounds)
+ self.filter['extra'] = Bloom(self.bits_size, CacheBloom.rounds)
+
+
+ def add_raw(self, v, label):
+ logg.debug('foo')
+ self.filter[label].add(v)
+
+
+ def serialize(self):
+ if self.filter['subject'] == None:
+ logg.warning('serialize called on uninitialized cache bloom')
+ return b''
+
+ b = self.filter['subject'].to_bytes()
+ b += self.filter['object'].to_bytes()
+ b += self.filter['cache'].to_bytes()
+ b += self.filter['extra'].to_bytes()
+ return b
+
+
+ def deserialize(self, b):
+ byte_size = int(self.bits_size / 8)
+ length_expect = byte_size * 4
+ length_data = len(b)
+ if length_data != length_expect:
+ raise ValueError('data size mismatch; expected {}, got {}'.format(length_expect, length_data))
+
+ cursor = 0
+ self.filter['subject'] = Bloom(self.bits_size, CacheBloom.rounds, default_data=b[cursor:cursor+byte_size])
+
+ cursor += byte_size
+ self.filter['object'] = Bloom(self.bits_size, CacheBloom.rounds, default_data=b[cursor:cursor+byte_size])
+
+ cursor += byte_size
+ self.filter['cache'] = Bloom(self.bits_size, CacheBloom.rounds, default_data=b[cursor:cursor+byte_size])
+
+ cursor += byte_size
+ self.filter['extra'] = Bloom(self.bits_size, CacheBloom.rounds, default_data=b[cursor:cursor+byte_size])
+
+
+ @staticmethod
+ def from_serialized(b):
+ if len(b) % 4 > 0:
+ raise ValueError('invalid data length, remainder {} of 4'.format(len(b) % 32))
+
+ bits_size = int((len(b) * 8) / 4)
+ bloom = CacheBloom(bits_size)
+ bloom.deserialize(b)
+ return bloom
+
+
+ def have(self, data, label):
+ return self.filter[label].check(data)
+
+
+ def have_index(self, block_height, tx_index):
+ b = to_index(block_height, tx_index)
+ if self.have(b, 'cache'):
+ return True
+ return self.have(b, 'extra')
+
+
+ def register(self, accounts, block_height, tx_index):
+ subject_match = False
+ object_match = False
+ for account in accounts:
+ if self.have(account, 'subject'):
+ subject_match = True
+ elif self.have(account, 'object'):
+ object_match = True
+
+ if not subject_match and not object_match:
+ return False
+
+ b = to_index(block_height, tx_index)
+ if subject_match:
+ self.add_raw(b, 'cache')
+ if object_match:
+ self.add_raw(b, 'extra')
+
+ return True
- def __init__(self, chain_spec, bits_size, bloom=None):
- if bloom == None:
- bloom = CryptoCache.__bloom_for_size(bits_size)
- self.chain_spec = chain_spec
- self.bloom = bloom
+
+class Cache:
+
+ def __init__(self, bits_size, store=None, cache_bloom=None):
self.bits_size = bits_size
- self.store = None
+ self.store = store
+
+ if cache_bloom == None:
+ cache_bloom = CacheBloom(bits_size)
+ cache_bloom.reset()
+
+ self.cache_bloom = cache_bloom
+ self.subjects = {}
+ self.objects = {}
+
+ self.first_block_height = -1
+ self.first_tx_index = 0
+ self.last_block_height = 0
+ self.last_tx_index = 0
+
+
+ def serialize(self):
+ if self.first_block_height < 0:
+ raise AttributeError('no content to serialize')
+
+ b = to_index(self.first_block_height, self.first_tx_index)
+ b += to_index(self.last_block_height, self.last_tx_index)
+ bb = self.cache_bloom.serialize()
+ return bb + b
+
+
+ @staticmethod
+ def from_serialized(b):
+ cursor = len(b)-32
+ bloom = CacheBloom.from_serialized(b[:cursor])
+ c = Cache(bloom.bits_size, cache_bloom=bloom)
+
+ (c.first_block_height, c.first_tx_index) = from_index(b[cursor:cursor+16])
+ cursor += 16
+ (c.last_block_height, c.last_tx_index) = from_index(b[cursor:cursor+16])
+
+ return c
def set_store(self, store):
self.store = store
- if not store.initd:
- self.store.save(self.bloom.to_bytes())
+ if not store.initd and self.cache_bloom:
+ self.store.save(self.cache_bloom.serialize())
-
- @staticmethod
- def __bloom_for_size(bits_size):
- return Bloom(bits_size, CryptoCache.rounds)
+
+ def divide(self, accounts):
+ subjects = []
+ objects = []
+
+ for account in accounts:
+ if self.cache_bloom.have(account, 'subject'):
+ subject = self.subjects[account]
+ subjects.append(subject)
+ elif self.cache_bloom.have(account, 'object'):
+ objct = self.objects[account]
+ objects.append(objct)
+
+ return (subjects, objects)
+
+
+ def add_subject(self, account):
+ if not isinstance(account, Account):
+ raise TypeError('subject must be type crypto_account_cache.account.Account')
+ self.cache_bloom.add_raw(account.account, 'subject')
+ logg.debug('added subject {}'.format(account))
+ self.subjects[account.account] = account
+
+
+ def add_object(self, account):
+ if not isinstance(account, Account):
+ raise TypeError('subject must be type crypto_account_cache.account.Account')
+ self.cache_bloom.add_raw(account.account, 'object')
+ logg.debug('added object {}'.format(account))
+ self.objects[account.account] = account
+
+
+ def add_tx(self, sender, recipient, block_height, tx_index, relays=[]):
+ accounts = [sender, recipient] + relays
+ match = self.cache_bloom.register(accounts, block_height, tx_index)
+
+ if not match:
+ return False
+
+ if self.first_block_height == -1:
+ self.first_block_height = block_height
+ self.first_tx_index = tx_index
+ self.last_block_height = block_height
+ self.last_tx_index = tx_index
+
+ logg.info('match in {}:{}'.format(block_height, tx_index))
+
+ # TODO: watch out, this currently scales geometrically
+ (subjects, objects) = self.divide(accounts)
+ for subject in subjects:
+ for objct in objects:
+ subject.connect(objct)
+ for other_subject in subjects:
+ if subject.is_same(other_subject):
+ continue
+ subject.connect(other_subject)
+
+ return True
+
+
+ def have(self, block_height, tx_index):
+ return self.cache_bloom.have_index(block_height, tx_index)
diff --git a/crypto_account_cache/store.py b/crypto_account_cache/store.py
@@ -0,0 +1,17 @@
+class FileStore:
+
+ def __init__(self, path):
+ self.path = path
+ self.initd = False
+
+
+ def save(self, data):
+ f = open(self.path, 'wb')
+
+ l = len(data)
+ c = 0
+ while c < l:
+ c += f.write(data[c:])
+ f.close()
+
+ self.initd = True
diff --git a/crypto_account_cache/tag.py b/crypto_account_cache/tag.py
@@ -0,0 +1,92 @@
+# standard imports
+import hashlib
+import logging
+
+logg = logging.getLogger().getChild(__name__)
+
+
+class TagPool:
+
+ def __init__(self):
+ self.tags = []
+ self.tag_values = {}
+ self.sum = b'\x00' * 32
+ self.dirty = False
+
+
+ def get(self):
+ if self.dirty:
+ self.tags.sort()
+ h = hashlib.new('sha256')
+ for tag in self.tags:
+ h.update(tag)
+ self.sum = h.digest()
+ return self.sum
+
+
+ def add(self, tag, value=None):
+ if tag in self.tags:
+ return False
+ self.tags.append(tag)
+ self.tag_values[tag] = value
+ self.dirty = True
+ return True
+
+
+ def create(self, value):
+ h = hashlib.new('sha256')
+ h.update(value)
+ tag = h.digest()
+ self.add(tag, value)
+ return tag
+
+
+ def merge(self, tags):
+ if not isinstance(tags, TagPool):
+ raise TypeError('tags must be type crypto_account_type.tag.TagPool')
+ for tag in tags.tags:
+ self.add(tag)
+ self.tag_values[tag] = tags.tag_values[tag]
+
+ for tag in self.tags:
+ tags.add(tag)
+ tags.tag_values[tag] = self.tag_values[tag]
+
+
+ def serialize(self):
+ b = self.get()
+ for tag in self.tags:
+ b += tag
+ return b
+
+
+ def deserialize(self, b):
+ if len(b) % 32 > 0:
+ raise ValueError('invalid data length; remainder {} from 32'.format(len(b) % 32))
+ cursor = 32
+ z = b[:cursor]
+
+ for i in range(cursor, len(b), 32):
+ tag = b[i:i+32]
+ logg.debug('deserialize add {}'.format(tag))
+ self.add(tag)
+
+ zz = self.get()
+ if z != zz:
+ raise ValueError('data sum does not match content; expected {}, found {}'.format(zz.hex(), z.hex()))
+
+
+ def __str__(self):
+ tag_list = []
+ for tag in self.tags:
+ v = self.tag_values[tag]
+ if v == None:
+ v = tag.hex()
+ else:
+ try:
+ v = v.decode('utf-8')
+ except UnicodeDecodeError:
+ v = v.hex()
+ tag_list.append(v)
+ tag_list.sort()
+ return ','.join(tag_list)
diff --git a/tests/base.py b/tests/base.py
@@ -5,6 +5,10 @@ import logging
import tempfile
import shutil
+# local imports
+from crypto_account_cache.account import Account
+
+
BLOOM_BITS = 1024 * 1024 * 8
script_dir = os.path.realpath(os.path.dirname(__file__))
@@ -17,12 +21,16 @@ class TestBase(unittest.TestCase):
def setUp(self):
os.makedirs(data_dir, exist_ok=True)
- self.salt = os.urandom(32)
+ self.salt = Account.salt
self.session_data_dir = tempfile.mkdtemp(dir=data_dir)
self.bits_size = BLOOM_BITS
self.bytes_size = (BLOOM_BITS - 1) / 8 + 1
-
+ self.alice = Account(os.urandom(20), label='alice', tags=[b'inky'])
+ self.bob = Account(os.urandom(20), label='bob', tags=[b'pinky'])
+ self.eve = Account(os.urandom(20), label='eve', tags=[b'blinky'])
+ self.mallory = Account(os.urandom(20), label='mallory', tags=[b'sue'])
+
def tearDown(self):
shutil.rmtree(self.session_data_dir)
diff --git a/tests/test_account.py b/tests/test_account.py
@@ -0,0 +1,44 @@
+# standard imports
+import unittest
+import copy
+
+# local imports
+from crypto_account_cache.account import Account
+from crypto_account_cache.tag import TagPool
+
+# test imports
+from tests.base import TestBase
+
+
+class TestAccount(TestBase):
+
+ def test_account_compare(self):
+ alice = self.alice
+ alice_again = copy.copy(self.alice)
+ self.assertTrue(alice.is_same(alice_again))
+
+ alice_alias = Account(self.alice.account_src)
+ self.assertTrue(alice.is_same(alice_alias))
+
+ self.assertFalse(alice.is_same(self.bob))
+
+
+ def test_connect_accounts(self):
+ self.alice.connect(self.bob)
+ self.assertEqual(self.alice.tags.get(), self.bob.tags.get())
+
+
+
+ def test_serialize(self):
+ self.alice.tags.create(b'xyzzy')
+ z = self.alice.tags.get()
+
+ b = self.alice.serialize()
+ self.assertEqual(b[len(b)-32:], self.alice.account)
+
+ new_alice = Account.from_serialized(b)
+ self.assertTrue(new_alice.is_same(self.alice))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_basic.py b/tests/test_basic.py
@@ -1,35 +0,0 @@
-# standard imports
-import unittest
-import os
-
-# external imports
-from chainlib.chain import ChainSpec
-
-# local imports
-from crypto_account_cache.name import for_label
-from crypto_account_cache.cache import CryptoCache
-from crypto_account_cache.store import FileStore
-
-# test imports
-from tests.base import TestBase
-
-class TestBasic(TestBase):
-
- def setUp(self):
- super(TestBasic, self).setUp()
-
- self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
- self.account = os.urandom(20)
-
- self.filename = for_label(self.chain_spec, self.account, self.salt)
- self.filepath = os.path.join(self.session_data_dir, self.filename)
- self.store = FileStore(self.filepath)
-
- def test_create_cache(self):
- cache = CryptoCache(self.chain_spec, self.bits_size)
- cache.set_store(self.store)
-
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/test_bloom.py b/tests/test_bloom.py
@@ -0,0 +1,111 @@
+# standard imports
+import os
+import unittest
+import copy
+
+# local imports
+from crypto_account_cache.cache import (
+ CacheBloom,
+ to_index,
+ )
+
+
+class TestBloom(unittest.TestCase):
+
+ def setUp(self):
+ self.size = 1024
+ self.bloom = CacheBloom(self.size)
+ self.bloom.reset()
+ self.alice = os.urandom(20)
+ self.bob = os.urandom(20)
+
+ self.bloom.add_raw(self.alice, 'subject')
+ self.bloom.add_raw(self.bob, 'object')
+
+
+ def reset_with_accounts(self):
+ self.bloom.reset()
+ self.bloom.add_raw(self.alice, 'subject')
+ self.bloom.add_raw(self.bob, 'object')
+
+
+ def test_bloom(self):
+
+ orig_serial = self.bloom.serialize()
+
+ self.bloom.add_raw(b'\x01', 'subject')
+ self.bloom.add_raw(b'\x01', 'object')
+ self.bloom.add_raw(b'\x01', 'cache')
+ self.bloom.add_raw(b'\x01', 'extra')
+
+ b = self.bloom.serialize()
+ byte_size = int(1024 / 8)
+ self.assertNotEqual(orig_serial, self.bloom.serialize())
+
+ self.reset_with_accounts()
+ self.assertEqual(orig_serial, self.bloom.serialize())
+
+ bloom_recovered = CacheBloom(self.size)
+ self.assertNotEqual(b, bloom_recovered.serialize())
+
+ bloom_recovered.deserialize(b)
+ self.assertEqual(b, bloom_recovered.serialize())
+
+ bloom_recovered = CacheBloom.from_serialized(b)
+ self.assertEqual(b, bloom_recovered.serialize())
+
+
+ def test_bloom_index(self):
+ block_height = 42
+ tx_index = 13
+ index = to_index(42, 13)
+ self.assertEqual(block_height, int.from_bytes(index[:12], 'big'))
+ self.assertEqual(tx_index, int.from_bytes(index[12:], 'big'))
+
+
+ def test_add(self):
+ block_height = 42
+ tx_index = 13
+
+ orig_cache = copy.copy(self.bloom.filter['cache'].to_bytes())
+ orig_extra = copy.copy(self.bloom.filter['extra'].to_bytes())
+
+ r = self.bloom.register([self.alice, self.bob], block_height, tx_index)
+ self.assertTrue(r)
+ self.assertNotEqual(self.bloom.filter['cache'].to_bytes(), orig_cache)
+ self.assertNotEqual(self.bloom.filter['extra'].to_bytes(), orig_extra)
+
+ self.reset_with_accounts()
+ r = self.bloom.register([self.alice], block_height, tx_index)
+ self.assertTrue(r)
+ self.assertNotEqual(self.bloom.filter['cache'].to_bytes(), orig_cache)
+ self.assertEqual(self.bloom.filter['extra'].to_bytes(), orig_extra)
+
+ self.reset_with_accounts()
+ r = self.bloom.register([self.bob], block_height, tx_index)
+ self.assertTrue(r)
+ self.assertEqual(self.bloom.filter['cache'].to_bytes(), orig_cache)
+ self.assertNotEqual(self.bloom.filter['extra'].to_bytes(), orig_extra)
+
+
+ def test_check(self):
+ block_height = 42
+ tx_index = 13
+
+ self.assertFalse(self.bloom.have_index(block_height, tx_index))
+
+ r = self.bloom.register([self.alice], block_height, tx_index)
+ self.assertTrue(self.bloom.have_index(block_height, tx_index))
+
+ r = self.reset_with_accounts()
+ r = self.bloom.register([self.bob], block_height, tx_index)
+ self.assertTrue(self.bloom.have_index(block_height, tx_index))
+
+ r = self.reset_with_accounts()
+ someaccount = os.urandom(20)
+ r = self.bloom.register([someaccount], block_height, tx_index)
+ self.assertFalse(self.bloom.have_index(block_height, tx_index))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_cache.py b/tests/test_cache.py
@@ -0,0 +1,147 @@
+# standard imports
+import unittest
+import os
+import copy
+import logging
+
+# external imports
+from chainlib.chain import ChainSpec
+
+# local imports
+from crypto_account_cache.name import for_label
+from crypto_account_cache.cache import (
+ Cache,
+ from_index,
+ )
+from crypto_account_cache.store import FileStore
+from crypto_account_cache.account import Account
+
+# test imports
+from tests.base import TestBase
+
+logging.basicConfig(level=logging.DEBUG)
+logg = logging.getLogger()
+
+
+class TestBasic(TestBase):
+
+ def setUp(self):
+ super(TestBasic, self).setUp()
+
+ self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
+ self.account = os.urandom(20)
+
+ self.filename = for_label(self.chain_spec, self.account, self.salt)
+ self.filepath = os.path.join(self.session_data_dir, self.filename)
+ self.store = FileStore(self.filepath)
+
+
+ def test_divide(self):
+ cache = Cache(self.bits_size, None)
+ cache.add_subject(self.alice)
+ cache.add_subject(self.bob)
+ cache.add_object(self.eve)
+ cache.add_object(self.mallory)
+ someaccount = Account(os.urandom(20))
+
+ (subjects, objects) = cache.divide([self.alice.account, self.bob.account, self.eve.account, self.mallory.account, someaccount.account])
+
+ self.assertTrue(self.alice in subjects)
+ self.assertFalse(self.alice in objects)
+
+ self.assertTrue(self.bob in subjects)
+ self.assertFalse(self.bob in objects)
+
+ self.assertFalse(self.eve in subjects)
+ self.assertTrue(self.eve in objects)
+
+ self.assertFalse(self.mallory in subjects)
+ self.assertTrue(self.mallory in objects)
+
+ self.assertFalse(someaccount in subjects)
+ self.assertFalse(someaccount in objects)
+
+
+ def test_create_cache(self):
+ cache = Cache(self.bits_size)
+ cache.add_subject(self.alice)
+ cache.add_object(self.bob)
+
+ block_height = 42
+ tx_index = 13
+
+ self.assertNotEqual(self.alice.tags.get(), self.bob.tags.get())
+
+ match = cache.add_tx(self.alice.account, self.bob.account, block_height, tx_index)
+ self.assertTrue(match)
+
+ self.assertEqual(self.alice.tags.get(), self.bob.tags.get())
+
+
+ def test_state(self):
+ cache = Cache(self.bits_size)
+
+ cache.add_subject(self.alice)
+ cache.add_subject(self.bob)
+ cache.add_object(self.eve)
+
+ first_block_height = 42
+ first_tx_index = 13
+ match = cache.add_tx(self.alice.account, self.bob.account, first_block_height, first_tx_index)
+
+ new_block_height = 666
+ new_tx_index = 1337
+ match = cache.add_tx(self.alice.account, self.eve.account, new_block_height, new_tx_index)
+
+ cache.first_block_height == first_block_height
+ cache.first_tx_index == first_tx_index
+ cache.last_block_height == new_block_height
+ cache.last_tx_index == new_tx_index
+
+
+ def test_recover(self):
+ cache = Cache(self.bits_size)
+
+ cache.add_subject(self.alice)
+ cache.add_subject(self.bob)
+ cache.add_object(self.eve)
+
+ first_block_height = 42
+ first_tx_index = 13
+ match = cache.add_tx(self.alice.account, self.bob.account, first_block_height, first_tx_index)
+
+ new_block_height = 666
+ new_tx_index = 1337
+ match = cache.add_tx(self.alice.account, self.eve.account, new_block_height, new_tx_index)
+
+ self.assertTrue(cache.have(first_block_height, first_tx_index))
+ self.assertTrue(cache.have(new_block_height, new_tx_index))
+
+ cache = Cache(self.bits_size, cache_bloom=copy.copy(cache.cache_bloom))
+ self.assertTrue(cache.have(first_block_height, first_tx_index))
+ self.assertTrue(cache.have(new_block_height, new_tx_index))
+
+
+ def test_serialize(self):
+ cache = Cache(self.bits_size)
+
+ cache.add_subject(self.alice)
+ cache.add_subject(self.bob)
+ cache.add_object(self.eve)
+
+ first_block_height = 42
+ first_tx_index = 13
+ match = cache.add_tx(self.alice.account, self.bob.account, first_block_height, first_tx_index)
+
+ new_block_height = 666
+ new_tx_index = 1337
+ match = cache.add_tx(self.alice.account, self.eve.account, new_block_height, new_tx_index)
+
+ b = cache.serialize()
+ cache_recovered = cache.from_serialized(b)
+ self.assertEqual(cache_recovered.first_block_height, first_block_height)
+ self.assertEqual(cache_recovered.first_tx_index, first_tx_index)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_tag.py b/tests/test_tag.py
@@ -0,0 +1,49 @@
+# standard imports
+import unittest
+import logging
+
+# local imports
+from crypto_account_cache.tag import TagPool
+
+logging.basicConfig(level=logging.DEBUG)
+
+
+class TestTag(unittest.TestCase):
+
+ def test_tag_add(self):
+ tag = TagPool()
+ self.assertEqual(tag.get(), b'\x00' * 32)
+
+ a = tag.create(b'foo')
+ b = tag.create(b'bar')
+ self.assertNotEqual(a, b)
+
+ self.assertFalse(tag.add(a))
+ self.assertFalse(tag.add(b))
+
+ z_one = tag.get()
+
+ tag = TagPool()
+ tag.create(b'foo')
+ tag.create(b'bar')
+
+ z_two = tag.get()
+
+ self.assertEqual(z_one, z_two)
+
+
+ def test_tag_serialize(self):
+ tag = TagPool()
+
+ tag.create(b'foo')
+ tag.create(b'bar')
+
+ s = tag.serialize()
+ self.assertEqual(len(s), 32 * 3)
+
+ tag_recovered = TagPool()
+ tag_recovered.deserialize(s)
+
+
+if __name__ == '__main__':
+ unittest.main()