52f4f5319ac68e569dba7f12bb3c72506811839f
[pyutils.git] / src / pyutils / unittest_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Helpers for unittests.
6
7 .. warning::
8
9     When you import this we automatically wrap the standard Python
10     `unittest.main` with a call to :meth:`pyutils.bootstrap.initialize`
11     so that we get logger config, commandline args, logging control,
12     etc... this works fine but may be unexpected behavior.
13 """
14
15 import contextlib
16 import functools
17 import inspect
18 import logging
19 import os
20 import pickle
21 import random
22 import statistics
23 import tempfile
24 import time
25 import unittest
26 import warnings
27 from abc import ABC, abstractmethod
28 from typing import Any, Callable, Dict, List, Literal, Optional
29
30 from pyutils import bootstrap, config, function_utils
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34     f'Logging ({__file__})', 'Args related to function decorators'
35 )
36 cfg.add_argument(
37     '--unittests_ignore_perf',
38     action='store_true',
39     default=False,
40     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
41 )
42 cfg.add_argument(
43     '--unittests_num_perf_samples',
44     type=int,
45     default=50,
46     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
47 )
48 cfg.add_argument(
49     '--unittests_drop_perf_traces',
50     type=str,
51     nargs=1,
52     default=None,
53     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
54 )
55 cfg.add_argument(
56     '--unittests_persistance_strategy',
57     choices=['FILE', 'DATABASE'],
58     default='FILE',
59     help='Should we persist perf data in a file or db?',
60 )
61 cfg.add_argument(
62     '--unittests_perfdb_filename',
63     type=str,
64     metavar='FILENAME',
65     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
67 )
68 cfg.add_argument(
69     '--unittests_perfdb_spec',
70     type=str,
71     metavar='DBSPEC',
72     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
73     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
74 )
75
76 unittest.main = bootstrap.initialize(unittest.main)
77
78
79 class PerfRegressionDataPersister(ABC):
80     """A base class that defines an interface for dealing with
81     persisting perf regression data.
82     """
83
84     @abstractmethod
85     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
86         """Load the historical performance data for the supplied method.
87
88         Args:
89             method_id: the method for which we want historical perf data.
90         """
91         pass
92
93     @abstractmethod
94     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
95         """Save the historical performance data of the supplied method.
96
97         Args:
98             method_id: the method whose historical perf data we're saving.
99             data: the historical performance data being persisted.
100         """
101         pass
102
103     @abstractmethod
104     def delete_performance_data(self, method_id: str):
105         """Delete the historical performance data of the supplied method.
106
107         Args:
108             method_id: the method whose data should be erased.
109         """
110         pass
111
112
113 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
114     """A perf regression data persister that uses files."""
115
116     def __init__(self, filename: str):
117         """
118         Args:
119             filename: the filename to save/load historical performance data
120         """
121         super().__init__()
122         self.filename = filename
123         self.traces_to_delete: List[str] = []
124
125     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126         with open(self.filename, 'rb') as f:
127             return pickle.load(f)
128
129     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
130         for trace in self.traces_to_delete:
131             if trace in data:
132                 data[trace] = []
133
134         with open(self.filename, 'wb') as f:
135             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
136
137     def delete_performance_data(self, method_id: str):
138         self.traces_to_delete.append(method_id)
139
140
141 # class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
142 #    """A perf regression data persister that uses a database backend."""
143 #
144 #    def __init__(self, dbspec: str):
145 #        super().__init__()
146 #        self.dbspec = dbspec
147 #        self.engine = sa.create_engine(self.dbspec)
148 #        self.conn = self.engine.connect()
149 #
150 #    def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
151 #        results = self.conn.execute(
152 #            sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
153 #        )
154 #        ret: Dict[str, List[float]] = {method_id: []}
155 #        for result in results.all():
156 #            ret[method_id].append(result['runtime'])
157 #        results.close()
158 #        return ret
159 #
160 #    def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
161 #        self.delete_performance_data(method_id)
162 #        for (mid, perf_data) in data.items():
163 #            sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
164 #            for perf in perf_data:
165 #                self.conn.execute(sql + f'("{mid}", {perf});')
166 #
167 #    def delete_performance_data(self, method_id: str):
168 #        sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
169 #        self.conn.execute(sql)
170
171
172 def check_method_for_perf_regressions(func: Callable) -> Callable:
173     """This decorator is meant to be used on a method in a class that
174     subclasses :class:`unittest.TestCase`.  When decorated, method
175     execution timing (i.e. performance) will be measured and compared
176     with a database of historical performance for the same method.
177     The wrapper will then fail the test with a perf-related message if
178     it has become too slow.
179
180     See also :meth:`check_all_methods_for_perf_regressions`.
181
182     Example usage::
183
184         class TestMyClass(unittest.TestCase):
185
186             @check_method_for_perf_regressions
187             def test_some_part_of_my_class(self):
188                 ...
189
190     """
191
192     @functools.wraps(func)
193     def wrapper_perf_monitor(*args, **kwargs):
194         if config.config['unittests_ignore_perf']:
195             return func(*args, **kwargs)
196
197         if config.config['unittests_persistance_strategy'] == 'FILE':
198             filename = config.config['unittests_perfdb_filename']
199             helper = FileBasedPerfRegressionDataPersister(filename)
200         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
201             raise NotImplementedError(
202                 'Persisting to a database is not implemented in this version'
203             )
204         else:
205             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
206
207         func_id = function_utils.function_identifier(func)
208         func_name = func.__name__
209         logger.debug('Watching %s\'s performance...', func_name)
210         logger.debug('Canonical function identifier = "%s"', func_id)
211
212         try:
213             perfdb = helper.load_performance_data(func_id)
214         except Exception as e:
215             logger.exception(e)
216             msg = 'Unable to load perfdb; skipping it...'
217             logger.warning(msg)
218             warnings.warn(msg)
219             perfdb = {}
220
221         # cmdline arg to forget perf traces for function
222         drop_id = config.config['unittests_drop_perf_traces']
223         if drop_id is not None:
224             helper.delete_performance_data(drop_id)
225
226         # Run the wrapped test paying attention to latency.
227         start_time = time.perf_counter()
228         value = func(*args, **kwargs)
229         end_time = time.perf_counter()
230         run_time = end_time - start_time
231
232         # See if it was unexpectedly slow.
233         hist = perfdb.get(func_id, [])
234         if len(hist) < config.config['unittests_num_perf_samples']:
235             hist.append(run_time)
236             logger.debug('Still establishing a perf baseline for %s', func_name)
237         else:
238             stdev = statistics.stdev(hist)
239             logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
240             slowest = hist[-1]
241             logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
242             limit = slowest + stdev * 4
243             logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
244             logger.debug(
245                 'For %s, actual observed runtime was %.2fs', func_name, run_time
246             )
247             if run_time > limit:
248                 msg = f'''{func_id} performance has regressed unacceptably.
249 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
250 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
251 Here is the current, full db perf timing distribution:
252
253 '''
254                 for x in hist:
255                     msg += f'{x:f}\n'
256                 logger.error(msg)
257                 slf = args[0]  # Peek at the wrapped function's self ref.
258                 slf.fail(msg)  # ...to fail the testcase.
259             else:
260                 hist.append(run_time)
261
262         # Don't spam the database with samples; just pick a random
263         # sample from what we have and store that back.
264         n = min(config.config['unittests_num_perf_samples'], len(hist))
265         hist = random.sample(hist, n)
266         hist.sort()
267         perfdb[func_id] = hist
268         helper.save_performance_data(func_id, perfdb)
269         return value
270
271     return wrapper_perf_monitor
272
273
274 def check_all_methods_for_perf_regressions(prefix='test_'):
275     """This decorator is meant to apply to classes that subclass from
276     :class:`unittest.TestCase` and, when applied, has the affect of
277     decorating each method that matches the `prefix` given with the
278     :meth:`check_method_for_perf_regressions` wrapper (see above).
279     This wrapper causes us to measure perf and fail tests that regress
280     perf dramatically.
281
282     Args:
283         prefix: the prefix of method names to check for regressions
284
285     See also :meth:`check_method_for_perf_regressions` to check only
286     a single method.
287
288     Example usage.  By decorating the class, all methods with names
289     that begin with `test_` will be perf monitored::
290
291         import pyutils.unittest_utils as uu
292
293         @uu.check_all_methods_for_perf_regressions()
294         class TestMyClass(unittest.TestCase):
295
296             def test_some_part_of_my_class(self):
297                 ...
298
299             def test_som_other_part_of_my_class(self):
300                 ...
301     """
302
303     def decorate_the_testcase(cls):
304         if issubclass(cls, unittest.TestCase):
305             for name, m in inspect.getmembers(cls, inspect.isfunction):
306                 if name.startswith(prefix):
307                     setattr(cls, name, check_method_for_perf_regressions(m))
308                     logger.debug('Wrapping %s:%s.', cls.__name__, name)
309         return cls
310
311     return decorate_the_testcase
312
313
314 class RecordStdout(contextlib.AbstractContextManager):
315     """
316     Records what is emitted to stdout into a buffer instead.
317
318     >>> with RecordStdout() as record:
319     ...     print("This is a test!")
320     >>> print({record().readline()})
321     {'This is a test!\\n'}
322     >>> record().close()
323     """
324
325     def __init__(self) -> None:
326         super().__init__()
327         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
328         self.recorder: Optional[contextlib.redirect_stdout] = None
329
330     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
331         self.recorder = contextlib.redirect_stdout(self.destination)
332         assert self.recorder is not None
333         self.recorder.__enter__()
334         return lambda: self.destination
335
336     def __exit__(self, *args) -> Literal[False]:
337         assert self.recorder is not None
338         self.recorder.__exit__(*args)
339         self.destination.seek(0)
340         return False
341
342
343 class RecordStderr(contextlib.AbstractContextManager):
344     """
345     Record what is emitted to stderr.
346
347     >>> import sys
348     >>> with RecordStderr() as record:
349     ...     print("This is a test!", file=sys.stderr)
350     >>> print({record().readline()})
351     {'This is a test!\\n'}
352     >>> record().close()
353     """
354
355     def __init__(self) -> None:
356         super().__init__()
357         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
358         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
359
360     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
361         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
362         assert self.recorder is not None
363         self.recorder.__enter__()
364         return lambda: self.destination
365
366     def __exit__(self, *args) -> Literal[False]:
367         assert self.recorder is not None
368         self.recorder.__exit__(*args)
369         self.destination.seek(0)
370         return False
371
372
373 class RecordMultipleStreams(contextlib.AbstractContextManager):
374     """
375     Record the output to more than one stream.
376
377     Example usage::
378
379         with RecordMultipleStreams(sys.stdout, sys.stderr) as r:
380             print("This is a test!", file=sys.stderr)
381             print("This is too", file=sys.stdout)
382
383         print(r().readlines())
384         r().close()
385
386     """
387
388     def __init__(self, *files) -> None:
389         super().__init__()
390         self.files = [*files]
391         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
392         self.saved_writes: List[Callable[..., Any]] = []
393
394     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
395         for f in self.files:
396             self.saved_writes.append(f.write)
397             f.write = self.destination.write
398         return lambda: self.destination
399
400     def __exit__(self, *args) -> Literal[False]:
401         for f in self.files:
402             f.write = self.saved_writes.pop()
403         self.destination.seek(0)
404         return False
405
406
407 if __name__ == '__main__':
408     import doctest
409
410     doctest.testmod()