Migration from old pyutilz package name (which, in turn, came from
[pyutils.git] / src / pyutils / unittest_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Helpers for unittests.
6
7 .. note::
8
9     When you import this we automatically wrap unittest.main()
10     with a call to bootstrap.initialize so that we getLogger
11     config, commandline args, logging control, etc... this works
12     fine but it's a little hacky so caveat emptor.
13
14 """
15
16 import contextlib
17 import functools
18 import inspect
19 import logging
20 import os
21 import pickle
22 import random
23 import statistics
24 import tempfile
25 import time
26 import unittest
27 import warnings
28 from abc import ABC, abstractmethod
29 from typing import Any, Callable, Dict, List, Literal, Optional
30
31 from pyutils import bootstrap, config, function_utils
32
33 logger = logging.getLogger(__name__)
34 cfg = config.add_commandline_args(
35     f'Logging ({__file__})', 'Args related to function decorators'
36 )
37 cfg.add_argument(
38     '--unittests_ignore_perf',
39     action='store_true',
40     default=False,
41     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
42 )
43 cfg.add_argument(
44     '--unittests_num_perf_samples',
45     type=int,
46     default=50,
47     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
48 )
49 cfg.add_argument(
50     '--unittests_drop_perf_traces',
51     type=str,
52     nargs=1,
53     default=None,
54     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
55 )
56 cfg.add_argument(
57     '--unittests_persistance_strategy',
58     choices=['FILE', 'DATABASE'],
59     default='FILE',
60     help='Should we persist perf data in a file or db?',
61 )
62 cfg.add_argument(
63     '--unittests_perfdb_filename',
64     type=str,
65     metavar='FILENAME',
66     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
67     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
68 )
69 cfg.add_argument(
70     '--unittests_perfdb_spec',
71     type=str,
72     metavar='DBSPEC',
73     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
74     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
75 )
76
77 # >>> This is the hacky business, FYI. <<<
78 unittest.main = bootstrap.initialize(unittest.main)
79
80
81 class PerfRegressionDataPersister(ABC):
82     """A base class for a signature dealing with persisting perf
83     regression data."""
84
85     def __init__(self):
86         pass
87
88     @abstractmethod
89     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
90         pass
91
92     @abstractmethod
93     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
94         pass
95
96     @abstractmethod
97     def delete_performance_data(self, method_id: str):
98         pass
99
100
101 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
102     """A perf regression data persister that uses files."""
103
104     def __init__(self, filename: str):
105         super().__init__()
106         self.filename = filename
107         self.traces_to_delete: List[str] = []
108
109     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
110         with open(self.filename, 'rb') as f:
111             return pickle.load(f)
112
113     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
114         for trace in self.traces_to_delete:
115             if trace in data:
116                 data[trace] = []
117
118         with open(self.filename, 'wb') as f:
119             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
120
121     def delete_performance_data(self, method_id: str):
122         self.traces_to_delete.append(method_id)
123
124
125 # class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
126 #    """A perf regression data persister that uses a database backend."""
127 #
128 #    def __init__(self, dbspec: str):
129 #        super().__init__()
130 #        self.dbspec = dbspec
131 #        self.engine = sa.create_engine(self.dbspec)
132 #        self.conn = self.engine.connect()
133 #
134 #    def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
135 #        results = self.conn.execute(
136 #            sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
137 #        )
138 #        ret: Dict[str, List[float]] = {method_id: []}
139 #        for result in results.all():
140 #            ret[method_id].append(result['runtime'])
141 #        results.close()
142 #        return ret
143 #
144 #    def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
145 #        self.delete_performance_data(method_id)
146 #        for (mid, perf_data) in data.items():
147 #            sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
148 #            for perf in perf_data:
149 #                self.conn.execute(sql + f'("{mid}", {perf});')
150 #
151 #    def delete_performance_data(self, method_id: str):
152 #        sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
153 #        self.conn.execute(sql)
154
155
156 def check_method_for_perf_regressions(func: Callable) -> Callable:
157     """
158     This is meant to be used on a method in a class that subclasses
159     unittest.TestCase.  When thus decorated it will time the execution
160     of the code in the method, compare it with a database of
161     historical perfmance, and fail the test with a perf-related
162     message if it has become too slow.
163
164     """
165
166     @functools.wraps(func)
167     def wrapper_perf_monitor(*args, **kwargs):
168         if config.config['unittests_ignore_perf']:
169             return func(*args, **kwargs)
170
171         if config.config['unittests_persistance_strategy'] == 'FILE':
172             filename = config.config['unittests_perfdb_filename']
173             helper = FileBasedPerfRegressionDataPersister(filename)
174         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
175             raise NotImplementedError(
176                 'Persisting to a database is not implemented in this version'
177             )
178         else:
179             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
180
181         func_id = function_utils.function_identifier(func)
182         func_name = func.__name__
183         logger.debug('Watching %s\'s performance...', func_name)
184         logger.debug('Canonical function identifier = "%s"', func_id)
185
186         try:
187             perfdb = helper.load_performance_data(func_id)
188         except Exception as e:
189             logger.exception(e)
190             msg = 'Unable to load perfdb; skipping it...'
191             logger.warning(msg)
192             warnings.warn(msg)
193             perfdb = {}
194
195         # cmdline arg to forget perf traces for function
196         drop_id = config.config['unittests_drop_perf_traces']
197         if drop_id is not None:
198             helper.delete_performance_data(drop_id)
199
200         # Run the wrapped test paying attention to latency.
201         start_time = time.perf_counter()
202         value = func(*args, **kwargs)
203         end_time = time.perf_counter()
204         run_time = end_time - start_time
205
206         # See if it was unexpectedly slow.
207         hist = perfdb.get(func_id, [])
208         if len(hist) < config.config['unittests_num_perf_samples']:
209             hist.append(run_time)
210             logger.debug('Still establishing a perf baseline for %s', func_name)
211         else:
212             stdev = statistics.stdev(hist)
213             logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
214             slowest = hist[-1]
215             logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
216             limit = slowest + stdev * 4
217             logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
218             logger.debug(
219                 'For %s, actual observed runtime was %.2fs', func_name, run_time
220             )
221             if run_time > limit:
222                 msg = f'''{func_id} performance has regressed unacceptably.
223 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
224 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
225 Here is the current, full db perf timing distribution:
226
227 '''
228                 for x in hist:
229                     msg += f'{x:f}\n'
230                 logger.error(msg)
231                 slf = args[0]  # Peek at the wrapped function's self ref.
232                 slf.fail(msg)  # ...to fail the testcase.
233             else:
234                 hist.append(run_time)
235
236         # Don't spam the database with samples; just pick a random
237         # sample from what we have and store that back.
238         n = min(config.config['unittests_num_perf_samples'], len(hist))
239         hist = random.sample(hist, n)
240         hist.sort()
241         perfdb[func_id] = hist
242         helper.save_performance_data(func_id, perfdb)
243         return value
244
245     return wrapper_perf_monitor
246
247
248 def check_all_methods_for_perf_regressions(prefix='test_'):
249     """Decorate unittests with this to pay attention to the perf of the
250     testcode and flag perf regressions.  e.g.
251
252     import pyutils.unittest_utils as uu
253
254     @uu.check_all_methods_for_perf_regressions()
255     class TestMyClass(unittest.TestCase):
256
257         def test_some_part_of_my_class(self):
258             ...
259
260     """
261
262     def decorate_the_testcase(cls):
263         if issubclass(cls, unittest.TestCase):
264             for name, m in inspect.getmembers(cls, inspect.isfunction):
265                 if name.startswith(prefix):
266                     setattr(cls, name, check_method_for_perf_regressions(m))
267                     logger.debug('Wrapping %s:%s.', cls.__name__, name)
268         return cls
269
270     return decorate_the_testcase
271
272
273 class RecordStdout(contextlib.AbstractContextManager):
274     """
275     Record what is emitted to stdout.
276
277     >>> with RecordStdout() as record:
278     ...     print("This is a test!")
279     >>> print({record().readline()})
280     {'This is a test!\\n'}
281     >>> record().close()
282     """
283
284     def __init__(self) -> None:
285         super().__init__()
286         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
287         self.recorder: Optional[contextlib.redirect_stdout] = None
288
289     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
290         self.recorder = contextlib.redirect_stdout(self.destination)
291         assert self.recorder is not None
292         self.recorder.__enter__()
293         return lambda: self.destination
294
295     def __exit__(self, *args) -> Literal[False]:
296         assert self.recorder is not None
297         self.recorder.__exit__(*args)
298         self.destination.seek(0)
299         return False
300
301
302 class RecordStderr(contextlib.AbstractContextManager):
303     """
304     Record what is emitted to stderr.
305
306     >>> import sys
307     >>> with RecordStderr() as record:
308     ...     print("This is a test!", file=sys.stderr)
309     >>> print({record().readline()})
310     {'This is a test!\\n'}
311     >>> record().close()
312     """
313
314     def __init__(self) -> None:
315         super().__init__()
316         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
317         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
318
319     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
320         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
321         assert self.recorder is not None
322         self.recorder.__enter__()
323         return lambda: self.destination
324
325     def __exit__(self, *args) -> Literal[False]:
326         assert self.recorder is not None
327         self.recorder.__exit__(*args)
328         self.destination.seek(0)
329         return False
330
331
332 class RecordMultipleStreams(contextlib.AbstractContextManager):
333     """
334     Record the output to more than one stream.
335     """
336
337     def __init__(self, *files) -> None:
338         super().__init__()
339         self.files = [*files]
340         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
341         self.saved_writes: List[Callable[..., Any]] = []
342
343     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
344         for f in self.files:
345             self.saved_writes.append(f.write)
346             f.write = self.destination.write
347         return lambda: self.destination
348
349     def __exit__(self, *args) -> Literal[False]:
350         for f in self.files:
351             f.write = self.saved_writes.pop()
352         self.destination.seek(0)
353         return False
354
355
356 if __name__ == '__main__':
357     import doctest
358
359     doctest.testmod()