Start using warnings from stdlib.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import pickle
15 import random
16 import statistics
17 import time
18 import tempfile
19 from typing import Callable
20 import unittest
21 import warnings
22
23 import bootstrap
24 import config
25
26
27 logger = logging.getLogger(__name__)
28 cfg = config.add_commandline_args(
29     f'Logging ({__file__})',
30     'Args related to function decorators')
31 cfg.add_argument(
32     '--unittests_ignore_perf',
33     action='store_true',
34     default=False,
35     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
36 )
37 cfg.add_argument(
38     '--unittests_num_perf_samples',
39     type=int,
40     default=20,
41     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
42 )
43 cfg.add_argument(
44     '--unittests_drop_perf_traces',
45     type=str,
46     nargs=1,
47     default=None,
48     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
49 )
50
51
52 # >>> This is the hacky business, FYI. <<<
53 unittest.main = bootstrap.initialize(unittest.main)
54
55
56 _db = '/home/scott/.python_unittest_performance_db'
57
58
59 def check_method_for_perf_regressions(func: Callable) -> Callable:
60     """
61     This is meant to be used on a method in a class that subclasses
62     unittest.TestCase.  When thus decorated it will time the execution
63     of the code in the method, compare it with a database of
64     historical perfmance, and fail the test with a perf-related
65     message if it has become too slow.
66
67     """
68     def load_known_test_performance_characteristics():
69         with open(_db, 'rb') as f:
70             return pickle.load(f)
71
72     def save_known_test_performance_characteristics(perfdb):
73         with open(_db, 'wb') as f:
74             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
75
76     @functools.wraps(func)
77     def wrapper_perf_monitor(*args, **kwargs):
78         try:
79             perfdb = load_known_test_performance_characteristics()
80         except Exception as e:
81             logger.exception(e)
82             msg = f'Unable to load perfdb from {_db}'
83             logger.warning(msg)
84             warnings.warn(msg)
85             perfdb = {}
86
87         # This is a unique identifier for a test: filepath!function
88         logger.debug(f'Watching {func.__name__}\'s performance...')
89         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
90         logger.debug(f'Canonical function identifier = {func_id}')
91
92         # cmdline arg to forget perf traces for function
93         drop_id = config.config['unittests_drop_perf_traces']
94         if drop_id is not None:
95             if drop_id in perfdb:
96                 perfdb[drop_id] = []
97
98         # Run the wrapped test paying attention to latency.
99         start_time = time.perf_counter()
100         value = func(*args, **kwargs)
101         end_time = time.perf_counter()
102         run_time = end_time - start_time
103         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
104
105         # Check the db; see if it was unexpectedly slow.
106         hist = perfdb.get(func_id, [])
107         if len(hist) < config.config['unittests_num_perf_samples']:
108             hist.append(run_time)
109             logger.debug(
110                 f'Still establishing a perf baseline for {func.__name__}'
111             )
112         else:
113             stdev = statistics.stdev(hist)
114             limit = hist[-1] + stdev * 5
115             logger.debug(
116                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
117             )
118             if (
119                 run_time > limit and
120                 not config.config['unittests_ignore_perf']
121             ):
122                 msg = f'''{func_id} performance has regressed unacceptably.
123 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
124 It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
125 Here is the current, full db perf timing distribution:
126
127 '''
128                 for x in hist:
129                     msg += f'{x:f}\n'
130                 logger.error(msg)
131                 slf = args[0]
132                 slf.fail(msg)
133             else:
134                 hist.append(run_time)
135
136         n = min(config.config['unittests_num_perf_samples'], len(hist))
137         hist = random.sample(hist, n)
138         hist.sort()
139         perfdb[func_id] = hist
140         save_known_test_performance_characteristics(perfdb)
141         return value
142     return wrapper_perf_monitor
143
144
145 def check_all_methods_for_perf_regressions(prefix='test_'):
146     def decorate_the_testcase(cls):
147         if issubclass(cls, unittest.TestCase):
148             for name, m in inspect.getmembers(cls, inspect.isfunction):
149                 if name.startswith(prefix):
150                     setattr(cls, name, check_method_for_perf_regressions(m))
151                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
152         return cls
153     return decorate_the_testcase
154
155
156 def breakpoint():
157     """Hard code a breakpoint somewhere; drop into pdb."""
158     import pdb
159     pdb.set_trace()
160
161
162 class RecordStdout(object):
163     """
164     Record what is emitted to stdout.
165
166     >>> with RecordStdout() as record:
167     ...     print("This is a test!")
168     >>> print({record().readline()})
169     {'This is a test!\\n'}
170     """
171
172     def __init__(self) -> None:
173         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
174         self.recorder = None
175
176     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
177         self.recorder = contextlib.redirect_stdout(self.destination)
178         self.recorder.__enter__()
179         return lambda: self.destination
180
181     def __exit__(self, *args) -> bool:
182         self.recorder.__exit__(*args)
183         self.destination.seek(0)
184         return None
185
186
187 class RecordStderr(object):
188     """
189     Record what is emitted to stderr.
190
191     >>> import sys
192     >>> with RecordStderr() as record:
193     ...     print("This is a test!", file=sys.stderr)
194     >>> print({record().readline()})
195     {'This is a test!\\n'}
196     """
197
198     def __init__(self) -> None:
199         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
200         self.recorder = None
201
202     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
203         self.recorder = contextlib.redirect_stderr(self.destination)
204         self.recorder.__enter__()
205         return lambda: self.destination
206
207     def __exit__(self, *args) -> bool:
208         self.recorder.__exit__(*args)
209         self.destination.seek(0)
210         return None
211
212
213 class RecordMultipleStreams(object):
214     """
215     Record the output to more than one stream.
216     """
217
218     def __init__(self, *files) -> None:
219         self.files = [*files]
220         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
221         self.saved_writes = []
222
223     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
224         for f in self.files:
225             self.saved_writes.append(f.write)
226             f.write = self.destination.write
227         return lambda: self.destination
228
229     def __exit__(self, *args) -> bool:
230         for f in self.files:
231             f.write = self.saved_writes.pop()
232         self.destination.seek(0)
233
234
235 if __name__ == '__main__':
236     import doctest
237     doctest.testmod()