f229df75e8b88825d66ca227d7e907d3dc725e1a
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 import contextlib
11 import functools
12 import inspect
13 import logging
14 import os
15 import pickle
16 import random
17 import statistics
18 import tempfile
19 import time
20 import unittest
21 import warnings
22 from abc import ABC, abstractmethod
23 from typing import Any, Callable, Dict, List, Optional
24
25 import sqlalchemy as sa
26
27 import bootstrap
28 import config
29 import function_utils
30 import scott_secrets
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators')
34 cfg.add_argument(
35     '--unittests_ignore_perf',
36     action='store_true',
37     default=False,
38     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
39 )
40 cfg.add_argument(
41     '--unittests_num_perf_samples',
42     type=int,
43     default=50,
44     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
45 )
46 cfg.add_argument(
47     '--unittests_drop_perf_traces',
48     type=str,
49     nargs=1,
50     default=None,
51     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
52 )
53 cfg.add_argument(
54     '--unittests_persistance_strategy',
55     choices=['FILE', 'DATABASE'],
56     default='DATABASE',
57     help='Should we persist perf data in a file or db?',
58 )
59 cfg.add_argument(
60     '--unittests_perfdb_filename',
61     type=str,
62     metavar='FILENAME',
63     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
64     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
65 )
66 cfg.add_argument(
67     '--unittests_perfdb_spec',
68     type=str,
69     metavar='DBSPEC',
70     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
71     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
72 )
73
74 # >>> This is the hacky business, FYI. <<<
75 unittest.main = bootstrap.initialize(unittest.main)
76
77
78 class PerfRegressionDataPersister(ABC):
79     def __init__(self):
80         pass
81
82     @abstractmethod
83     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
84         pass
85
86     @abstractmethod
87     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
88         pass
89
90     @abstractmethod
91     def delete_performance_data(self, method_id: str):
92         pass
93
94
95 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
96     def __init__(self, filename: str):
97         self.filename = filename
98         self.traces_to_delete: List[str] = []
99
100     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
101         with open(self.filename, 'rb') as f:
102             return pickle.load(f)
103
104     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
105         for trace in self.traces_to_delete:
106             if trace in data:
107                 data[trace] = []
108
109         with open(self.filename, 'wb') as f:
110             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
111
112     def delete_performance_data(self, method_id: str):
113         self.traces_to_delete.append(method_id)
114
115
116 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
117     def __init__(self, dbspec: str):
118         self.dbspec = dbspec
119         self.engine = sa.create_engine(self.dbspec)
120         self.conn = self.engine.connect()
121
122     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
123         results = self.conn.execute(
124             sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
125         )
126         ret: Dict[str, List[float]] = {method_id: []}
127         for result in results.all():
128             ret[method_id].append(result['runtime'])
129         results.close()
130         return ret
131
132     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
133         self.delete_performance_data(method_id)
134         for (method_id, perf_data) in data.items():
135             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
136             for perf in perf_data:
137                 self.conn.execute(sql + f'("{method_id}", {perf});')
138
139     def delete_performance_data(self, method_id: str):
140         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
141         self.conn.execute(sql)
142
143
144 def check_method_for_perf_regressions(func: Callable) -> Callable:
145     """
146     This is meant to be used on a method in a class that subclasses
147     unittest.TestCase.  When thus decorated it will time the execution
148     of the code in the method, compare it with a database of
149     historical perfmance, and fail the test with a perf-related
150     message if it has become too slow.
151
152     """
153
154     @functools.wraps(func)
155     def wrapper_perf_monitor(*args, **kwargs):
156         if config.config['unittests_ignore_perf']:
157             return func(*args, **kwargs)
158
159         if config.config['unittests_persistance_strategy'] == 'FILE':
160             filename = config.config['unittests_perfdb_filename']
161             helper = FileBasedPerfRegressionDataPersister(filename)
162         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
163             dbspec = config.config['unittests_perfdb_spec']
164             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
165             helper = DatabasePerfRegressionDataPersister(dbspec)
166         else:
167             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
168
169         func_id = function_utils.function_identifier(func)
170         func_name = func.__name__
171         logger.debug(f'Watching {func_name}\'s performance...')
172         logger.debug(f'Canonical function identifier = {func_id}')
173
174         try:
175             perfdb = helper.load_performance_data(func_id)
176         except Exception as e:
177             logger.exception(e)
178             msg = 'Unable to load perfdb; skipping it...'
179             logger.warning(msg)
180             warnings.warn(msg)
181             perfdb = {}
182
183         # cmdline arg to forget perf traces for function
184         drop_id = config.config['unittests_drop_perf_traces']
185         if drop_id is not None:
186             helper.delete_performance_data(drop_id)
187
188         # Run the wrapped test paying attention to latency.
189         start_time = time.perf_counter()
190         value = func(*args, **kwargs)
191         end_time = time.perf_counter()
192         run_time = end_time - start_time
193
194         # See if it was unexpectedly slow.
195         hist = perfdb.get(func_id, [])
196         if len(hist) < config.config['unittests_num_perf_samples']:
197             hist.append(run_time)
198             logger.debug(f'Still establishing a perf baseline for {func_name}')
199         else:
200             stdev = statistics.stdev(hist)
201             logger.debug(f'For {func_name}, performance stdev={stdev}')
202             slowest = hist[-1]
203             logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s')
204             limit = slowest + stdev * 4
205             logger.debug(f'For {func_name}, max acceptable runtime is {limit:f}s')
206             logger.debug(f'For {func_name}, actual observed runtime was {run_time:f}s')
207             if run_time > limit:
208                 msg = f'''{func_id} performance has regressed unacceptably.
209 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
210 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
211 Here is the current, full db perf timing distribution:
212
213 '''
214                 for x in hist:
215                     msg += f'{x:f}\n'
216                 logger.error(msg)
217                 slf = args[0]  # Peek at the wrapped function's self ref.
218                 slf.fail(msg)  # ...to fail the testcase.
219             else:
220                 hist.append(run_time)
221
222         # Don't spam the database with samples; just pick a random
223         # sample from what we have and store that back.
224         n = min(config.config['unittests_num_perf_samples'], len(hist))
225         hist = random.sample(hist, n)
226         hist.sort()
227         perfdb[func_id] = hist
228         helper.save_performance_data(func_id, perfdb)
229         return value
230
231     return wrapper_perf_monitor
232
233
234 def check_all_methods_for_perf_regressions(prefix='test_'):
235     """Decorate unittests with this to pay attention to the perf of the
236     testcode and flag perf regressions.  e.g.
237
238     import unittest_utils as uu
239
240     @uu.check_all_methods_for_perf_regressions()
241     class TestMyClass(unittest.TestCase):
242
243         def test_some_part_of_my_class(self):
244             ...
245
246     """
247
248     def decorate_the_testcase(cls):
249         if issubclass(cls, unittest.TestCase):
250             for name, m in inspect.getmembers(cls, inspect.isfunction):
251                 if name.startswith(prefix):
252                     setattr(cls, name, check_method_for_perf_regressions(m))
253                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
254         return cls
255
256     return decorate_the_testcase
257
258
259 def breakpoint():
260     """Hard code a breakpoint somewhere; drop into pdb."""
261     import pdb
262
263     pdb.set_trace()
264
265
266 class RecordStdout(object):
267     """
268     Record what is emitted to stdout.
269
270     >>> with RecordStdout() as record:
271     ...     print("This is a test!")
272     >>> print({record().readline()})
273     {'This is a test!\\n'}
274     >>> record().close()
275     """
276
277     def __init__(self) -> None:
278         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
279         self.recorder: Optional[contextlib.redirect_stdout] = None
280
281     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
282         self.recorder = contextlib.redirect_stdout(self.destination)
283         assert self.recorder is not None
284         self.recorder.__enter__()
285         return lambda: self.destination
286
287     def __exit__(self, *args) -> Optional[bool]:
288         assert self.recorder is not None
289         self.recorder.__exit__(*args)
290         self.destination.seek(0)
291         return None
292
293
294 class RecordStderr(object):
295     """
296     Record what is emitted to stderr.
297
298     >>> import sys
299     >>> with RecordStderr() as record:
300     ...     print("This is a test!", file=sys.stderr)
301     >>> print({record().readline()})
302     {'This is a test!\\n'}
303     >>> record().close()
304     """
305
306     def __init__(self) -> None:
307         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
308         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
309
310     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
311         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
312         assert self.recorder is not None
313         self.recorder.__enter__()
314         return lambda: self.destination
315
316     def __exit__(self, *args) -> Optional[bool]:
317         assert self.recorder is not None
318         self.recorder.__exit__(*args)
319         self.destination.seek(0)
320         return None
321
322
323 class RecordMultipleStreams(object):
324     """
325     Record the output to more than one stream.
326     """
327
328     def __init__(self, *files) -> None:
329         self.files = [*files]
330         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
331         self.saved_writes: List[Callable[..., Any]] = []
332
333     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
334         for f in self.files:
335             self.saved_writes.append(f.write)
336             f.write = self.destination.write
337         return lambda: self.destination
338
339     def __exit__(self, *args) -> Optional[bool]:
340         for f in self.files:
341             f.write = self.saved_writes.pop()
342         self.destination.seek(0)
343         return None
344
345
346 if __name__ == '__main__':
347     import doctest
348
349     doctest.testmod()