commit 798262f00f3af1cd59320ae489274f1cf999d099
parent af8ce95e22407bc22082fffecd3cf125dd833abc
Author: lash <dev@holbrook.no>
Date: Wed, 16 Mar 2022 17:13:05 +0000
State change event emitter
Diffstat:
2 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/shep/state.py b/shep/state.py
@@ -30,7 +30,7 @@ class State:
base_state_name = 'NEW'
- def __init__(self, bits, logger=None, verifier=None, check_alias=True):
+ def __init__(self, bits, logger=None, verifier=None, check_alias=True, event_callback=None):
self.__bits = bits
self.__limit = (1 << bits) - 1
self.__c = 0
@@ -43,6 +43,7 @@ class State:
self.modified_last = {}
self.verifier = verifier
self.check_alias = check_alias
+ self.event_callback = event_callback
@classmethod
@@ -320,6 +321,9 @@ class State:
self.__contents[key] = contents
self.register_modify(key)
+
+ if self.event_callback != None:
+ self.event_callback(key, state)
return state
@@ -369,6 +373,9 @@ class State:
self.register_modify(key)
+ if self.event_callback != None:
+ self.event_callback(key, to_state)
+
return to_state
diff --git a/tests/test_state.py b/tests/test_state.py
@@ -13,6 +13,18 @@ logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
+class MockCallback:
+
+ def __init__(self):
+ self.items = {}
+
+
+ def add(self, k, v):
+ if self.items.get(k) == None:
+ self.items[k] = []
+ self.items[k].append(v)
+
+
class TestState(unittest.TestCase):
def test_key_check(self):
@@ -177,5 +189,24 @@ class TestState(unittest.TestCase):
self.assertGreater(a, b)
+ def test_event_callback(self):
+ cb = MockCallback()
+ states = State(3, event_callback=cb.add)
+ states.add('foo')
+ states.add('bar')
+ states.add('baz')
+ states.alias('xyzzy', states.FOO | states.BAR)
+ states.put('abcd')
+ states.set('abcd', states.FOO)
+ states.set('abcd', states.BAR)
+ states.change('abcd', states.BAZ, states.XYZZY)
+ events = cb.items['abcd']
+ self.assertEqual(len(events), 4)
+ self.assertEqual(events[0], states.NEW)
+ self.assertEqual(events[1], states.FOO)
+ self.assertEqual(events[2], states.XYZZY)
+ self.assertEqual(events[3], states.BAZ)
+
+
if __name__ == '__main__':
unittest.main()