Various changes.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import pickle
15 import random
16 import statistics
17 import time
18 import tempfile
19 from typing import Callable
20 import unittest
21
22 import bootstrap
23 import config
24
25
26 logger = logging.getLogger(__name__)
27 cfg = config.add_commandline_args(
28     f'Logging ({__file__})',
29     'Args related to function decorators')
30 cfg.add_argument(
31     '--unittests_ignore_perf',
32     action='store_true',
33     default=False,
34     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
35 )
36 cfg.add_argument(
37     '--unittests_num_perf_samples',
38     type=int,
39     default=20,
40     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
41 )
42 cfg.add_argument(
43     '--unittests_drop_perf_traces',
44     type=str,
45     nargs=1,
46     default=None,
47     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
48 )
49
50
51 # >>> This is the hacky business, FYI. <<<
52 unittest.main = bootstrap.initialize(unittest.main)
53
54
55 _db = '/home/scott/.python_unittest_performance_db'
56
57
58 def check_method_for_perf_regressions(func: Callable) -> Callable:
59     """This is meant to be used on a method in a class that subclasses
60     unittest.TestCase.  When thus decorated it will time the execution
61     of the code in the method, compare it with a database of
62     historical perfmance, and fail the test with a perf-related
63     message if it has become too slow.
64     """
65
66     def load_known_test_performance_characteristics():
67         with open(_db, 'rb') as f:
68             return pickle.load(f)
69
70     def save_known_test_performance_characteristics(perfdb):
71         with open(_db, 'wb') as f:
72             pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
73
74     @functools.wraps(func)
75     def wrapper_perf_monitor(*args, **kwargs):
76         try:
77             perfdb = load_known_test_performance_characteristics()
78         except Exception as e:
79             logger.exception(e)
80             logger.warning(f'Unable to load perfdb from {_db}')
81             perfdb = {}
82
83         # This is a unique identifier for a test: filepath!function
84         logger.debug(f'Watching {func.__name__}\'s performance...')
85         func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
86         logger.debug(f'Canonical function identifier = {func_id}')
87
88         # cmdline arg to forget perf traces for function
89         drop_id = config.config['unittests_drop_perf_traces']
90         if drop_id is not None:
91             if drop_id in perfdb:
92                 perfdb[drop_id] = []
93
94         # Run the wrapped test paying attention to latency.
95         start_time = time.perf_counter()
96         value = func(*args, **kwargs)
97         end_time = time.perf_counter()
98         run_time = end_time - start_time
99         logger.debug(f'{func.__name__} executed in {run_time:f}s.')
100
101         # Check the db; see if it was unexpectedly slow.
102         hist = perfdb.get(func_id, [])
103         if len(hist) < config.config['unittests_num_perf_samples']:
104             hist.append(run_time)
105             logger.debug(
106                 f'Still establishing a perf baseline for {func.__name__}'
107             )
108         else:
109             stdev = statistics.stdev(hist)
110             limit = hist[-1] + stdev * 3
111             logger.debug(
112                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
113             )
114             if (
115                 run_time > limit and
116                 not config.config['unittests_ignore_perf']
117             ):
118                 msg = f'''{func_id} performance has regressed unacceptably.
119 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
120 It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample.
121 Here is the current, full db perf timing distribution:
122
123 {hist}'''
124                 slf = args[0]
125                 logger.error(msg)
126                 slf.fail(msg)
127             else:
128                 hist.append(run_time)
129
130         n = min(config.config['unittests_num_perf_samples'], len(hist))
131         hist = random.sample(hist, n)
132         hist.sort()
133         perfdb[func_id] = hist
134         save_known_test_performance_characteristics(perfdb)
135         return value
136     return wrapper_perf_monitor
137
138
139 def check_all_methods_for_perf_regressions(prefix='test_'):
140     def decorate_the_testcase(cls):
141         if issubclass(cls, unittest.TestCase):
142             for name, m in inspect.getmembers(cls, inspect.isfunction):
143                 if name.startswith(prefix):
144                     setattr(cls, name, check_method_for_perf_regressions(m))
145                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
146         return cls
147     return decorate_the_testcase
148
149
150 def breakpoint():
151     import pdb
152     pdb.set_trace()
153
154
155 class RecordStdout(object):
156     """
157         with uu.RecordStdout() as record:
158             print("This is a test!")
159         print({record().readline()})
160     """
161
162     def __init__(self) -> None:
163         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
164         self.recorder = None
165
166     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
167         self.recorder = contextlib.redirect_stdout(self.destination)
168         self.recorder.__enter__()
169         return lambda: self.destination
170
171     def __exit__(self, *args) -> bool:
172         self.recorder.__exit__(*args)
173         self.destination.seek(0)
174         return None
175
176
177 class RecordStderr(object):
178     """
179         with uu.RecordStderr() as record:
180             print("This is a test!", file=sys.stderr)
181         print({record().readline()})
182     """
183
184     def __init__(self) -> None:
185         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
186         self.recorder = None
187
188     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
189         self.recorder = contextlib.redirect_stderr(self.destination)
190         self.recorder.__enter__()
191         return lambda: self.destination
192
193     def __exit__(self, *args) -> bool:
194         self.recorder.__exit__(*args)
195         self.destination.seek(0)
196         return None
197
198
199 class RecordMultipleStreams(object):
200     """
201         with uu.RecordStreams(sys.stderr, sys.stdout) as record:
202             print("This is a test!")
203             print("This is one too.", file=sys.stderr)
204
205         print(record().readlines())
206     """
207     def __init__(self, *files) -> None:
208         self.files = [*files]
209         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
210         self.saved_writes = []
211
212     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
213         for f in self.files:
214             self.saved_writes.append(f.write)
215             f.write = self.destination.write
216         return lambda: self.destination
217
218     def __exit__(self, *args) -> bool:
219         for f in self.files:
220             f.write = self.saved_writes.pop()
221         self.destination.seek(0)