commit d59f467b1d8400783a0c7537302a609755a98046
parent 4ae495fe324f7ae4d88057fe18c483a0d9ad97bc
Author: lash <dev@holbrook.no>
Date: Wed, 11 May 2022 16:25:47 +0000
Add long only alternative, arg match and val return
Diffstat:
3 files changed, 125 insertions(+), 21 deletions(-)
diff --git a/aiee/arg.py b/aiee/arg.py
@@ -22,6 +22,11 @@ class ArgFlag:
self.__all = 0
+ @property
+ def all(self):
+ return self.__all
+
+
def val(self, v):
if isinstance(v, int):
if self.__reverse.get(v) == None:
@@ -43,11 +48,8 @@ class ArgFlag:
def have_all(self, v):
- c = 1
- while c < self.__c:
- if v & c == 0:
- raise ValueError('bit {} not found in {}'.format(c, v))
- c <<= 1
+ if v & self.__all != v:
+ raise ValueError('missing flag {} in {}'.format(v, self.__all))
return v
@@ -109,19 +111,22 @@ class Arg:
self.__v = {}
self.__long = {}
self.__k = []
+ self.__l = []
self.__dest = {}
self.__x = {}
self.__crsr = 0
self.__typ = {}
+ self.__z = 0
- def add(self, k, v, typ=str, dest=None, **kwargs):
- if len(k) != 1:
+ def add(self, k, v, check=True, typ=str, dest=None, **kwargs):
+ if len(k) != 1 and check:
raise ValueError('short flag must have length 1, got "{}"'.format(k))
v = self.__flags.val(v)
if self.__v.get(v) == None:
self.__v[v] = []
self.__v[v].append(k)
+ self.__z |= v
self.__k.append(k)
if dest != None:
if not re.match(re_dest, dest):
@@ -134,9 +139,22 @@ class Arg:
self.__typ[k] = typ
+ def add_long(self, k, v, typ=str, dest=None, **kwargs):
+ v = self.__flags.val(v)
+ if self.__v.get(v) == None:
+ self.__v[v] = []
+ self.__v[v].append(k)
+ self.__l.append(k)
+ if dest != None:
+ if not re.match(re_dest, dest):
+ raise ValueError('invalid destination name: {}'.format(dest))
+ self.__dest[k] = dest
+
+ self.__x[k] = kwargs
+ self.__typ[k] = typ
+
+
def set_long(self, short, long, dest=None):
- if len(short) != 1:
- raise ValueError('short flag must have length 1, got "{}"'.format(short))
if not re.match(re_arg, long):
raise ValueError('invalid flag name: {}'.format(long))
if short not in self.__k:
@@ -152,17 +170,31 @@ class Arg:
def get(self, k):
k = self.__flags.val(k)
r = []
- for short in self.__v[k]:
- long = self.__long.get(short)
- if long != None:
- long = '--' + long
- dest = self.__dest[short]
- typ = self.__typ[short]
- short = '-' + short
+ for v in self.__v[k]:
+ long = None
+ short = None
+ if v in self.__l:
+ long = '--' + v
+ else:
+ long = self.__long.get(v)
+ if long != None:
+ long = '--' + long
+ short = '-' + v
+ dest = self.__dest[v]
+ typ = self.__typ[v]
r.append((short, long, dest, typ,))
return r
+ def val(self, k):
+ return self.__flags.get(k)
+
+
+ def match(self, v):
+ v = self.__flags.val(v)
+ return v & self.__z == v
+
+
def __iter__(self):
self.__crsr = 1
return self
@@ -196,7 +228,11 @@ def process_args(argparser, args, flags):
for (short, long, dest, typ,) in args.get(flag):
- kw = args.kwargs(short[1])
+ kw = {}
+ try:
+ kw = args.kwargs(short[1:])
+ except TypeError:
+ kw = args.kwargs(long[2:])
if typ == bool:
kw['action'] = 'store_true'
@@ -206,6 +242,9 @@ def process_args(argparser, args, flags):
if long == None:
argparser.add_argument(short, **kw)
+ elif short == None:
+ argparser.add_argument(long, **kw)
else:
argparser.add_argument(short, long, **kw)
+
return argparser
diff --git a/tests/test_arg.py b/tests/test_arg.py
@@ -3,7 +3,7 @@ import unittest
import argparse
# local imports
-from aiee.flag import (
+from aiee.arg import (
ArgFlag,
Arg,
process_args,
@@ -59,6 +59,20 @@ class TestArg(unittest.TestCase):
self.assertEqual(r[1][0], '-y')
+ def test_arg_longonly(self):
+ flag = ArgFlag()
+ flag.add('foo')
+
+ arg = Arg(flag)
+ arg.add('x', 'foo')
+ arg.add_long('yyy', 'foo')
+
+ r = arg.get('foo')
+ self.assertEqual(len(r), 2)
+ self.assertEqual(r[0][0], '-x')
+ self.assertEqual(r[1][1], '--yyy')
+
+
def test_arg_iter(self):
flags = ArgFlag()
flags.add('foo')
@@ -84,6 +98,26 @@ class TestArg(unittest.TestCase):
self.assertListEqual(r, [flags.FOO, flags.BAR])
+ def test_arg_iter_mix(self):
+ flags = ArgFlag()
+ flags.add('foo')
+ flags.add('bar')
+
+ arg = Arg(flags)
+
+ r = []
+ arg.add('y', 'bar')
+ for flag in arg:
+ r.append(flag)
+ self.assertListEqual(r, [flags.BAR])
+
+ r = []
+ arg.add_long('xxx', 'foo')
+ for flag in arg:
+ r.append(flag)
+ self.assertListEqual(r, [flags.FOO, flags.BAR])
+
+
def test_process_argparser(self):
flags = ArgFlag()
flags.add('foo')
@@ -105,13 +139,15 @@ class TestArg(unittest.TestCase):
arg = Arg(flags)
arg.add('x', 'foo')
arg.add('y', 'foo')
+ arg.add_long('zzz', 'foo')
argparser = argparse.ArgumentParser()
argparser = process_args(argparser, arg, flags.FOO)
- r = argparser.parse_args(['-x', '13', '-y', '42'])
+ r = argparser.parse_args(['-x', '13', '-y', '42', '--zzz', '666'])
self.assertEqual(r.x, '13')
self.assertEqual(r.y, '42')
+ self.assertEqual(r.zzz, '666')
def test_process_argparser_multi_alias(self):
@@ -202,7 +238,36 @@ class TestArg(unittest.TestCase):
r = argparser.parse_args(['-x'])
self.assertIsInstance(r.x, bool)
-
+
+
+ def test_val(self):
+ flags = ArgFlag()
+ flags.add('foo')
+
+ args = Arg(flags)
+ args.add('x', 'foo', typ=bool)
+
+ r = args.val('foo')
+ self.assertEqual(r, flags.FOO)
+
+
+ def test_match(self):
+ flags = ArgFlag()
+ flags.add('foo')
+ flags.add('bar')
+ flags.alias('baz', 'foo', 'bar')
+
+ args = Arg(flags)
+ args.add('x', 'foo', typ=bool)
+ self.assertTrue(args.match('foo'))
+ self.assertFalse(args.match('bar'))
+ self.assertFalse(args.match('baz'))
+
+ args.add('y', 'bar', typ=bool)
+ self.assertTrue(args.match('foo'))
+ self.assertTrue(args.match('bar'))
+ self.assertTrue(args.match('baz'))
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_flag.py b/tests/test_flag.py
@@ -2,7 +2,7 @@
import unittest
# local imports
-from aiee.flag import ArgFlag
+from aiee.arg import ArgFlag
class TestAieeFlag(unittest.TestCase):