Scale back warnings.warn and add stacklevels= where appropriate.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import pickle
15 import random
16 import statistics
17 import time
18 import tempfile
19 from typing import Callable
20 import unittest
21
22 import bootstrap
23 import config
24
25
26 logger = logging.getLogger(__name__)
27 cfg = config.add_commandline_args(
28     f'Logging ({__file__})',
29     'Args related to function decorators')
30 cfg.add_argument(
31     '--unittests_ignore_perf',
32     action='store_true',
33     default=False,
34     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
35 )
36 cfg.add_argument(
37     '--unittests_num_perf_samples',
38     type=int,
39     default=20,
40     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
41 )
42 cfg.add_argument(
43     '--unittests_drop_perf_traces',
44     type=str,
45     nargs=1,
46     default=None,
47     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
48 )
49
50
51 # >>> This is the hacky business, FYI. <<<
52 unittest.main = bootstrap.initialize(unittest.main)
53
54
55 _db = '/home/scott/.python_unittest_performance_db'
56
57
58 def check_method_for_perf_regressions(func: Callable) -> Callable:
59     """
60     This is meant to be used on a method in a class that subclasses
61     unittest.TestCase.  When thus decorated it will time the execution
62     of the code in the method, compare it with a database of
63     historical perfmance, and fail the test with a perf-related
64     message if it has become too slow.
65
66     """
67     def load_known_test_performance_characteristics():
68         with open(_db, 'rb') as f:
69             return pickle.load(f)
70
71     def save_known_test_performance_characteristics(perfdb):
72         with open(_db, 'wb') as f:
73             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
74
75     @functools.wraps(func)
76     def wrapper_perf_monitor(*args, **kwargs):
77         try:
78             perfdb = load_known_test_performance_characteristics()
79         except Exception as e:
80             logger.exception(e)
81             msg = f'Unable to load perfdb from {_db}'
82             logger.warning(msg)
83             perfdb = {}
84
85         # This is a unique identifier for a test: filepath!function
86         logger.debug(f'Watching {func.__name__}\'s performance...')
87         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
88         logger.debug(f'Canonical function identifier = {func_id}')
89
90         # cmdline arg to forget perf traces for function
91         drop_id = config.config['unittests_drop_perf_traces']
92         if drop_id is not None:
93             if drop_id in perfdb:
94                 perfdb[drop_id] = []
95
96         # Run the wrapped test paying attention to latency.
97         start_time = time.perf_counter()
98         value = func(*args, **kwargs)
99         end_time = time.perf_counter()
100         run_time = end_time - start_time
101         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
102
103         # Check the db; see if it was unexpectedly slow.
104         hist = perfdb.get(func_id, [])
105         if len(hist) < config.config['unittests_num_perf_samples']:
106             hist.append(run_time)
107             logger.debug(
108                 f'Still establishing a perf baseline for {func.__name__}'
109             )
110         else:
111             stdev = statistics.stdev(hist)
112             limit = hist[-1] + stdev * 5
113             logger.debug(
114                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
115             )
116             if (
117                 run_time > limit and
118                 not config.config['unittests_ignore_perf']
119             ):
120                 msg = f'''{func_id} performance has regressed unacceptably.
121 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
122 It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
123 Here is the current, full db perf timing distribution:
124
125 '''
126                 for x in hist:
127                     msg += f'{x:f}\n'
128                 logger.error(msg)
129                 slf = args[0]
130                 slf.fail(msg)
131             else:
132                 hist.append(run_time)
133
134         n = min(config.config['unittests_num_perf_samples'], len(hist))
135         hist = random.sample(hist, n)
136         hist.sort()
137         perfdb[func_id] = hist
138         save_known_test_performance_characteristics(perfdb)
139         return value
140     return wrapper_perf_monitor
141
142
143 def check_all_methods_for_perf_regressions(prefix='test_'):
144     def decorate_the_testcase(cls):
145         if issubclass(cls, unittest.TestCase):
146             for name, m in inspect.getmembers(cls, inspect.isfunction):
147                 if name.startswith(prefix):
148                     setattr(cls, name, check_method_for_perf_regressions(m))
149                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
150         return cls
151     return decorate_the_testcase
152
153
154 def breakpoint():
155     """Hard code a breakpoint somewhere; drop into pdb."""
156     import pdb
157     pdb.set_trace()
158
159
160 class RecordStdout(object):
161     """
162     Record what is emitted to stdout.
163
164     >>> with RecordStdout() as record:
165     ...     print("This is a test!")
166     >>> print({record().readline()})
167     {'This is a test!\\n'}
168     """
169
170     def __init__(self) -> None:
171         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
172         self.recorder = None
173
174     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
175         self.recorder = contextlib.redirect_stdout(self.destination)
176         self.recorder.__enter__()
177         return lambda: self.destination
178
179     def __exit__(self, *args) -> bool:
180         self.recorder.__exit__(*args)
181         self.destination.seek(0)
182         return None
183
184
185 class RecordStderr(object):
186     """
187     Record what is emitted to stderr.
188
189     >>> import sys
190     >>> with RecordStderr() as record:
191     ...     print("This is a test!", file=sys.stderr)
192     >>> print({record().readline()})
193     {'This is a test!\\n'}
194     """
195
196     def __init__(self) -> None:
197         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
198         self.recorder = None
199
200     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
201         self.recorder = contextlib.redirect_stderr(self.destination)
202         self.recorder.__enter__()
203         return lambda: self.destination
204
205     def __exit__(self, *args) -> bool:
206         self.recorder.__exit__(*args)
207         self.destination.seek(0)
208         return None
209
210
211 class RecordMultipleStreams(object):
212     """
213     Record the output to more than one stream.
214     """
215
216     def __init__(self, *files) -> None:
217         self.files = [*files]
218         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
219         self.saved_writes = []
220
221     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
222         for f in self.files:
223             self.saved_writes.append(f.write)
224             f.write = self.destination.write
225         return lambda: self.destination
226
227     def __exit__(self, *args) -> bool:
228         for f in self.files:
229             f.write = self.saved_writes.pop()
230         self.destination.seek(0)
231
232
233 if __name__ == '__main__':
234     import doctest
235     doctest.testmod()