385 lines
10 KiB
Python
385 lines
10 KiB
Python
import argparse
|
|
import inspect
|
|
import logging
|
|
import re
|
|
import sys
|
|
import textwrap
|
|
from typing import Any, Callable, Iterable, List, Optional, Sequence
|
|
|
|
PY39PLUS = sys.version_info[0] > 3 or sys.version_info[1] >= 9
|
|
|
|
|
|
def _module_version(func):
|
|
version = None
|
|
for v in "__version__ VERSION version".split():
|
|
version = func.__globals__.get(v)
|
|
if version:
|
|
break
|
|
return version
|
|
|
|
|
|
class _ParagraphPreservingArgParseFormatter(argparse.HelpFormatter):
|
|
def __init__(self, *args, **kwargs):
|
|
super(_ParagraphPreservingArgParseFormatter, self).__init__(*args, **kwargs)
|
|
self._long_break_matcher = argparse._re.compile(r"\n\n+")
|
|
|
|
def _fill_text(self, text, width, indent):
|
|
output = []
|
|
for block in self._long_break_matcher.split(text.strip()):
|
|
output.append(
|
|
textwrap.fill(
|
|
block, width, initial_indent=indent, subsequent_indent=indent
|
|
)
|
|
)
|
|
return "\n\n".join(output + [""])
|
|
|
|
|
|
def _parse_doc(docs):
|
|
"""
|
|
Converts a well-formed docstring into documentation
|
|
to be fed into argparse.
|
|
|
|
See signature_parser for details.
|
|
|
|
shorts: (-k for --keyword -k, or "from" for "frm/from")
|
|
metavars: (FILE for --input=FILE)
|
|
helps: (docs for --keyword: docs)
|
|
description: the stuff before
|
|
epilog: the stuff after
|
|
"""
|
|
|
|
name = "(?:[a-zA-Z][a-zA-Z0-9-_]*)"
|
|
|
|
re_var = re.compile(r"^ *(%s)(?: */(%s))? *:(.*)$" % (name, name))
|
|
re_opt = re.compile(
|
|
r"^ *(?:(-[a-zA-Z0-9]),? +)?--(%s)(?: *=(%s))? *:(.*)$" % (name, name)
|
|
)
|
|
|
|
shorts, metavars, helps, description, epilog = {}, {}, {}, "", ""
|
|
|
|
if docs:
|
|
prev = ""
|
|
for line in docs.split("\n"):
|
|
|
|
line = line.strip()
|
|
|
|
# remove starting ':param'
|
|
if line.startswith(":param"):
|
|
line = line[len(":param") :]
|
|
|
|
# skip ':rtype:' row
|
|
if line.startswith(":rtype:"):
|
|
continue
|
|
|
|
if line.strip() == "----":
|
|
break
|
|
|
|
m = re_var.match(line)
|
|
if m:
|
|
if epilog:
|
|
helps[prev] += epilog.strip()
|
|
epilog = ""
|
|
|
|
if m.group(2):
|
|
shorts[m.group(1)] = m.group(2)
|
|
|
|
helps[m.group(1)] = m.group(3).strip()
|
|
prev = m.group(1)
|
|
previndent = len(line) - len(line.lstrip())
|
|
continue
|
|
|
|
m = re_opt.match(line)
|
|
if m:
|
|
if epilog:
|
|
helps[prev] += epilog.strip()
|
|
epilog = ""
|
|
name = m.group(2).replace("-", "_")
|
|
helps[name] = m.group(4)
|
|
prev = name
|
|
|
|
if m.group(1):
|
|
shorts[name] = m.group(1)
|
|
if m.group(3):
|
|
metavars[name] = m.group(3)
|
|
|
|
previndent = len(line) - len(line.lstrip())
|
|
continue
|
|
|
|
if helps:
|
|
if line.startswith(" " * (previndent + 1)):
|
|
helps[prev] += "\n" + line.strip()
|
|
else:
|
|
epilog += "\n" + line.strip()
|
|
else:
|
|
description += "\n" + line.strip()
|
|
|
|
if line.strip():
|
|
previndent = len(line) - len(line.lstrip())
|
|
|
|
return shorts, metavars, helps, description, epilog
|
|
|
|
|
|
def _listLike(ann, t):
|
|
ret = ann is List[t] or ann is Sequence[t] or ann is Iterable[t]
|
|
if PY39PLUS:
|
|
ret = ret or ann == list[t]
|
|
return ret
|
|
|
|
|
|
def _toStr(x):
|
|
return x
|
|
|
|
|
|
def _toBytes(x):
|
|
return bytes(x, "utf-8")
|
|
|
|
|
|
def _toBool(x):
|
|
return x.strip().lower() not in ["false", "0", "no", ""]
|
|
|
|
|
|
def _useAnnotation(ann, positional=False):
|
|
# https://stackoverflow.com/questions/48572831/how-to-access-the-type-arguments-of-typing-generic
|
|
d = {}
|
|
d["action"] = "store"
|
|
d["type"] = _toStr
|
|
islist = False
|
|
|
|
if ann is str:
|
|
pass
|
|
elif ann is bytes:
|
|
d["type"] = _toBytes
|
|
elif ann is bool:
|
|
d["type"] = _toBool
|
|
elif _listLike(ann, str):
|
|
islist = True
|
|
elif _listLike(ann, bytes):
|
|
islist = True
|
|
d["type"] = _toBytes
|
|
elif _listLike(ann, int):
|
|
islist = True
|
|
d["type"] = int
|
|
elif _listLike(ann, float):
|
|
islist = True
|
|
d["type"] = float
|
|
elif _listLike(ann, complex):
|
|
islist = True
|
|
d["type"] = complex
|
|
elif _listLike(ann, bool):
|
|
islist = True
|
|
d["type"] = _toBool
|
|
elif ann is Any:
|
|
pass
|
|
elif ann is Optional[str]:
|
|
pass
|
|
elif ann is Optional[bytes]:
|
|
d["type"] = _toBytes
|
|
elif ann is Optional[int]:
|
|
d["type"] = int
|
|
elif ann is Optional[float]:
|
|
d["type"] = float
|
|
elif ann is Optional[complex]:
|
|
d["type"] = complex
|
|
elif ann is Optional[bool]:
|
|
d["type"] = _toBool
|
|
else:
|
|
d["type"] = ann
|
|
|
|
nargs = None
|
|
if islist:
|
|
if positional:
|
|
nargs = "*"
|
|
else:
|
|
d["action"] = "append"
|
|
|
|
return d["action"], d["type"], nargs
|
|
|
|
|
|
def _signature_parser(func):
|
|
# args, varargs, varkw, defaults = inspect.getargspec(func)
|
|
(
|
|
args,
|
|
varargs,
|
|
varkw,
|
|
defaults,
|
|
kwonlyargs,
|
|
kwonlydefaults,
|
|
annotations,
|
|
) = inspect.getfullargspec(func)
|
|
# print(f"func: {func}")
|
|
# print(f"args: {args}")
|
|
# print(f"varargs: {varargs}")
|
|
# print(f"varkw: {varkw}")
|
|
# print(f"defaults: {defaults}")
|
|
# print(f"kwonlyargs: {kwonlyargs}")
|
|
# print(f"kwonlydefaults: {kwonlydefaults}")
|
|
# print(f"annotations: {annotations}")
|
|
if not args:
|
|
args = []
|
|
|
|
if not defaults:
|
|
defaults = []
|
|
|
|
if varkw:
|
|
raise ValueError("Can't wrap a function with **kwargs")
|
|
|
|
# Compulsary positional options
|
|
needed = args[0 : len(args) - len(defaults)]
|
|
|
|
# Optional flag options
|
|
params = args[len(needed) :]
|
|
|
|
shorts, metavars, helps, description, epilog = _parse_doc(func.__doc__)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description=description,
|
|
epilog=epilog,
|
|
formatter_class=_ParagraphPreservingArgParseFormatter,
|
|
)
|
|
|
|
# special flags
|
|
special_flags = []
|
|
|
|
special_flags += ["debug"]
|
|
defaults += (False,)
|
|
helps["debug"] = "set logging level to DEBUG"
|
|
if _module_version(func):
|
|
special_flags += ["version"]
|
|
defaults += (False,)
|
|
helps["version"] = "show program's version number and exit"
|
|
params += special_flags
|
|
|
|
# Optional flag options f(p=1)
|
|
used_shorts = set()
|
|
for param, default in zip(params, defaults):
|
|
args = ["--%s" % param.replace("_", "-")]
|
|
short = None
|
|
if param in shorts:
|
|
short = shorts[param]
|
|
else:
|
|
if param not in special_flags and len(param) > 1:
|
|
first_char = param[0]
|
|
if first_char not in used_shorts:
|
|
used_shorts.add(first_char)
|
|
short = "-" + first_char
|
|
# -h conflicts with 'help'
|
|
if short and short != "-h":
|
|
args = [short] + args
|
|
|
|
d = {"default": default, "dest": param.replace("-", "_")}
|
|
|
|
ann = annotations.get(param)
|
|
if param == "version":
|
|
d["action"] = "version"
|
|
d["version"] = _module_version(func)
|
|
elif default is True:
|
|
d["action"] = "store_false"
|
|
elif default is False:
|
|
d["action"] = "store_true"
|
|
elif ann:
|
|
d["action"], d["type"], _ = _useAnnotation(ann)
|
|
elif isinstance(default, list):
|
|
d["action"] = "append"
|
|
d["type"] = _toStr
|
|
elif isinstance(default, str):
|
|
d["action"] = "store"
|
|
d["type"] = _toStr
|
|
elif isinstance(default, bytes):
|
|
d["action"] = "store"
|
|
d["type"] = _toBytes
|
|
elif default is None:
|
|
d["action"] = "store"
|
|
d["type"] = _toStr
|
|
else:
|
|
d["action"] = "store"
|
|
d["type"] = type(default)
|
|
|
|
if param in helps:
|
|
d["help"] = helps[param]
|
|
|
|
if param in metavars:
|
|
d["metavar"] = metavars[param]
|
|
parser.add_argument(*args, **d)
|
|
|
|
# Compulsary positional options f(p1,p2)
|
|
for need in needed:
|
|
|
|
ann = annotations.get(need)
|
|
d = {"action": "store"}
|
|
if ann:
|
|
d["action"], d["type"], nargs = _useAnnotation(ann, positional=True)
|
|
if nargs:
|
|
d["nargs"] = nargs
|
|
else:
|
|
d["type"] = _toStr
|
|
|
|
if need in helps:
|
|
d["help"] = helps[need]
|
|
|
|
if need in shorts:
|
|
args = [shorts[need]]
|
|
else:
|
|
args = [need]
|
|
|
|
parser.add_argument(*args, **d)
|
|
|
|
# The trailing arguments f(*args)
|
|
if varargs:
|
|
d = {"action": "store", "type": _toStr, "nargs": "*"}
|
|
|
|
if varargs in helps:
|
|
d["help"] = helps[varargs]
|
|
|
|
if varargs in shorts:
|
|
d["metavar"] = shorts[varargs]
|
|
else:
|
|
d["metavar"] = varargs
|
|
|
|
parser.add_argument("__args", **d)
|
|
|
|
return parser
|
|
|
|
|
|
def _correct_args(func, kwargs):
|
|
"""
|
|
Convert a dictionary of arguments including __argv into a list
|
|
for passing to the function.
|
|
"""
|
|
args = inspect.getfullargspec(func)[0]
|
|
return [kwargs[arg] for arg in args] + kwargs["__args"]
|
|
|
|
|
|
def entrypoint(func: Callable) -> Callable:
|
|
frame_local = sys._getframe(1).f_locals
|
|
if "__name__" in frame_local and frame_local["__name__"] == "__main__":
|
|
argv = sys.argv[1:]
|
|
# print("__annotations__ ", func.__annotations__)
|
|
# print("__total__", func.__total__)
|
|
parser = _signature_parser(func)
|
|
kwargs = parser.parse_args(argv).__dict__
|
|
|
|
# special cli flags
|
|
|
|
# --version is handled by ArgParse
|
|
# if kwargs.get('version'):
|
|
# print module_version(func)
|
|
# return
|
|
if "version" in kwargs.keys():
|
|
del kwargs["version"]
|
|
|
|
# --debug
|
|
FORMAT = "%(asctime)-6s: %(name)s - %(levelname)s - %(message)s"
|
|
if kwargs.get("debug"):
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format=FORMAT,
|
|
)
|
|
del kwargs["debug"]
|
|
|
|
if "__args" in kwargs:
|
|
return func(*_correct_args(func, kwargs))
|
|
else:
|
|
return func(**kwargs)
|
|
|
|
return func
|