commit a40e057fbae58ad1716b601369592eba2aa53956
parent ba2df8529ea3b98ea491a336ccec1606da553cca
Author: nolash <dev@holbrook.no>
Date:   Sun, 27 Jun 2021 09:54:30 +0200
Handle invalid thresholds in numdir
Diffstat:
2 files changed, 21 insertions(+), 3 deletions(-)
diff --git a/hexdir/numeric.py b/hexdir/numeric.py
@@ -9,12 +9,23 @@ from .base import LevelDir
 class NumDir(LevelDir):
 
     def __init__(self, root_path, thresholds=[1000]):
+        thresholds = self.__thresholds_sanity(thresholds)
         super(NumDir, self).__init__(root_path, len(thresholds), 8)
-        fi = os.stat(self.master_file)
         self.thresholds = thresholds
+        fi = os.stat(self.master_file)
         self.entry_length = 8
 
 
+    def __thresholds_sanity(self, thresholds):
+        if len(thresholds) == 0:
+            raise ValueError('thresholds must have at least one value')
+        last_t = thresholds[0]
+        for i in range(len(thresholds) - 1):
+            if thresholds[i+1] > last_t:
+                raise ValueError('thresholds must have diminishing order')
+        return thresholds
+
+
     def to_dirpath(self, n): 
         c = n 
         x = 0
diff --git a/tests/test_numdir.py b/tests/test_numdir.py
@@ -18,14 +18,14 @@ class NumDirTest(unittest.TestCase):
 
     def setUp(self):
         self.dir = tempfile.mkdtemp() 
-        self.numdir = NumDir(os.path.join(self.dir, 'n'), [1000, 100])
-        logg.debug('setup numdir root {}'.format(self.dir))
 
 #    def tearDown(self):
 #        shutil.rmtree(self.dir)
 #        logg.debug('cleaned numdir root {}'.format(self.dir))
 
     def test_path(self):
+        self.numdir = NumDir(os.path.join(self.dir, 'n'), [1000, 100])
+        logg.debug('setup numdir root {}'.format(self.dir))
         path = self.numdir.to_filepath(1337)
         path_parts = []
         logg.debug(path)
@@ -37,6 +37,13 @@ class NumDirTest(unittest.TestCase):
         self.assertEqual(one, '1000')
 
 
+    def test_invalid_thresholds(self):
+        with self.assertRaises(ValueError):
+            self.numdir = NumDir(os.path.join(self.dir, 'n'), [100, 1000])
+
+        with self.assertRaises(ValueError):
+            self.numdir = NumDir(os.path.join(self.dir, 'n'), [])
+
 
 
 if __name__ == '__main__':