Make smart futures avoid polling.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import pickle
15 import random
16 import statistics
17 import time
18 import tempfile
19 from typing import Callable
20 import unittest
21
22 import bootstrap
23 import config
24
25
26 logger = logging.getLogger(__name__)
27 cfg = config.add_commandline_args(
28     f'Logging ({__file__})',
29     'Args related to function decorators')
30 cfg.add_argument(
31     '--unittests_ignore_perf',
32     action='store_true',
33     default=False,
34     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
35 )
36 cfg.add_argument(
37     '--unittests_num_perf_samples',
38     type=int,
39     default=20,
40     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
41 )
42 cfg.add_argument(
43     '--unittests_drop_perf_traces',
44     type=str,
45     nargs=1,
46     default=None,
47     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
48 )
49
50
51 # >>> This is the hacky business, FYI. <<<
52 unittest.main = bootstrap.initialize(unittest.main)
53
54
55 _db = '/home/scott/.python_unittest_performance_db'
56
57
58 def check_method_for_perf_regressions(func: Callable) -> Callable:
59     """
60     This is meant to be used on a method in a class that subclasses
61     unittest.TestCase.  When thus decorated it will time the execution
62     of the code in the method, compare it with a database of
63     historical perfmance, and fail the test with a perf-related
64     message if it has become too slow.
65
66     """
67     def load_known_test_performance_characteristics():
68         with open(_db, 'rb') as f:
69             return pickle.load(f)
70
71     def save_known_test_performance_characteristics(perfdb):
72         with open(_db, 'wb') as f:
73             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
74
75     @functools.wraps(func)
76     def wrapper_perf_monitor(*args, **kwargs):
77         try:
78             perfdb = load_known_test_performance_characteristics()
79         except Exception as e:
80             logger.exception(e)
81             logger.warning(f'Unable to load perfdb from {_db}')
82             perfdb = {}
83
84         # This is a unique identifier for a test: filepath!function
85         logger.debug(f'Watching {func.__name__}\'s performance...')
86         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
87         logger.debug(f'Canonical function identifier = {func_id}')
88
89         # cmdline arg to forget perf traces for function
90         drop_id = config.config['unittests_drop_perf_traces']
91         if drop_id is not None:
92             if drop_id in perfdb:
93                 perfdb[drop_id] = []
94
95         # Run the wrapped test paying attention to latency.
96         start_time = time.perf_counter()
97         value = func(*args, **kwargs)
98         end_time = time.perf_counter()
99         run_time = end_time - start_time
100         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
101
102         # Check the db; see if it was unexpectedly slow.
103         hist = perfdb.get(func_id, [])
104         if len(hist) < config.config['unittests_num_perf_samples']:
105             hist.append(run_time)
106             logger.debug(
107                 f'Still establishing a perf baseline for {func.__name__}'
108             )
109         else:
110             stdev = statistics.stdev(hist)
111             limit = hist[-1] + stdev * 5
112             logger.debug(
113                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
114             )
115             if (
116                 run_time > limit and
117                 not config.config['unittests_ignore_perf']
118             ):
119                 msg = f'''{func_id} performance has regressed unacceptably.
120 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
121 It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
122 Here is the current, full db perf timing distribution:
123
124 '''
125                 for x in hist:
126                     msg += f'{x:f}\n'
127                 logger.error(msg)
128                 slf = args[0]
129                 slf.fail(msg)
130             else:
131                 hist.append(run_time)
132
133         n = min(config.config['unittests_num_perf_samples'], len(hist))
134         hist = random.sample(hist, n)
135         hist.sort()
136         perfdb[func_id] = hist
137         save_known_test_performance_characteristics(perfdb)
138         return value
139     return wrapper_perf_monitor
140
141
142 def check_all_methods_for_perf_regressions(prefix='test_'):
143     def decorate_the_testcase(cls):
144         if issubclass(cls, unittest.TestCase):
145             for name, m in inspect.getmembers(cls, inspect.isfunction):
146                 if name.startswith(prefix):
147                     setattr(cls, name, check_method_for_perf_regressions(m))
148                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
149         return cls
150     return decorate_the_testcase
151
152
153 def breakpoint():
154     """Hard code a breakpoint somewhere; drop into pdb."""
155     import pdb
156     pdb.set_trace()
157
158
159 class RecordStdout(object):
160     """
161     Record what is emitted to stdout.
162
163     >>> with RecordStdout() as record:
164     ...     print("This is a test!")
165     >>> print({record().readline()})
166     {'This is a test!\\n'}
167     """
168
169     def __init__(self) -> None:
170         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
171         self.recorder = None
172
173     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
174         self.recorder = contextlib.redirect_stdout(self.destination)
175         self.recorder.__enter__()
176         return lambda: self.destination
177
178     def __exit__(self, *args) -> bool:
179         self.recorder.__exit__(*args)
180         self.destination.seek(0)
181         return None
182
183
184 class RecordStderr(object):
185     """
186     Record what is emitted to stderr.
187
188     >>> import sys
189     >>> with RecordStderr() as record:
190     ...     print("This is a test!", file=sys.stderr)
191     >>> print({record().readline()})
192     {'This is a test!\\n'}
193     """
194
195     def __init__(self) -> None:
196         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
197         self.recorder = None
198
199     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
200         self.recorder = contextlib.redirect_stderr(self.destination)
201         self.recorder.__enter__()
202         return lambda: self.destination
203
204     def __exit__(self, *args) -> bool:
205         self.recorder.__exit__(*args)
206         self.destination.seek(0)
207         return None
208
209
210 class RecordMultipleStreams(object):
211     """
212     Record the output to more than one stream.
213     """
214
215     def __init__(self, *files) -> None:
216         self.files = [*files]
217         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
218         self.saved_writes = []
219
220     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
221         for f in self.files:
222             self.saved_writes.append(f.write)
223             f.write = self.destination.write
224         return lambda: self.destination
225
226     def __exit__(self, *args) -> bool:
227         for f in self.files:
228             f.write = self.saved_writes.pop()
229         self.destination.seek(0)
230
231
232 if __name__ == '__main__':
233     import doctest
234     doctest.testmod()