Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / argparse_utils.py
1 #!/usr/bin/python3
2
3 import argparse
4 import datetime
5 import logging
6 import os
7 from typing import Any
8
9 from overrides import overrides
10
11 # This module is commonly used by others in here and should avoid
12 # taking any unnecessary dependencies back on them.
13
14 logger = logging.getLogger(__name__)
15
16
17 class ActionNoYes(argparse.Action):
18     def __init__(self, option_strings, dest, default=None, required=False, help=None):
19         if default is None:
20             msg = 'You must provide a default with Yes/No action'
21             logger.critical(msg)
22             raise ValueError(msg)
23         if len(option_strings) != 1:
24             msg = 'Only single argument is allowed with YesNo action'
25             logger.critical(msg)
26             raise ValueError(msg)
27         opt = option_strings[0]
28         if not opt.startswith('--'):
29             msg = 'Yes/No arguments must be prefixed with --'
30             logger.critical(msg)
31             raise ValueError(msg)
32
33         opt = opt[2:]
34         opts = ['--' + opt, '--no_' + opt]
35         super().__init__(
36             opts,
37             dest,
38             nargs=0,
39             const=None,
40             default=default,
41             required=required,
42             help=help,
43         )
44
45     @overrides
46     def __call__(self, parser, namespace, values, option_strings=None):
47         if option_strings.startswith('--no-') or option_strings.startswith('--no_'):
48             setattr(namespace, self.dest, False)
49         else:
50             setattr(namespace, self.dest, True)
51
52
53 def valid_bool(v: Any) -> bool:
54     """
55     If the string is a valid bool, return its value.
56
57     >>> valid_bool(True)
58     True
59
60     >>> valid_bool("true")
61     True
62
63     >>> valid_bool("yes")
64     True
65
66     >>> valid_bool("on")
67     True
68
69     >>> valid_bool("1")
70     True
71
72     >>> valid_bool(12345)
73     Traceback (most recent call last):
74     ...
75     argparse.ArgumentTypeError: 12345
76
77     """
78     if isinstance(v, bool):
79         return v
80     from string_utils import to_bool
81
82     try:
83         return to_bool(v)
84     except Exception:
85         raise argparse.ArgumentTypeError(v)
86
87
88 def valid_ip(ip: str) -> str:
89     """
90     If the string is a valid IPv4 address, return it.  Otherwise raise
91     an ArgumentTypeError.
92
93     >>> valid_ip("1.2.3.4")
94     '1.2.3.4'
95
96     >>> valid_ip("localhost")
97     Traceback (most recent call last):
98     ...
99     argparse.ArgumentTypeError: localhost is an invalid IP address
100
101     """
102     from string_utils import extract_ip_v4
103
104     s = extract_ip_v4(ip.strip())
105     if s is not None:
106         return s
107     msg = f"{ip} is an invalid IP address"
108     logger.error(msg)
109     raise argparse.ArgumentTypeError(msg)
110
111
112 def valid_mac(mac: str) -> str:
113     """
114     If the string is a valid MAC address, return it.  Otherwise raise
115     an ArgumentTypeError.
116
117     >>> valid_mac('12:23:3A:4F:55:66')
118     '12:23:3A:4F:55:66'
119
120     >>> valid_mac('12-23-3A-4F-55-66')
121     '12-23-3A-4F-55-66'
122
123     >>> valid_mac('big')
124     Traceback (most recent call last):
125     ...
126     argparse.ArgumentTypeError: big is an invalid MAC address
127
128     """
129     from string_utils import extract_mac_address
130
131     s = extract_mac_address(mac)
132     if s is not None:
133         return s
134     msg = f"{mac} is an invalid MAC address"
135     logger.error(msg)
136     raise argparse.ArgumentTypeError(msg)
137
138
139 def valid_percentage(num: str) -> float:
140     """
141     If the string is a valid percentage, return it.  Otherwise raise
142     an ArgumentTypeError.
143
144     >>> valid_percentage("15%")
145     15.0
146
147     >>> valid_percentage('40')
148     40.0
149
150     >>> valid_percentage('115')
151     Traceback (most recent call last):
152     ...
153     argparse.ArgumentTypeError: 115 is an invalid percentage; expected 0 <= n <= 100.0
154
155     """
156     num = num.strip('%')
157     n = float(num)
158     if 0.0 <= n <= 100.0:
159         return n
160     msg = f"{num} is an invalid percentage; expected 0 <= n <= 100.0"
161     logger.error(msg)
162     raise argparse.ArgumentTypeError(msg)
163
164
165 def valid_filename(filename: str) -> str:
166     """
167     If the string is a valid filename, return it.  Otherwise raise
168     an ArgumentTypeError.
169
170     >>> valid_filename('/tmp')
171     '/tmp'
172
173     >>> valid_filename('wfwefwefwefwefwefwefwefwef')
174     Traceback (most recent call last):
175     ...
176     argparse.ArgumentTypeError: wfwefwefwefwefwefwefwefwef was not found and is therefore invalid.
177
178     """
179     s = filename.strip()
180     if os.path.exists(s):
181         return s
182     msg = f"{filename} was not found and is therefore invalid."
183     logger.error(msg)
184     raise argparse.ArgumentTypeError(msg)
185
186
187 def valid_date(txt: str) -> datetime.date:
188     """If the string is a valid date, return it.  Otherwise raise
189     an ArgumentTypeError.
190
191     >>> valid_date('6/5/2021')
192     datetime.date(2021, 6, 5)
193
194     # Note: dates like 'next wednesday' work fine, they are just
195     # hard to test for without knowing when the testcase will be
196     # executed...
197     >>> valid_date('next wednesday') # doctest: +ELLIPSIS
198     -ANYTHING-
199     """
200     from string_utils import to_date
201
202     date = to_date(txt)
203     if date is not None:
204         return date
205     msg = f'Cannot parse argument as a date: {txt}'
206     logger.error(msg)
207     raise argparse.ArgumentTypeError(msg)
208
209
210 def valid_datetime(txt: str) -> datetime.datetime:
211     """If the string is a valid datetime, return it.  Otherwise raise
212     an ArgumentTypeError.
213
214     >>> valid_datetime('6/5/2021 3:01:02')
215     datetime.datetime(2021, 6, 5, 3, 1, 2)
216
217     # Again, these types of expressions work fine but are
218     # difficult to test with doctests because the answer is
219     # relative to the time the doctest is executed.
220     >>> valid_datetime('next christmas at 4:15am') # doctest: +ELLIPSIS
221     -ANYTHING-
222     """
223     from string_utils import to_datetime
224
225     dt = to_datetime(txt)
226     if dt is not None:
227         return dt
228     msg = f'Cannot parse argument as datetime: {txt}'
229     logger.error(msg)
230     raise argparse.ArgumentTypeError(msg)
231
232
233 def valid_duration(txt: str) -> datetime.timedelta:
234     """If the string is a valid time duration, return a
235     datetime.timedelta representing the period of time.  Otherwise
236     maybe raise an ArgumentTypeError or potentially just treat the
237     time window as zero in length.
238
239     >>> valid_duration('3m')
240     datetime.timedelta(seconds=180)
241
242     >>> valid_duration('your mom')
243     datetime.timedelta(0)
244
245     """
246     from datetime_utils import parse_duration
247
248     try:
249         secs = parse_duration(txt)
250     except Exception as e:
251         raise argparse.ArgumentTypeError(e)
252     finally:
253         return datetime.timedelta(seconds=secs)
254
255
256 if __name__ == '__main__':
257     import doctest
258
259     doctest.ELLIPSIS_MARKER = '-ANYTHING-'
260     doctest.testmod()