6055f1ac0334f15286ee2d5884863e42b5eb7746
[python_utils.git] / argparse_utils.py
1 #!/usr/bin/python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Helpers for commandline argument parsing."""
6
7 import argparse
8 import datetime
9 import logging
10 import os
11 from typing import Any
12
13 from overrides import overrides
14
15 # This module is commonly used by others in here and should avoid
16 # taking any unnecessary dependencies back on them.
17
18 logger = logging.getLogger(__name__)
19
20
21 class ActionNoYes(argparse.Action):
22     """An argparse Action that allows for commandline arguments like this:
23
24         cfg.add_argument(
25             '--enable_the_thing',
26             action=ActionNoYes,
27             default=False,
28             help='Should we enable the thing?'
29         )
30
31     This creates cmdline arguments:
32
33         --enable_the_thing
34         --no_enable_the_thing
35
36     These arguments can be used to indicate the inclusion or exclusion of
37     binary exclusive behaviors.
38     """
39
40     def __init__(self, option_strings, dest, default=None, required=False, help=None):
41         if default is None:
42             msg = 'You must provide a default with Yes/No action'
43             logger.critical(msg)
44             raise ValueError(msg)
45         if len(option_strings) != 1:
46             msg = 'Only single argument is allowed with NoYes action'
47             logger.critical(msg)
48             raise ValueError(msg)
49         opt = option_strings[0]
50         if not opt.startswith('--'):
51             msg = 'Yes/No arguments must be prefixed with --'
52             logger.critical(msg)
53             raise ValueError(msg)
54
55         opt = opt[2:]
56         opts = ['--' + opt, '--no_' + opt]
57         super().__init__(
58             opts,
59             dest,
60             nargs=0,
61             const=None,
62             default=default,
63             required=required,
64             help=help,
65         )
66
67     @overrides
68     def __call__(self, parser, namespace, values, option_strings=None):
69         if option_strings.startswith('--no-') or option_strings.startswith('--no_'):
70             setattr(namespace, self.dest, False)
71         else:
72             setattr(namespace, self.dest, True)
73
74
75 def valid_bool(v: Any) -> bool:
76     """
77     If the string is a valid bool, return its value.
78
79     >>> valid_bool(True)
80     True
81
82     >>> valid_bool("true")
83     True
84
85     >>> valid_bool("yes")
86     True
87
88     >>> valid_bool("on")
89     True
90
91     >>> valid_bool("1")
92     True
93
94     >>> valid_bool(12345)
95     Traceback (most recent call last):
96     ...
97     argparse.ArgumentTypeError: 12345
98
99     """
100     if isinstance(v, bool):
101         return v
102     from string_utils import to_bool
103
104     try:
105         return to_bool(v)
106     except Exception as e:
107         raise argparse.ArgumentTypeError(v) from e
108
109
110 def valid_ip(ip: str) -> str:
111     """
112     If the string is a valid IPv4 address, return it.  Otherwise raise
113     an ArgumentTypeError.
114
115     >>> valid_ip("1.2.3.4")
116     '1.2.3.4'
117
118     >>> valid_ip("localhost")
119     Traceback (most recent call last):
120     ...
121     argparse.ArgumentTypeError: localhost is an invalid IP address
122
123     """
124     from string_utils import extract_ip_v4
125
126     s = extract_ip_v4(ip.strip())
127     if s is not None:
128         return s
129     msg = f"{ip} is an invalid IP address"
130     logger.error(msg)
131     raise argparse.ArgumentTypeError(msg)
132
133
134 def valid_mac(mac: str) -> str:
135     """
136     If the string is a valid MAC address, return it.  Otherwise raise
137     an ArgumentTypeError.
138
139     >>> valid_mac('12:23:3A:4F:55:66')
140     '12:23:3A:4F:55:66'
141
142     >>> valid_mac('12-23-3A-4F-55-66')
143     '12-23-3A-4F-55-66'
144
145     >>> valid_mac('big')
146     Traceback (most recent call last):
147     ...
148     argparse.ArgumentTypeError: big is an invalid MAC address
149
150     """
151     from string_utils import extract_mac_address
152
153     s = extract_mac_address(mac)
154     if s is not None:
155         return s
156     msg = f"{mac} is an invalid MAC address"
157     logger.error(msg)
158     raise argparse.ArgumentTypeError(msg)
159
160
161 def valid_percentage(num: str) -> float:
162     """
163     If the string is a valid percentage, return it.  Otherwise raise
164     an ArgumentTypeError.
165
166     >>> valid_percentage("15%")
167     15.0
168
169     >>> valid_percentage('40')
170     40.0
171
172     >>> valid_percentage('115')
173     Traceback (most recent call last):
174     ...
175     argparse.ArgumentTypeError: 115 is an invalid percentage; expected 0 <= n <= 100.0
176
177     """
178     num = num.strip('%')
179     n = float(num)
180     if 0.0 <= n <= 100.0:
181         return n
182     msg = f"{num} is an invalid percentage; expected 0 <= n <= 100.0"
183     logger.error(msg)
184     raise argparse.ArgumentTypeError(msg)
185
186
187 def valid_filename(filename: str) -> str:
188     """
189     If the string is a valid filename, return it.  Otherwise raise
190     an ArgumentTypeError.
191
192     >>> valid_filename('/tmp')
193     '/tmp'
194
195     >>> valid_filename('wfwefwefwefwefwefwefwefwef')
196     Traceback (most recent call last):
197     ...
198     argparse.ArgumentTypeError: wfwefwefwefwefwefwefwefwef was not found and is therefore invalid.
199
200     """
201     s = filename.strip()
202     if os.path.exists(s):
203         return s
204     msg = f"{filename} was not found and is therefore invalid."
205     logger.error(msg)
206     raise argparse.ArgumentTypeError(msg)
207
208
209 def valid_date(txt: str) -> datetime.date:
210     """If the string is a valid date, return it.  Otherwise raise
211     an ArgumentTypeError.
212
213     >>> valid_date('6/5/2021')
214     datetime.date(2021, 6, 5)
215
216     # Note: dates like 'next wednesday' work fine, they are just
217     # hard to test for without knowing when the testcase will be
218     # executed...
219     >>> valid_date('next wednesday') # doctest: +ELLIPSIS
220     -ANYTHING-
221     """
222     from string_utils import to_date
223
224     date = to_date(txt)
225     if date is not None:
226         return date
227     msg = f'Cannot parse argument as a date: {txt}'
228     logger.error(msg)
229     raise argparse.ArgumentTypeError(msg)
230
231
232 def valid_datetime(txt: str) -> datetime.datetime:
233     """If the string is a valid datetime, return it.  Otherwise raise
234     an ArgumentTypeError.
235
236     >>> valid_datetime('6/5/2021 3:01:02')
237     datetime.datetime(2021, 6, 5, 3, 1, 2)
238
239     # Again, these types of expressions work fine but are
240     # difficult to test with doctests because the answer is
241     # relative to the time the doctest is executed.
242     >>> valid_datetime('next christmas at 4:15am') # doctest: +ELLIPSIS
243     -ANYTHING-
244     """
245     from string_utils import to_datetime
246
247     dt = to_datetime(txt)
248     if dt is not None:
249         return dt
250     msg = f'Cannot parse argument as datetime: {txt}'
251     logger.error(msg)
252     raise argparse.ArgumentTypeError(msg)
253
254
255 def valid_duration(txt: str) -> datetime.timedelta:
256     """If the string is a valid time duration, return a
257     datetime.timedelta representing the period of time.  Otherwise
258     maybe raise an ArgumentTypeError or potentially just treat the
259     time window as zero in length.
260
261     >>> valid_duration('3m')
262     datetime.timedelta(seconds=180)
263
264     >>> valid_duration('your mom')
265     datetime.timedelta(0)
266
267     """
268     from datetime_utils import parse_duration
269
270     try:
271         secs = parse_duration(txt)
272         return datetime.timedelta(seconds=secs)
273     except Exception as e:
274         logger.exception(e)
275         raise argparse.ArgumentTypeError(e) from e
276
277
278 if __name__ == '__main__':
279     import doctest
280
281     doctest.ELLIPSIS_MARKER = '-ANYTHING-'
282     doctest.testmod()