test_store.py (2250B)
1 # standard imports 2 import unittest 3 import logging 4 5 # local imports 6 from shep.persist import PersistedState 7 from shep.error import ( 8 StateExists, 9 StateItemExists, 10 StateInvalid, 11 StateItemNotFound, 12 ) 13 14 logging.basicConfig(level=logging.DEBUG) 15 logg = logging.getLogger() 16 17 18 class MockStore: 19 20 def __init__(self): 21 self.v = {} 22 self.for_state = 0 23 24 25 def put(self, k, contents=None): 26 self.v[k] = contents 27 28 29 def remove(self, k): 30 del self.v[k] 31 32 33 def get(self, k): 34 return self.v[k] 35 36 37 def list(self): 38 return list(self.v.keys()) 39 40 41 class TestStateItems(unittest.TestCase): 42 43 def setUp(self): 44 self.mockstore = MockStore() 45 46 def mockstore_factory(v): 47 #self.mockstore.for_state = v 48 return self.mockstore 49 50 self.states = PersistedState(mockstore_factory, 4) 51 self.states.add('foo') 52 self.states.add('bar') 53 self.states.add('baz') 54 self.states.alias('xyzzy', self.states.BAZ | self.states.BAR) 55 self.states.alias('plugh', self.states.FOO | self.states.BAR) 56 57 58 def test_persist_new(self): 59 item = b'foo' 60 self.states.put(item, True) 61 self.assertTrue(self.mockstore.v.get(item)) 62 63 64 def test_persist_move(self): 65 item = b'foo' 66 self.states.put(item, self.states.FOO) 67 self.states.move(item, self.states.XYZZY) 68 self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY)) 69 70 71 def test_persist_move(self): 72 item = b'foo' 73 self.states.put(item, self.states.FOO, True) 74 self.states.move(item, self.states.XYZZY) 75 #self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY)) 76 # TODO: cant check the add because remove happens after remove, need better mock 77 self.assertIsNone(self.mockstore.v.get(item)) 78 79 80 def test_persist_move_new(self): 81 item = b'foo' 82 self.states.put(item) 83 self.states.move(item, self.states.BAZ) 84 #self.assertEqual(self.mockstore.for_state, self.states.name(self.states.BAZ)) 85 self.assertIsNone(self.mockstore.v.get(item)) 86 87 88 89 if __name__ == '__main__': 90 unittest.main()