commit f7ec82c08944ec2b14637968fe01c65d3e87876f
parent 62d322a7029786a19c319917aeec56a2954d0faa
Author: lash <dev@holbrook.no>
Date: Wed, 11 May 2022 08:32:25 +0000
Add bitfield flag specifier, short to long mapping
Diffstat:
5 files changed, 268 insertions(+), 3 deletions(-)
diff --git a/aiee/flag.py b/aiee/flag.py
@@ -0,0 +1,153 @@
+# standard import
+import re
+
+
+re_name = r'^[a-zA-Z_\.]+$'
+re_arg = r'^[a-zA-Z][a-zA-z\-]+$'
+re_dest = r'^[a-zA-Z_]+$'
+
+def to_key(v):
+ if not re.match(re_name, v):
+ raise ValueError('invalid key {}'.format(v))
+ return v.upper()
+
+
+class ArgFlag:
+
+ def __init__(self):
+ self.__pure = []
+ self.__alias = []
+ self.__reverse = {}
+ self.__c = 1
+ self.__all = 0
+
+
+ def val(self, v):
+ if isinstance(v, int):
+ if self.__reverse.get(v) == None:
+ raise ValueError('not a valid flag value: {}'.format(v))
+ return v
+ k = to_key(v)
+ return getattr(self, k)
+
+
+ def add(self, k):
+ k = to_key(k)
+ if getattr(self, k, False):
+ raise ValueError('key exists: {}'.format(k))
+ setattr(self, k, self.__c)
+ self.__pure.append(k)
+ self.__reverse[self.__c] = k
+ self.__c <<= 1
+ self.__all = self.__c - 1
+
+
+ def have_all(self, v):
+ c = 1
+ while c < self.__c:
+ if v & c == 0:
+ raise ValueError('bit {} not found in {}'.format(c, v))
+ c <<= 1
+ return v
+
+
+ def match(self, k, v):
+ k = to_key(k)
+ return getattr(self, k) & v > 0
+
+
+ def alias(self, k, *args):
+ k = to_key(k)
+ if getattr(self, k, False):
+ raise ValueError('key exists: {}'.format(k))
+ r = 0
+ for v in args:
+ r |= self.val(v)
+ r = self.have_all(r)
+ setattr(self, k, r)
+ self.__alias.append(k)
+ self.__reverse[r] = k
+
+
+ def less(self, k, v):
+ k = to_key(k)
+ flags = getattr(self, k)
+ mask = ~(self.__all & v)
+ r = flags & mask
+ return r
+
+
+ def more(self, k, v):
+ k = to_key(k)
+ flags = getattr(self, k)
+ return flags | v
+
+
+ def names(self, flags):
+ flags_debug = []
+ c = 1
+ i = 0
+ while c < self.__c:
+ if flags & c > 0:
+ k = self.__pure[i]
+ flags_debug.append(k)
+ c <<= 1
+ i += 1
+ return flags_debug
+
+
+ def get(self, k):
+ k = to_key(k)
+ v = getattr(self, k)
+ return self.val(v)
+
+
+
+class Arg:
+
+ def __init__(self, flags):
+ self.__flags = flags
+ self.__v = {}
+ self.__long = {}
+ self.__k = []
+ self.__dest = {}
+
+
+ def set(self, k, v, dest=None):
+ if len(k) != 1:
+ raise ValueError('short flag must have length 1, got "{}"'.format(k))
+ v = self.__flags.val(v)
+ self.__v[v] = k
+ self.__k.append(k)
+ if dest != None:
+ if not re.match(re_dest, dest):
+ raise ValueError('invalid destination name: {}'.format(dest))
+ else:
+ dest = k
+ self.__dest[k] = dest
+
+
+ def set_long(self, short, long, dest=None):
+ if len(short) != 1:
+ raise ValueError('short flag must have length 1, got "{}"'.format(short))
+ if not re.match(re_arg, long):
+ raise ValueError('invalid flag name: {}'.format(long))
+ if short not in self.__k:
+ raise ValueError('unknown short flag: {}'.format(long))
+ self.__long[short] = long
+
+ if dest != None:
+ if not re.match(re_dest, dest):
+ raise ValueError('invalid destination name: {}'.format(dest))
+ self.__dest[short] = dest
+
+
+ def get(self, k):
+ k = self.__flags.val(k)
+ short = self.__v[k]
+ long = self.__long.get(short)
+ if long != None:
+ long = '--' + long
+ dest = self.__dest[short]
+ short = '-' + short
+ return (short, long, dest,)
diff --git a/setup.cfg b/setup.cfg
@@ -1,10 +1,10 @@
[metadata]
name = aiee
-version = 0.3.0
+version = 0.3.1
description = Common command line interfacing utils
author = Louis Holbrook
author_email = dev@holbrook.no
-url = https://gitlab.com/chaintools/chainlib
+url = https://git.defalsify.org/aiee
keywords =
cli
classifiers =
diff --git a/tests/test_arg.py b/tests/test_arg.py
@@ -0,0 +1,48 @@
+# standard imports
+import unittest
+
+# local imports
+from aiee.flag import (
+ ArgFlag,
+ Arg,
+ )
+
+
+class TestArg(unittest.TestCase):
+
+ def test_arg(self):
+ flag = ArgFlag()
+ flag.add('foo')
+ flag.add('bar')
+ flag.alias('baz', flag.FOO, flag.BAR)
+
+ arg = Arg(flag)
+ arg.set('x', 'baz')
+ r = arg.get('baz')
+
+ self.assertEqual(r[0], '-x')
+ self.assertIsNone(r[1])
+ self.assertEqual(r[2], 'x')
+
+ arg.set_long('x', 'xyzzy')
+ r = arg.get('baz')
+ self.assertEqual(r[0], '-x')
+ self.assertEqual(r[1], '--xyzzy')
+ self.assertEqual(r[2], 'x')
+
+ arg.set('y', 'foo', dest='yyy')
+ r = arg.get('foo')
+ self.assertEqual(r[0], '-y')
+ self.assertIsNone(r[1])
+ self.assertEqual(r[2], 'yyy')
+
+ arg.set_long('y', 'yy', dest='yyyyy')
+ r = arg.get('foo')
+ self.assertEqual(r[0], '-y')
+ self.assertEqual(r[1], '--yy')
+ self.assertEqual(r[2], 'yyyyy')
+
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_flag.py b/tests/test_flag.py
@@ -0,0 +1,64 @@
+# standard imports
+import unittest
+
+# local imports
+from aiee.flag import ArgFlag
+
+
+class TestAieeFlag(unittest.TestCase):
+
+ def test_basic_flag(self):
+ arg = ArgFlag()
+ arg.add('foo')
+ arg.add('bar')
+
+ self.assertEqual(arg.FOO, 1)
+ self.assertEqual(arg.BAR, 2)
+
+ self.assertEqual(arg.get('foo'), arg.FOO)
+ self.assertEqual(arg.get('bar'), arg.BAR)
+
+ arg.alias('baz', arg.FOO, arg.BAR)
+ self.assertEqual(arg.BAZ, 3)
+
+ arg.alias('barbarbar', 'foo', arg.BAR)
+ self.assertEqual(arg.BARBARBAR, 3)
+
+ self.assertTrue(arg.match('foo', arg.FOO))
+ self.assertFalse(arg.match('foo', arg.BAR))
+ self.assertTrue(arg.match('foo', arg.BAZ))
+
+ with self.assertRaises(ValueError):
+ arg.add('foo')
+
+ with self.assertRaises(ValueError):
+ arg.alias('xyzzy', 5)
+
+ self.assertEqual(arg.val('foo'), arg.FOO)
+ self.assertEqual(arg.val(arg.FOO), arg.FOO)
+ self.assertEqual(arg.val(arg.BAZ), arg.BAZ)
+ with self.assertRaises(ValueError):
+ arg.val(4)
+
+
+ def test_name_set(self):
+ arg = ArgFlag()
+ arg.add('foo')
+ arg.add('bar')
+ arg.alias('baz', arg.FOO, arg.BAR)
+
+ self.assertEqual(arg.less('baz', arg.FOO), arg.BAR)
+ self.assertEqual(arg.more('foo', arg.BAR), arg.BAZ)
+
+
+ def test_name_flag(self):
+ arg = ArgFlag()
+ arg.add('foo')
+ arg.add('bar')
+ arg.alias('baz', arg.FOO, arg.BAR)
+
+ self.assertListEqual(arg.names(arg.BAZ), ['FOO', 'BAR'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_numbers.py b/tests/test_numbers.py
@@ -31,7 +31,7 @@ class TestNumbers(unittest.TestCase):
]):
r = postfix_to_int(s + p)
x = v[1] * (10 ** ((i * 3) - v[2]))
- self.assertEqual(x, r)
+ self.assertEqual(int(x), r)
r = postfix_to_int('42E11')
x = 42 * (10 ** 11)