mypy clean!
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 from abc import ABC, abstractmethod
11 import contextlib
12 import functools
13 import inspect
14 import logging
15 import os
16 import pickle
17 import random
18 import statistics
19 import time
20 import tempfile
21 from typing import Any, Callable, Dict, List, Optional
22 import unittest
23 import warnings
24
25 import bootstrap
26 import config
27 import function_utils
28 import scott_secrets
29
30 import sqlalchemy as sa
31
32
33 logger = logging.getLogger(__name__)
34 cfg = config.add_commandline_args(
35     f'Logging ({__file__})', 'Args related to function decorators'
36 )
37 cfg.add_argument(
38     '--unittests_ignore_perf',
39     action='store_true',
40     default=False,
41     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
42 )
43 cfg.add_argument(
44     '--unittests_num_perf_samples',
45     type=int,
46     default=50,
47     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
48 )
49 cfg.add_argument(
50     '--unittests_drop_perf_traces',
51     type=str,
52     nargs=1,
53     default=None,
54     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
55 )
56 cfg.add_argument(
57     '--unittests_persistance_strategy',
58     choices=['FILE', 'DATABASE'],
59     default='DATABASE',
60     help='Should we persist perf data in a file or db?',
61 )
62 cfg.add_argument(
63     '--unittests_perfdb_filename',
64     type=str,
65     metavar='FILENAME',
66     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
67     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
68 )
69 cfg.add_argument(
70     '--unittests_perfdb_spec',
71     type=str,
72     metavar='DBSPEC',
73     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
74     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
75 )
76
77 # >>> This is the hacky business, FYI. <<<
78 unittest.main = bootstrap.initialize(unittest.main)
79
80
81 class PerfRegressionDataPersister(ABC):
82     def __init__(self):
83         pass
84
85     @abstractmethod
86     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
87         pass
88
89     @abstractmethod
90     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
91         pass
92
93     @abstractmethod
94     def delete_performance_data(self, method_id: str):
95         pass
96
97
98 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
99     def __init__(self, filename: str):
100         self.filename = filename
101         self.traces_to_delete: List[str] = []
102
103     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
104         with open(self.filename, 'rb') as f:
105             return pickle.load(f)
106
107     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
108         for trace in self.traces_to_delete:
109             if trace in data:
110                 data[trace] = []
111
112         with open(self.filename, 'wb') as f:
113             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
114
115     def delete_performance_data(self, method_id: str):
116         self.traces_to_delete.append(method_id)
117
118
119 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
120     def __init__(self, dbspec: str):
121         self.dbspec = dbspec
122         self.engine = sa.create_engine(self.dbspec)
123         self.conn = self.engine.connect()
124
125     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126         results = self.conn.execute(
127             sa.text(
128                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
129             )
130         )
131         ret: Dict[str, List[float]] = {method_id: []}
132         for result in results.all():
133             ret[method_id].append(result['runtime'])
134         results.close()
135         return ret
136
137     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
138         self.delete_performance_data(method_id)
139         for (method_id, perf_data) in data.items():
140             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
141             for perf in perf_data:
142                 self.conn.execute(sql + f'("{method_id}", {perf});')
143
144     def delete_performance_data(self, method_id: str):
145         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
146         self.conn.execute(sql)
147
148
149 def check_method_for_perf_regressions(func: Callable) -> Callable:
150     """
151     This is meant to be used on a method in a class that subclasses
152     unittest.TestCase.  When thus decorated it will time the execution
153     of the code in the method, compare it with a database of
154     historical perfmance, and fail the test with a perf-related
155     message if it has become too slow.
156
157     """
158
159     @functools.wraps(func)
160     def wrapper_perf_monitor(*args, **kwargs):
161         if config.config['unittests_ignore_perf']:
162             return func(*args, **kwargs)
163
164         if config.config['unittests_persistance_strategy'] == 'FILE':
165             filename = config.config['unittests_perfdb_filename']
166             helper = FileBasedPerfRegressionDataPersister(filename)
167         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
168             dbspec = config.config['unittests_perfdb_spec']
169             dbspec = dbspec.replace(
170                 '<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD
171             )
172             helper = DatabasePerfRegressionDataPersister(dbspec)
173         else:
174             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
175
176         func_id = function_utils.function_identifier(func)
177         func_name = func.__name__
178         logger.debug(f'Watching {func_name}\'s performance...')
179         logger.debug(f'Canonical function identifier = {func_id}')
180
181         try:
182             perfdb = helper.load_performance_data(func_id)
183         except Exception as e:
184             logger.exception(e)
185             msg = 'Unable to load perfdb; skipping it...'
186             logger.warning(msg)
187             warnings.warn(msg)
188             perfdb = {}
189
190         # cmdline arg to forget perf traces for function
191         drop_id = config.config['unittests_drop_perf_traces']
192         if drop_id is not None:
193             helper.delete_performance_data(drop_id)
194
195         # Run the wrapped test paying attention to latency.
196         start_time = time.perf_counter()
197         value = func(*args, **kwargs)
198         end_time = time.perf_counter()
199         run_time = end_time - start_time
200
201         # See if it was unexpectedly slow.
202         hist = perfdb.get(func_id, [])
203         if len(hist) < config.config['unittests_num_perf_samples']:
204             hist.append(run_time)
205             logger.debug(f'Still establishing a perf baseline for {func_name}')
206         else:
207             stdev = statistics.stdev(hist)
208             logger.debug(f'For {func_name}, performance stdev={stdev}')
209             slowest = hist[-1]
210             logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s')
211             limit = slowest + stdev * 4
212             logger.debug(f'For {func_name}, max acceptable runtime is {limit:f}s')
213             logger.debug(f'For {func_name}, actual observed runtime was {run_time:f}s')
214             if run_time > limit:
215                 msg = f'''{func_id} performance has regressed unacceptably.
216 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
217 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
218 Here is the current, full db perf timing distribution:
219
220 '''
221                 for x in hist:
222                     msg += f'{x:f}\n'
223                 logger.error(msg)
224                 slf = args[0]  # Peek at the wrapped function's self ref.
225                 slf.fail(msg)  # ...to fail the testcase.
226             else:
227                 hist.append(run_time)
228
229         # Don't spam the database with samples; just pick a random
230         # sample from what we have and store that back.
231         n = min(config.config['unittests_num_perf_samples'], len(hist))
232         hist = random.sample(hist, n)
233         hist.sort()
234         perfdb[func_id] = hist
235         helper.save_performance_data(func_id, perfdb)
236         return value
237
238     return wrapper_perf_monitor
239
240
241 def check_all_methods_for_perf_regressions(prefix='test_'):
242     """Decorate unittests with this to pay attention to the perf of the
243     testcode and flag perf regressions.  e.g.
244
245     import unittest_utils as uu
246
247     @uu.check_all_methods_for_perf_regressions()
248     class TestMyClass(unittest.TestCase):
249
250         def test_some_part_of_my_class(self):
251             ...
252
253     """
254
255     def decorate_the_testcase(cls):
256         if issubclass(cls, unittest.TestCase):
257             for name, m in inspect.getmembers(cls, inspect.isfunction):
258                 if name.startswith(prefix):
259                     setattr(cls, name, check_method_for_perf_regressions(m))
260                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
261         return cls
262
263     return decorate_the_testcase
264
265
266 def breakpoint():
267     """Hard code a breakpoint somewhere; drop into pdb."""
268     import pdb
269
270     pdb.set_trace()
271
272
273 class RecordStdout(object):
274     """
275     Record what is emitted to stdout.
276
277     >>> with RecordStdout() as record:
278     ...     print("This is a test!")
279     >>> print({record().readline()})
280     {'This is a test!\\n'}
281     >>> record().close()
282     """
283
284     def __init__(self) -> None:
285         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
286         self.recorder: Optional[contextlib.redirect_stdout] = None
287
288     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
289         self.recorder = contextlib.redirect_stdout(self.destination)
290         assert self.recorder
291         self.recorder.__enter__()
292         return lambda: self.destination
293
294     def __exit__(self, *args) -> Optional[bool]:
295         assert self.recorder
296         self.recorder.__exit__(*args)
297         self.destination.seek(0)
298         return None
299
300
301 class RecordStderr(object):
302     """
303     Record what is emitted to stderr.
304
305     >>> import sys
306     >>> with RecordStderr() as record:
307     ...     print("This is a test!", file=sys.stderr)
308     >>> print({record().readline()})
309     {'This is a test!\\n'}
310     >>> record().close()
311     """
312
313     def __init__(self) -> None:
314         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
315         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
316
317     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
318         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
319         assert self.recorder
320         self.recorder.__enter__()
321         return lambda: self.destination
322
323     def __exit__(self, *args) -> Optional[bool]:
324         assert self.recorder
325         self.recorder.__exit__(*args)
326         self.destination.seek(0)
327         return None
328
329
330 class RecordMultipleStreams(object):
331     """
332     Record the output to more than one stream.
333     """
334
335     def __init__(self, *files) -> None:
336         self.files = [*files]
337         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
338         self.saved_writes: List[Callable[..., Any]] = []
339
340     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
341         for f in self.files:
342             self.saved_writes.append(f.write)
343             f.write = self.destination.write
344         return lambda: self.destination
345
346     def __exit__(self, *args) -> Optional[bool]:
347         for f in self.files:
348             f.write = self.saved_writes.pop()
349         self.destination.seek(0)
350         return None
351
352
353 if __name__ == '__main__':
354     import doctest
355
356     doctest.testmod()