commit 2356ebc08f221d87131085941d807f17594ed6fd
parent 8ccc89b4a538bc351fe25edea74158ed06fc0149
Author: lash <dev@holbrook.no>
Date: Thu, 17 Mar 2022 21:36:07 +0000
Pure-only all, faulty peek check, update persist init
Diffstat:
5 files changed, 22 insertions(+), 11 deletions(-)
diff --git a/shep/persist.py b/shep/persist.py
@@ -17,8 +17,8 @@ class PersistedState(State):
:type logger: object
"""
- def __init__(self, factory, bits, logger=None, verifier=None):
- super(PersistedState, self).__init__(bits, logger=logger, verifier=verifier)
+ def __init__(self, factory, bits, logger=None, verifier=None, check_alias=True, event_callback=None):
+ super(PersistedState, self).__init__(bits, logger=logger, verifier=verifier, check_alias=check_alias, event_callback=event_callback)
self.__store_factory = factory
self.__stores = {}
diff --git a/shep/state.py b/shep/state.py
@@ -203,7 +203,7 @@ class State:
self.__set(k, v)
- def all(self):
+ def all(self, pure=False):
"""Return list of all unique atomic and alias states.
:rtype: list of ints
@@ -215,6 +215,10 @@ class State:
continue
if k.upper() != k:
continue
+ if pure:
+ state = self.from_name(k)
+ if not self.__is_pure(state):
+ continue
l.append(k)
l.sort()
return l
@@ -349,7 +353,7 @@ class State:
raise StateItemNotFound(key)
new_state = self.__reverse.get(to_state)
- if new_state == None:
+ if new_state == None and self.check_alias:
raise StateInvalid(to_state)
return self.__move(key, current_state, to_state)
@@ -549,7 +553,7 @@ class State:
state = 1
else:
state <<= 1
- if state > self.__c:
+ if state > self.__limit:
raise StateInvalid('unknown state {}'.format(state))
return state
diff --git a/shep/store/file.py b/shep/store/file.py
@@ -27,7 +27,10 @@ class SimpleFileStore:
"""
fp = os.path.join(self.__path, k)
if contents == None:
- contents = ''
+ if self.__m[1] == 'wb':
+ contents = b''
+ else:
+ contents = ''
f = open(fp, self.__m[1])
f.write(contents)
diff --git a/tests/test_file.py b/tests/test_file.py
@@ -200,6 +200,9 @@ class TestStateReport(unittest.TestCase):
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR)
+ self.states.next('abcd')
+ self.assertEqual(self.states.state('abcd'), self.states.BAZ)
+
with self.assertRaises(StateInvalid):
self.states.next('abcd')
@@ -207,7 +210,7 @@ class TestStateReport(unittest.TestCase):
with self.assertRaises(FileNotFoundError):
os.stat(fp)
- fp = os.path.join(self.d, 'BAR', 'abcd')
+ fp = os.path.join(self.d, 'BAZ', 'abcd')
os.stat(fp)
diff --git a/tests/test_state.py b/tests/test_state.py
@@ -48,11 +48,12 @@ class TestState(unittest.TestCase):
def test_limit(self):
- states = State(2)
+ states = State(3)
states.add('foo')
states.add('bar')
+ states.add('baz')
with self.assertRaises(OverflowError):
- states.add('baz')
+ states.add('gaz')
def test_dup(self):
@@ -122,7 +123,7 @@ class TestState(unittest.TestCase):
def test_peek(self):
- states = State(3)
+ states = State(2)
states.add('foo')
states.add('bar')
@@ -135,7 +136,7 @@ class TestState(unittest.TestCase):
states.move('abcd', states.BAR)
with self.assertRaises(StateInvalid):
- self.assertEqual(states.peek('abcd'))
+ states.peek('abcd')
def test_from_name(self):