Adding more tests, working on the test harness.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import pickle
15 import random
16 import statistics
17 import time
18 import tempfile
19 from typing import Callable
20 import unittest
21
22 import bootstrap
23 import config
24
25
26 logger = logging.getLogger(__name__)
27 cfg = config.add_commandline_args(
28     f'Logging ({__file__})',
29     'Args related to function decorators')
30 cfg.add_argument(
31     '--unittests_ignore_perf',
32     action='store_true',
33     default=False,
34     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
35 )
36 cfg.add_argument(
37     '--unittests_num_perf_samples',
38     type=int,
39     default=20,
40     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
41 )
42 cfg.add_argument(
43     '--unittests_drop_perf_traces',
44     type=str,
45     nargs=1,
46     default=None,
47     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
48 )
49
50
51 # >>> This is the hacky business, FYI. <<<
52 unittest.main = bootstrap.initialize(unittest.main)
53
54
55 _db = '/home/scott/.python_unittest_performance_db'
56
57
58 def check_method_for_perf_regressions(func: Callable) -> Callable:
59     """This is meant to be used on a method in a class that subclasses
60     unittest.TestCase.  When thus decorated it will time the execution
61     of the code in the method, compare it with a database of
62     historical perfmance, and fail the test with a perf-related
63     message if it has become too slow.
64     """
65
66     def load_known_test_performance_characteristics():
67         with open(_db, 'rb') as f:
68             return pickle.load(f)
69
70     def save_known_test_performance_characteristics(perfdb):
71         with open(_db, 'wb') as f:
72             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
73
74     @functools.wraps(func)
75     def wrapper_perf_monitor(*args, **kwargs):
76         try:
77             perfdb = load_known_test_performance_characteristics()
78         except Exception as e:
79             logger.exception(e)
80             logger.warning(f'Unable to load perfdb from {_db}')
81             perfdb = {}
82
83         # This is a unique identifier for a test: filepath!function
84         logger.debug(f'Watching {func.__name__}\'s performance...')
85         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
86         logger.debug(f'Canonical function identifier = {func_id}')
87
88         # cmdline arg to forget perf traces for function
89         drop_id = config.config['unittests_drop_perf_traces']
90         if drop_id is not None:
91             if drop_id in perfdb:
92                 perfdb[drop_id] = []
93
94         # Run the wrapped test paying attention to latency.
95         start_time = time.perf_counter()
96         value = func(*args, **kwargs)
97         end_time = time.perf_counter()
98         run_time = end_time - start_time
99         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
100
101         # Check the db; see if it was unexpectedly slow.
102         hist = perfdb.get(func_id, [])
103         if len(hist) < config.config['unittests_num_perf_samples']:
104             hist.append(run_time)
105             logger.debug(
106                 f'Still establishing a perf baseline for {func.__name__}'
107             )
108         else:
109             stdev = statistics.stdev(hist)
110             limit = hist[-1] + stdev * 3
111             logger.debug(
112                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
113             )
114             if (
115                 run_time > limit and
116                 not config.config['unittests_ignore_perf']
117             ):
118                 msg = f'''{func_id} performance has regressed unacceptably.
119 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
120 It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample.
121 Here is the current, full db perf timing distribution:
122
123 {hist}'''
124                 slf = args[0]
125                 logger.error(msg)
126                 slf.fail(msg)
127             else:
128                 hist.append(run_time)
129
130         n = min(config.config['unittests_num_perf_samples'], len(hist))
131         hist = random.sample(hist, n)
132         hist.sort()
133         perfdb[func_id] = hist
134         save_known_test_performance_characteristics(perfdb)
135         return value
136     return wrapper_perf_monitor
137
138
139 def check_all_methods_for_perf_regressions(prefix='test_'):
140     def decorate_the_testcase(cls):
141         if issubclass(cls, unittest.TestCase):
142             for name, m in inspect.getmembers(cls, inspect.isfunction):
143                 if name.startswith(prefix):
144                     setattr(cls, name, check_method_for_perf_regressions(m))
145                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
146         return cls
147     return decorate_the_testcase
148
149
150 def breakpoint():
151     """Hard code a breakpoint somewhere; drop into pdb."""
152     import pdb
153     pdb.set_trace()
154
155
156 class RecordStdout(object):
157     """
158     Record what is emitted to stdout.
159
160     >>> with RecordStdout() as record:
161     ...     print("This is a test!")
162     >>> print({record().readline()})
163     {'This is a test!\\n'}
164     """
165
166     def __init__(self) -> None:
167         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
168         self.recorder = None
169
170     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
171         self.recorder = contextlib.redirect_stdout(self.destination)
172         self.recorder.__enter__()
173         return lambda: self.destination
174
175     def __exit__(self, *args) -> bool:
176         self.recorder.__exit__(*args)
177         self.destination.seek(0)
178         return None
179
180
181 class RecordStderr(object):
182     """
183     Record what is emitted to stderr.
184
185     >>> import sys
186     >>> with RecordStderr() as record:
187     ...     print("This is a test!", file=sys.stderr)
188     >>> print({record().readline()})
189     {'This is a test!\\n'}
190     """
191
192     def __init__(self) -> None:
193         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
194         self.recorder = None
195
196     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
197         self.recorder = contextlib.redirect_stderr(self.destination)
198         self.recorder.__enter__()
199         return lambda: self.destination
200
201     def __exit__(self, *args) -> bool:
202         self.recorder.__exit__(*args)
203         self.destination.seek(0)
204         return None
205
206
207 class RecordMultipleStreams(object):
208     """
209     Record the output to more than one stream.
210     """
211
212     def __init__(self, *files) -> None:
213         self.files = [*files]
214         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
215         self.saved_writes = []
216
217     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
218         for f in self.files:
219             self.saved_writes.append(f.write)
220             f.write = self.destination.write
221         return lambda: self.destination
222
223     def __exit__(self, *args) -> bool:
224         for f in self.files:
225             f.write = self.saved_writes.pop()
226         self.destination.seek(0)
227
228
229 if __name__ == '__main__':
230     import doctest
231     doctest.testmod()