Let's be explicit with asserts; there was a bug in histogram
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import os
15 import pickle
16 import random
17 import statistics
18 import tempfile
19 import time
20 import unittest
21 import warnings
22 from abc import ABC, abstractmethod
23 from typing import Any, Callable, Dict, List, Optional
24
25 import sqlalchemy as sa
26
27 import bootstrap
28 import config
29 import function_utils
30 import scott_secrets
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34     f'Logging ({__file__})', 'Args related to function decorators'
35 )
36 cfg.add_argument(
37     '--unittests_ignore_perf',
38     action='store_true',
39     default=False,
40     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
41 )
42 cfg.add_argument(
43     '--unittests_num_perf_samples',
44     type=int,
45     default=50,
46     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
47 )
48 cfg.add_argument(
49     '--unittests_drop_perf_traces',
50     type=str,
51     nargs=1,
52     default=None,
53     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
54 )
55 cfg.add_argument(
56     '--unittests_persistance_strategy',
57     choices=['FILE', 'DATABASE'],
58     default='DATABASE',
59     help='Should we persist perf data in a file or db?',
60 )
61 cfg.add_argument(
62     '--unittests_perfdb_filename',
63     type=str,
64     metavar='FILENAME',
65     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
67 )
68 cfg.add_argument(
69     '--unittests_perfdb_spec',
70     type=str,
71     metavar='DBSPEC',
72     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
73     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
74 )
75
76 # >>> This is the hacky business, FYI. <<<
77 unittest.main = bootstrap.initialize(unittest.main)
78
79
80 class PerfRegressionDataPersister(ABC):
81     def __init__(self):
82         pass
83
84     @abstractmethod
85     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
86         pass
87
88     @abstractmethod
89     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
90         pass
91
92     @abstractmethod
93     def delete_performance_data(self, method_id: str):
94         pass
95
96
97 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
98     def __init__(self, filename: str):
99         self.filename = filename
100         self.traces_to_delete: List[str] = []
101
102     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
103         with open(self.filename, 'rb') as f:
104             return pickle.load(f)
105
106     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
107         for trace in self.traces_to_delete:
108             if trace in data:
109                 data[trace] = []
110
111         with open(self.filename, 'wb') as f:
112             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
113
114     def delete_performance_data(self, method_id: str):
115         self.traces_to_delete.append(method_id)
116
117
118 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
119     def __init__(self, dbspec: str):
120         self.dbspec = dbspec
121         self.engine = sa.create_engine(self.dbspec)
122         self.conn = self.engine.connect()
123
124     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
125         results = self.conn.execute(
126             sa.text(
127                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
128             )
129         )
130         ret: Dict[str, List[float]] = {method_id: []}
131         for result in results.all():
132             ret[method_id].append(result['runtime'])
133         results.close()
134         return ret
135
136     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
137         self.delete_performance_data(method_id)
138         for (method_id, perf_data) in data.items():
139             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
140             for perf in perf_data:
141                 self.conn.execute(sql + f'("{method_id}", {perf});')
142
143     def delete_performance_data(self, method_id: str):
144         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
145         self.conn.execute(sql)
146
147
148 def check_method_for_perf_regressions(func: Callable) -> Callable:
149     """
150     This is meant to be used on a method in a class that subclasses
151     unittest.TestCase.  When thus decorated it will time the execution
152     of the code in the method, compare it with a database of
153     historical perfmance, and fail the test with a perf-related
154     message if it has become too slow.
155
156     """
157
158     @functools.wraps(func)
159     def wrapper_perf_monitor(*args, **kwargs):
160         if config.config['unittests_ignore_perf']:
161             return func(*args, **kwargs)
162
163         if config.config['unittests_persistance_strategy'] == 'FILE':
164             filename = config.config['unittests_perfdb_filename']
165             helper = FileBasedPerfRegressionDataPersister(filename)
166         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
167             dbspec = config.config['unittests_perfdb_spec']
168             dbspec = dbspec.replace(
169                 '<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD
170             )
171             helper = DatabasePerfRegressionDataPersister(dbspec)
172         else:
173             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
174
175         func_id = function_utils.function_identifier(func)
176         func_name = func.__name__
177         logger.debug(f'Watching {func_name}\'s performance...')
178         logger.debug(f'Canonical function identifier = {func_id}')
179
180         try:
181             perfdb = helper.load_performance_data(func_id)
182         except Exception as e:
183             logger.exception(e)
184             msg = 'Unable to load perfdb; skipping it...'
185             logger.warning(msg)
186             warnings.warn(msg)
187             perfdb = {}
188
189         # cmdline arg to forget perf traces for function
190         drop_id = config.config['unittests_drop_perf_traces']
191         if drop_id is not None:
192             helper.delete_performance_data(drop_id)
193
194         # Run the wrapped test paying attention to latency.
195         start_time = time.perf_counter()
196         value = func(*args, **kwargs)
197         end_time = time.perf_counter()
198         run_time = end_time - start_time
199
200         # See if it was unexpectedly slow.
201         hist = perfdb.get(func_id, [])
202         if len(hist) < config.config['unittests_num_perf_samples']:
203             hist.append(run_time)
204             logger.debug(f'Still establishing a perf baseline for {func_name}')
205         else:
206             stdev = statistics.stdev(hist)
207             logger.debug(f'For {func_name}, performance stdev={stdev}')
208             slowest = hist[-1]
209             logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s')
210             limit = slowest + stdev * 4
211             logger.debug(f'For {func_name}, max acceptable runtime is {limit:f}s')
212             logger.debug(f'For {func_name}, actual observed runtime was {run_time:f}s')
213             if run_time > limit:
214                 msg = f'''{func_id} performance has regressed unacceptably.
215 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
216 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
217 Here is the current, full db perf timing distribution:
218
219 '''
220                 for x in hist:
221                     msg += f'{x:f}\n'
222                 logger.error(msg)
223                 slf = args[0]  # Peek at the wrapped function's self ref.
224                 slf.fail(msg)  # ...to fail the testcase.
225             else:
226                 hist.append(run_time)
227
228         # Don't spam the database with samples; just pick a random
229         # sample from what we have and store that back.
230         n = min(config.config['unittests_num_perf_samples'], len(hist))
231         hist = random.sample(hist, n)
232         hist.sort()
233         perfdb[func_id] = hist
234         helper.save_performance_data(func_id, perfdb)
235         return value
236
237     return wrapper_perf_monitor
238
239
240 def check_all_methods_for_perf_regressions(prefix='test_'):
241     """Decorate unittests with this to pay attention to the perf of the
242     testcode and flag perf regressions.  e.g.
243
244     import unittest_utils as uu
245
246     @uu.check_all_methods_for_perf_regressions()
247     class TestMyClass(unittest.TestCase):
248
249         def test_some_part_of_my_class(self):
250             ...
251
252     """
253
254     def decorate_the_testcase(cls):
255         if issubclass(cls, unittest.TestCase):
256             for name, m in inspect.getmembers(cls, inspect.isfunction):
257                 if name.startswith(prefix):
258                     setattr(cls, name, check_method_for_perf_regressions(m))
259                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
260         return cls
261
262     return decorate_the_testcase
263
264
265 def breakpoint():
266     """Hard code a breakpoint somewhere; drop into pdb."""
267     import pdb
268
269     pdb.set_trace()
270
271
272 class RecordStdout(object):
273     """
274     Record what is emitted to stdout.
275
276     >>> with RecordStdout() as record:
277     ...     print("This is a test!")
278     >>> print({record().readline()})
279     {'This is a test!\\n'}
280     >>> record().close()
281     """
282
283     def __init__(self) -> None:
284         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
285         self.recorder: Optional[contextlib.redirect_stdout] = None
286
287     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
288         self.recorder = contextlib.redirect_stdout(self.destination)
289         assert self.recorder is not None
290         self.recorder.__enter__()
291         return lambda: self.destination
292
293     def __exit__(self, *args) -> Optional[bool]:
294         assert self.recorder is not None
295         self.recorder.__exit__(*args)
296         self.destination.seek(0)
297         return None
298
299
300 class RecordStderr(object):
301     """
302     Record what is emitted to stderr.
303
304     >>> import sys
305     >>> with RecordStderr() as record:
306     ...     print("This is a test!", file=sys.stderr)
307     >>> print({record().readline()})
308     {'This is a test!\\n'}
309     >>> record().close()
310     """
311
312     def __init__(self) -> None:
313         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
314         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
315
316     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
317         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
318         assert self.recorder is not None
319         self.recorder.__enter__()
320         return lambda: self.destination
321
322     def __exit__(self, *args) -> Optional[bool]:
323         assert self.recorder is not None
324         self.recorder.__exit__(*args)
325         self.destination.seek(0)
326         return None
327
328
329 class RecordMultipleStreams(object):
330     """
331     Record the output to more than one stream.
332     """
333
334     def __init__(self, *files) -> None:
335         self.files = [*files]
336         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
337         self.saved_writes: List[Callable[..., Any]] = []
338
339     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
340         for f in self.files:
341             self.saved_writes.append(f.write)
342             f.write = self.destination.write
343         return lambda: self.destination
344
345     def __exit__(self, *args) -> Optional[bool]:
346         for f in self.files:
347             f.write = self.saved_writes.pop()
348         self.destination.seek(0)
349         return None
350
351
352 if __name__ == '__main__':
353     import doctest
354
355     doctest.testmod()