commit 3ffb3b08aa99b34bdbe78c74196ea655c31fb9e0
parent 0eaf032b894ba985c49995e33a4d61cab5e8924e
Author: lash <dev@holbrook.no>
Date: Mon, 31 Jan 2022 12:10:04 +0000
Add generic persistence wrapper
Diffstat:
4 files changed, 144 insertions(+), 5 deletions(-)
diff --git a/shep/persist.py b/shep/persist.py
@@ -0,0 +1,48 @@
+# local imports
+from .state import State
+
+
+class PersistedState(State):
+
+ def __init__(self, factory, bits, logger=None):
+ super(PersistedState, self).__init__(bits, logger=logger)
+ self.__store_factory = factory
+ self.__stores = {}
+
+
+ def __ensure_store(self, k):
+ if self.__stores.get(k) == None:
+ self.__stores[k] = self.__store_factory(k)
+
+
+ def put(self, item, state=None):
+ k = self.name(state)
+ self.__ensure_store(k)
+ self.__stores[k].add(item)
+
+ super(PersistedState, self).put(item, state=state)
+
+
+ def move(self, item, to_state):
+ k_to = self.name(to_state)
+
+ from_state = self.state(item)
+ k_from = self.name(from_state)
+
+ self.__ensure_store(k_to)
+ self.__ensure_store(k_from)
+
+ self.__stores[k_to].add(item)
+ self.__stores[k_from].remove(item)
+
+ super(PersistedState, self).move(item, to_state)
+
+
+ def purge(self, item):
+ state = self.state(item)
+ k = self.name(state)
+
+ self.__ensure_store(k)
+
+ self.__stores[k].remove(item)
+ super(PersistedState, self).purge(item)
diff --git a/shep/state.py b/shep/state.py
@@ -9,13 +9,13 @@ from shep.error import (
class State:
- def __init__(self, bits, logger=None, store_factory=None):
+ def __init__(self, bits, logger=None):
self.__bits = bits
self.__limit = (1 << bits) - 1
self.__c = 0
- self.__reverse = {}
-
self.NEW = 0
+
+ self.__reverse = {0: self.NEW}
self.__items = {self.NEW: []}
self.__items_reverse = {}
@@ -124,6 +124,15 @@ class State:
return l
+ def name(self, v):
+ if v == None:
+ return self.NEW
+ k = self.__reverse.get(v)
+ if k == None:
+ raise StateInvalid(v)
+ return k
+
+
def match(self, v, pure=False):
alias = None
if not pure:
@@ -175,7 +184,6 @@ class State:
current_state_list.pop(idx)
-
def purge(self, item):
current_state = self.__items_reverse.get(item)
if current_state == None:
diff --git a/tests/test_item.py b/tests/test_item.py
@@ -74,6 +74,5 @@ class TestStateItems(unittest.TestCase):
self.states.state(item)
-
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_store.py b/tests/test_store.py
@@ -0,0 +1,84 @@
+# standard imports
+import unittest
+
+# local imports
+from shep.persist import PersistedState
+from shep.error import (
+ StateExists,
+ StateItemExists,
+ StateInvalid,
+ StateItemNotFound,
+ )
+
+class MockStore:
+
+ def __init__(self):
+ self.v = {}
+ self.for_state = 0
+
+
+ def add(self, k):
+ self.v[k] = True
+
+
+ def remove(self, k):
+ del self.v[k]
+
+
+class TestStateItems(unittest.TestCase):
+
+ def setUp(self):
+ self.mockstore = MockStore()
+
+ def mockstore_factory(v):
+ self.mockstore.for_state = v
+ return self.mockstore
+
+ self.states = PersistedState(mockstore_factory, 4)
+ self.states.add('foo')
+ self.states.add('bar')
+ self.states.add('baz')
+ self.states.alias('xyzzy', self.states.BAZ | self.states.BAR)
+ self.states.alias('plugh', self.states.FOO | self.states.BAR)
+
+
+ def test_persist_new(self):
+ item = b'foo'
+ self.states.put(item)
+ self.assertTrue(self.mockstore.v.get(item))
+
+
+ def test_persist_move(self):
+ item = b'foo'
+ self.states.put(item, self.states.FOO)
+ self.states.move(item, self.states.XYZZY)
+ self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY))
+
+
+ def test_persist_move(self):
+ item = b'foo'
+ self.states.put(item, self.states.FOO)
+ self.states.move(item, self.states.XYZZY)
+ self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY))
+ # TODO: cant check the add because remove happens after remove, need better mock
+ self.assertIsNone(self.mockstore.v.get(item))
+
+
+ def test_persist_purge(self):
+ item = b'foo'
+ self.states.put(item, self.states.FOO)
+ self.states.purge(item)
+ self.assertEqual(self.mockstore.for_state, self.states.name(self.states.FOO))
+ self.assertIsNone(self.mockstore.v.get(item))
+
+
+ def test_persist_move_new(self):
+ item = b'foo'
+ self.states.put(item)
+ self.states.move(item, self.states.BAZ)
+ self.assertEqual(self.mockstore.for_state, self.states.name(self.states.BAZ))
+ self.assertIsNone(self.mockstore.v.get(item))
+
+
+if __name__ == '__main__':
+ unittest.main()