88e41954811de26629b405cb4ee7255d8bbebc62
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import os
15 import pickle
16 import random
17 import statistics
18 import tempfile
19 import time
20 import unittest
21 import warnings
22 from abc import ABC, abstractmethod
23 from typing import Any, Callable, Dict, List, Literal, Optional
24
25 import sqlalchemy as sa
26
27 import bootstrap
28 import config
29 import function_utils
30 import scott_secrets
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators')
34 cfg.add_argument(
35     '--unittests_ignore_perf',
36     action='store_true',
37     default=False,
38     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
39 )
40 cfg.add_argument(
41     '--unittests_num_perf_samples',
42     type=int,
43     default=50,
44     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
45 )
46 cfg.add_argument(
47     '--unittests_drop_perf_traces',
48     type=str,
49     nargs=1,
50     default=None,
51     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
52 )
53 cfg.add_argument(
54     '--unittests_persistance_strategy',
55     choices=['FILE', 'DATABASE'],
56     default='DATABASE',
57     help='Should we persist perf data in a file or db?',
58 )
59 cfg.add_argument(
60     '--unittests_perfdb_filename',
61     type=str,
62     metavar='FILENAME',
63     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
64     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
65 )
66 cfg.add_argument(
67     '--unittests_perfdb_spec',
68     type=str,
69     metavar='DBSPEC',
70     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
71     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
72 )
73
74 # >>> This is the hacky business, FYI. <<<
75 unittest.main = bootstrap.initialize(unittest.main)
76
77
78 class PerfRegressionDataPersister(ABC):
79     """A base class for a signature dealing with persisting perf
80     regression data."""
81
82     def __init__(self):
83         pass
84
85     @abstractmethod
86     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
87         pass
88
89     @abstractmethod
90     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
91         pass
92
93     @abstractmethod
94     def delete_performance_data(self, method_id: str):
95         pass
96
97
98 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
99     """A perf regression data persister that uses files."""
100
101     def __init__(self, filename: str):
102         super().__init__()
103         self.filename = filename
104         self.traces_to_delete: List[str] = []
105
106     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
107         with open(self.filename, 'rb') as f:
108             return pickle.load(f)
109
110     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
111         for trace in self.traces_to_delete:
112             if trace in data:
113                 data[trace] = []
114
115         with open(self.filename, 'wb') as f:
116             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
117
118     def delete_performance_data(self, method_id: str):
119         self.traces_to_delete.append(method_id)
120
121
122 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
123     """A perf regression data persister that uses a database backend."""
124
125     def __init__(self, dbspec: str):
126         super().__init__()
127         self.dbspec = dbspec
128         self.engine = sa.create_engine(self.dbspec)
129         self.conn = self.engine.connect()
130
131     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
132         results = self.conn.execute(
133             sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
134         )
135         ret: Dict[str, List[float]] = {method_id: []}
136         for result in results.all():
137             ret[method_id].append(result['runtime'])
138         results.close()
139         return ret
140
141     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
142         self.delete_performance_data(method_id)
143         for (mid, perf_data) in data.items():
144             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
145             for perf in perf_data:
146                 self.conn.execute(sql + f'("{mid}", {perf});')
147
148     def delete_performance_data(self, method_id: str):
149         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
150         self.conn.execute(sql)
151
152
153 def check_method_for_perf_regressions(func: Callable) -> Callable:
154     """
155     This is meant to be used on a method in a class that subclasses
156     unittest.TestCase.  When thus decorated it will time the execution
157     of the code in the method, compare it with a database of
158     historical perfmance, and fail the test with a perf-related
159     message if it has become too slow.
160
161     """
162
163     @functools.wraps(func)
164     def wrapper_perf_monitor(*args, **kwargs):
165         if config.config['unittests_ignore_perf']:
166             return func(*args, **kwargs)
167
168         if config.config['unittests_persistance_strategy'] == 'FILE':
169             filename = config.config['unittests_perfdb_filename']
170             helper = FileBasedPerfRegressionDataPersister(filename)
171         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
172             dbspec = config.config['unittests_perfdb_spec']
173             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
174             helper = DatabasePerfRegressionDataPersister(dbspec)
175         else:
176             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
177
178         func_id = function_utils.function_identifier(func)
179         func_name = func.__name__
180         logger.debug('Watching %s\'s performance...', func_name)
181         logger.debug('Canonical function identifier = "%s"', func_id)
182
183         try:
184             perfdb = helper.load_performance_data(func_id)
185         except Exception as e:
186             logger.exception(e)
187             msg = 'Unable to load perfdb; skipping it...'
188             logger.warning(msg)
189             warnings.warn(msg)
190             perfdb = {}
191
192         # cmdline arg to forget perf traces for function
193         drop_id = config.config['unittests_drop_perf_traces']
194         if drop_id is not None:
195             helper.delete_performance_data(drop_id)
196
197         # Run the wrapped test paying attention to latency.
198         start_time = time.perf_counter()
199         value = func(*args, **kwargs)
200         end_time = time.perf_counter()
201         run_time = end_time - start_time
202
203         # See if it was unexpectedly slow.
204         hist = perfdb.get(func_id, [])
205         if len(hist) < config.config['unittests_num_perf_samples']:
206             hist.append(run_time)
207             logger.debug('Still establishing a perf baseline for %s', func_name)
208         else:
209             stdev = statistics.stdev(hist)
210             logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
211             slowest = hist[-1]
212             logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
213             limit = slowest + stdev * 4
214             logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
215             logger.debug('For %s, actual observed runtime was %.2fs', func_name, run_time)
216             if run_time > limit:
217                 msg = f'''{func_id} performance has regressed unacceptably.
218 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
219 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
220 Here is the current, full db perf timing distribution:
221
222 '''
223                 for x in hist:
224                     msg += f'{x:f}\n'
225                 logger.error(msg)
226                 slf = args[0]  # Peek at the wrapped function's self ref.
227                 slf.fail(msg)  # ...to fail the testcase.
228             else:
229                 hist.append(run_time)
230
231         # Don't spam the database with samples; just pick a random
232         # sample from what we have and store that back.
233         n = min(config.config['unittests_num_perf_samples'], len(hist))
234         hist = random.sample(hist, n)
235         hist.sort()
236         perfdb[func_id] = hist
237         helper.save_performance_data(func_id, perfdb)
238         return value
239
240     return wrapper_perf_monitor
241
242
243 def check_all_methods_for_perf_regressions(prefix='test_'):
244     """Decorate unittests with this to pay attention to the perf of the
245     testcode and flag perf regressions.  e.g.
246
247     import unittest_utils as uu
248
249     @uu.check_all_methods_for_perf_regressions()
250     class TestMyClass(unittest.TestCase):
251
252         def test_some_part_of_my_class(self):
253             ...
254
255     """
256
257     def decorate_the_testcase(cls):
258         if issubclass(cls, unittest.TestCase):
259             for name, m in inspect.getmembers(cls, inspect.isfunction):
260                 if name.startswith(prefix):
261                     setattr(cls, name, check_method_for_perf_regressions(m))
262                     logger.debug('Wrapping %s:%s.', cls.__name__, name)
263         return cls
264
265     return decorate_the_testcase
266
267
268 class RecordStdout(contextlib.AbstractContextManager):
269     """
270     Record what is emitted to stdout.
271
272     >>> with RecordStdout() as record:
273     ...     print("This is a test!")
274     >>> print({record().readline()})
275     {'This is a test!\\n'}
276     >>> record().close()
277     """
278
279     def __init__(self) -> None:
280         super().__init__()
281         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
282         self.recorder: Optional[contextlib.redirect_stdout] = None
283
284     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
285         self.recorder = contextlib.redirect_stdout(self.destination)
286         assert self.recorder is not None
287         self.recorder.__enter__()
288         return lambda: self.destination
289
290     def __exit__(self, *args) -> Literal[False]:
291         assert self.recorder is not None
292         self.recorder.__exit__(*args)
293         self.destination.seek(0)
294         return False
295
296
297 class RecordStderr(contextlib.AbstractContextManager):
298     """
299     Record what is emitted to stderr.
300
301     >>> import sys
302     >>> with RecordStderr() as record:
303     ...     print("This is a test!", file=sys.stderr)
304     >>> print({record().readline()})
305     {'This is a test!\\n'}
306     >>> record().close()
307     """
308
309     def __init__(self) -> None:
310         super().__init__()
311         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
312         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
313
314     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
315         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
316         assert self.recorder is not None
317         self.recorder.__enter__()
318         return lambda: self.destination
319
320     def __exit__(self, *args) -> Literal[False]:
321         assert self.recorder is not None
322         self.recorder.__exit__(*args)
323         self.destination.seek(0)
324         return False
325
326
327 class RecordMultipleStreams(contextlib.AbstractContextManager):
328     """
329     Record the output to more than one stream.
330     """
331
332     def __init__(self, *files) -> None:
333         super().__init__()
334         self.files = [*files]
335         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
336         self.saved_writes: List[Callable[..., Any]] = []
337
338     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
339         for f in self.files:
340             self.saved_writes.append(f.write)
341             f.write = self.destination.write
342         return lambda: self.destination
343
344     def __exit__(self, *args) -> Literal[False]:
345         for f in self.files:
346             f.write = self.saved_writes.pop()
347         self.destination.seek(0)
348         return False
349
350
351 if __name__ == '__main__':
352     import doctest
353
354     doctest.testmod()