commit e72293d63a232ccb2c513d40ad7b06cac583a974
parent fb7b82b429a95bdc80846c293a3eed409d25e920
Author: lash <dev@holbrook.no>
Date: Wed, 11 May 2022 11:51:55 +0000
Implement multi arg
Diffstat:
2 files changed, 62 insertions(+), 36 deletions(-)
diff --git a/aiee/flag.py b/aiee/flag.py
@@ -115,11 +115,13 @@ class Arg:
self.__typ = {}
- def set(self, k, v, typ=str, dest=None, **kwargs):
+ def add(self, k, v, typ=str, dest=None, **kwargs):
if len(k) != 1:
raise ValueError('short flag must have length 1, got "{}"'.format(k))
v = self.__flags.val(v)
- self.__v[v] = k
+ if self.__v.get(v) == None:
+ self.__v[v] = []
+ self.__v[v].append(k)
self.__k.append(k)
if dest != None:
if not re.match(re_dest, dest):
@@ -149,14 +151,16 @@ class Arg:
def get(self, k):
k = self.__flags.val(k)
- short = self.__v[k]
- long = self.__long.get(short)
- if long != None:
- long = '--' + long
- dest = self.__dest[short]
- typ = self.__typ[short]
- short = '-' + short
- return (short, long, dest, typ,)
+ r = []
+ for short in self.__v[k]:
+ long = self.__long.get(short)
+ if long != None:
+ long = '--' + long
+ dest = self.__dest[short]
+ typ = self.__typ[short]
+ short = '-' + short
+ r.append((short, long, dest, typ,))
+ return r
def __iter__(self):
@@ -184,18 +188,18 @@ class Arg:
return self.__x[k]
-def process_args(argparser, arg, flags):
- for flag in arg:
- (short, long, dest, typ) = arg.get(flag)
- kw = arg.kwargs(short[1])
- if long == None:
- if kw == None:
- argparser.add_argument(short, type=typ, dest=dest)
- else:
- argparser.add_argument(short, type=typ, dest=dest, **kw)
- else:
- if kw == None:
- argparser.add_argument(short, long, type=typ, dest=dest)
+def process_args(argparser, args, flags):
+ for flag in args:
+ for (short, long, dest, typ,) in args.get(flag):
+ kw = args.kwargs(short[1])
+ if long == None:
+ if kw == None:
+ argparser.add_argument(short, type=typ, dest=dest)
+ else:
+ argparser.add_argument(short, type=typ, dest=dest, **kw)
else:
- argparser.add_argument(short, long, type=typ, dest=dest, **kw)
+ if kw == None:
+ argparser.add_argument(short, long, type=typ, dest=dest)
+ else:
+ argparser.add_argument(short, long, type=typ, dest=dest, **kw)
return argparser
diff --git a/tests/test_arg.py b/tests/test_arg.py
@@ -19,27 +19,27 @@ class TestArg(unittest.TestCase):
flag.alias('baz', flag.FOO, flag.BAR)
arg = Arg(flag)
- arg.set('x', 'baz')
- r = arg.get('baz')
+ arg.add('x', 'baz')
+ r = arg.get('baz')[0]
self.assertEqual(r[0], '-x')
self.assertIsNone(r[1])
self.assertEqual(r[2], 'x')
arg.set_long('x', 'xyzzy')
- r = arg.get('baz')
+ r = arg.get('baz')[0]
self.assertEqual(r[0], '-x')
self.assertEqual(r[1], '--xyzzy')
self.assertEqual(r[2], 'x')
- arg.set('y', 'foo', dest='yyy')
- r = arg.get('foo')
+ arg.add('y', 'foo', dest='yyy')
+ r = arg.get('foo')[0]
self.assertEqual(r[0], '-y')
self.assertIsNone(r[1])
self.assertEqual(r[2], 'yyy')
arg.set_long('y', 'yy', dest='yyyyy')
- r = arg.get('foo')
+ r = arg.get('foo')[0]
self.assertEqual(r[0], '-y')
self.assertEqual(r[1], '--yy')
self.assertEqual(r[2], 'yyyyy')
@@ -50,8 +50,14 @@ class TestArg(unittest.TestCase):
flag.add('foo')
arg = Arg(flag)
- arg.set('x', 'foo')
- arg.set('y', 'foo')
+ arg.add('x', 'foo')
+ arg.add('y', 'foo')
+
+ r = arg.get('foo')
+ self.assertEqual(len(r), 2)
+ self.assertEqual(r[0][0], '-x')
+ self.assertEqual(r[1][0], '-y')
+
def test_arg_iter(self):
flags = ArgFlag()
@@ -66,13 +72,13 @@ class TestArg(unittest.TestCase):
self.assertEqual(len(r), 0)
r = []
- arg.set('y', 'bar')
+ arg.add('y', 'bar')
for flag in arg:
r.append(flag)
self.assertListEqual(r, [flags.BAR])
r = []
- arg.set('x', 'foo')
+ arg.add('x', 'foo')
for flag in arg:
r.append(flag)
self.assertListEqual(r, [flags.FOO, flags.BAR])
@@ -83,7 +89,7 @@ class TestArg(unittest.TestCase):
flags.add('foo')
arg = Arg(flags)
- arg.set('x', 'foo')
+ arg.add('x', 'foo')
argparser = argparse.ArgumentParser()
argparser = process_args(argparser, arg, flags)
@@ -91,13 +97,29 @@ class TestArg(unittest.TestCase):
self.assertEqual(r.x, '13')
+
+ def test_process_argparser_multi(self):
+ flags = ArgFlag()
+ flags.add('foo')
+
+ arg = Arg(flags)
+ arg.add('x', 'foo')
+ arg.add('y', 'foo')
+
+ argparser = argparse.ArgumentParser()
+ argparser = process_args(argparser, arg, flags)
+ r = argparser.parse_args(['-x', '13', '-y', '42'])
+
+ self.assertEqual(r.x, '13')
+ self.assertEqual(r.y, '42')
+
def test_process_argparser_typ(self):
flags = ArgFlag()
flags.add('foo')
arg = Arg(flags)
- arg.set('x', 'foo', typ=int)
+ arg.add('x', 'foo', typ=int)
argparser = argparse.ArgumentParser()
argparser = process_args(argparser, arg, flags)
@@ -112,7 +134,7 @@ class TestArg(unittest.TestCase):
flag.add('foo')
arg = Arg(flag)
- arg.set('x', 'foo', help='foo', me=42)
+ arg.add('x', 'foo', help='foo', me=42)
if __name__ == '__main__':