Record streams.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4 automatically wrap unittest.main() with a call to bootstrap.initialize
5 so that we getLogger config, commandline args, logging control,
6 etc... this works fine but it's a little hacky so caveat emptor.
7 """
8
9 import contextlib
10 import functools
11 import inspect
12 import logging
13 import pickle
14 import random
15 import statistics
16 import time
17 import tempfile
18 from typing import Callable
19 import unittest
20
21 import bootstrap
22 import config
23
24
25 logger = logging.getLogger(__name__)
26 cfg = config.add_commandline_args(
27     f'Logging ({__file__})',
28     'Args related to function decorators')
29 cfg.add_argument(
30     '--unittests_ignore_perf',
31     action='store_true',
32     default=False,
33     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
34 )
35 cfg.add_argument(
36     '--unittests_num_perf_samples',
37     type=int,
38     default=20,
39     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
40 )
41 cfg.add_argument(
42     '--unittests_drop_perf_traces',
43     type=str,
44     nargs=1,
45     default=None,
46     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
47 )
48
49
50 # >>> This is the hacky business, FYI. <<<
51 unittest.main = bootstrap.initialize(unittest.main)
52
53
54 _db = '/home/scott/.python_unittest_performance_db'
55
56
57 def check_method_for_perf_regressions(func: Callable) -> Callable:
58     """This is meant to be used on a method in a class that subclasses
59     unittest.TestCase.  When thus decorated it will time the execution
60     of the code in the method, compare it with a database of
61     historical perfmance, and fail the test with a perf-related
62     message if it has become too slow.
63     """
64
65     def load_known_test_performance_characteristics():
66         with open(_db, 'rb') as f:
67             return pickle.load(f)
68
69     def save_known_test_performance_characteristics(perfdb):
70         with open(_db, 'wb') as f:
71             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
72
73     @functools.wraps(func)
74     def wrapper_perf_monitor(*args, **kwargs):
75         try:
76             perfdb = load_known_test_performance_characteristics()
77         except Exception as e:
78             logger.exception(e)
79             logger.warning(f'Unable to load perfdb from {_db}')
80             perfdb = {}
81
82         # This is a unique identifier for a test: filepath!function
83         logger.debug(f'Watching {func.__name__}\'s performance...')
84         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
85         logger.debug(f'Canonical function identifier = {func_id}')
86
87         # cmdline arg to forget perf traces for function
88         drop_id = config.config['unittests_drop_perf_traces']
89         if drop_id is not None:
90             if drop_id in perfdb:
91                 perfdb[drop_id] = []
92
93         # Run the wrapped test paying attention to latency.
94         start_time = time.perf_counter()
95         value = func(*args, **kwargs)
96         end_time = time.perf_counter()
97         run_time = end_time - start_time
98         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
99
100         # Check the db; see if it was unexpectedly slow.
101         hist = perfdb.get(func_id, [])
102         if len(hist) < config.config['unittests_num_perf_samples']:
103             hist.append(run_time)
104             logger.debug(
105                 f'Still establishing a perf baseline for {func.__name__}'
106             )
107         else:
108             stdev = statistics.stdev(hist)
109             limit = hist[-1] + stdev * 3
110             logger.debug(
111                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
112             )
113             if (
114                 run_time > limit and
115                 not config.config['unittests_ignore_perf']
116             ):
117                 msg = f'''{func_id} performance has regressed unacceptably.
118 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
119 It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample.
120 Here is the current, full db perf timing distribution:
121
122 {hist}'''
123                 slf = args[0]
124                 logger.error(msg)
125                 slf.fail(msg)
126             else:
127                 hist.append(run_time)
128
129         n = min(config.config['unittests_num_perf_samples'], len(hist))
130         hist = random.sample(hist, n)
131         hist.sort()
132         perfdb[func_id] = hist
133         save_known_test_performance_characteristics(perfdb)
134         return value
135     return wrapper_perf_monitor
136
137
138 def check_all_methods_for_perf_regressions(prefix='test_'):
139     def decorate_the_testcase(cls):
140         if issubclass(cls, unittest.TestCase):
141             for name, m in inspect.getmembers(cls, inspect.isfunction):
142                 if name.startswith(prefix):
143                     setattr(cls, name, check_method_for_perf_regressions(m))
144                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
145         return cls
146     return decorate_the_testcase
147
148
149 def breakpoint():
150     import pdb
151     pdb.set_trace()
152
153
154 class RecordStdout(object):
155     """
156         with uu.RecordStdout() as record:
157             print("This is a test!")
158         print({record().readline()})
159     """
160
161     def __init__(self) -> None:
162         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
163         self.recorder = None
164
165     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
166         self.recorder = contextlib.redirect_stdout(self.destination)
167         self.recorder.__enter__()
168         return lambda: self.destination
169
170     def __exit__(self, *args) -> bool:
171         self.recorder.__exit__(*args)
172         self.destination.seek(0)
173         return True
174
175
176 class RecordStderr(object):
177     """
178         with uu.RecordStderr() as record:
179             print("This is a test!", file=sys.stderr)
180         print({record().readline()})
181     """
182
183     def __init__(self) -> None:
184         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
185         self.recorder = None
186
187     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
188         self.recorder = contextlib.redirect_stderr(self.destination)
189         self.recorder.__enter__()
190         return lambda: self.destination
191
192     def __exit__(self, *args) -> bool:
193         self.recorder.__exit__(*args)
194         self.destination.seek(0)
195         return True
196
197
198 class RecordMultipleStreams(object):
199     """
200         with uu.RecordStreams(sys.stderr, sys.stdout) as record:
201             print("This is a test!")
202             print("This is one too.", file=sys.stderr)
203
204         print(record().readlines())
205     """
206     def __init__(self, *files) -> None:
207         self.files = [*files]
208         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
209         self.saved_writes = []
210
211     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
212         for f in self.files:
213             self.saved_writes.append(f.write)
214             f.write = self.destination.write
215         return lambda: self.destination
216
217     def __exit__(self, *args) -> bool:
218         for f in self.files:
219             f.write = self.saved_writes.pop()
220         self.destination.seek(0)