I guess it's 2023 now...
[pyutils.git] / src / pyutils / unittest_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """Helpers for unittests.
6
7 .. warning::
8
9     When you import this we automatically wrap the standard Python
10     `unittest.main` with a call to :meth:`pyutils.bootstrap.initialize`
11     so that we get logger config, commandline args, logging control,
12     etc... this works fine but may be unexpected behavior.
13 """
14
15 import contextlib
16 import functools
17 import inspect
18 import logging
19 import os
20 import pickle
21 import random
22 import statistics
23 import tempfile
24 import time
25 import unittest
26 import warnings
27 from abc import ABC, abstractmethod
28 from typing import Any, Callable, Dict, List, Literal, Optional
29
30 from pyutils import bootstrap, config, function_utils
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34     f"Logging ({__file__})", "Args related to function decorators"
35 )
36 cfg.add_argument(
37     "--unittests_ignore_perf",
38     action="store_true",
39     default=False,
40     help="Ignore unittest perf regression in @check_method_for_perf_regressions",
41 )
42 cfg.add_argument(
43     "--unittests_num_perf_samples",
44     type=int,
45     default=50,
46     help="The count of perf timing samples we need to see before blocking slow runs on perf grounds",
47 )
48 cfg.add_argument(
49     "--unittests_drop_perf_traces",
50     type=str,
51     nargs=1,
52     default=None,
53     help="The identifier (i.e. file!test_fixture) for which we should drop all perf data",
54 )
55 cfg.add_argument(
56     "--unittests_persistance_strategy",
57     choices=["FILE", "DATABASE"],
58     default="FILE",
59     help="Should we persist perf data in a file or db?",
60 )
61 cfg.add_argument(
62     "--unittests_perfdb_filename",
63     type=str,
64     metavar="FILENAME",
65     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66     help="File in which to store perf data (iff --unittests_persistance_strategy is FILE)",
67 )
68 cfg.add_argument(
69     "--unittests_perfdb_spec",
70     type=str,
71     metavar="DBSPEC",
72     default="mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance",
73     help="Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)",
74 )
75
76 unittest.main = bootstrap.initialize(unittest.main)
77
78
79 class PerfRegressionDataPersister(ABC):
80     """A base class that defines an interface for dealing with
81     persisting perf regression data.
82     """
83
84     @abstractmethod
85     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
86         """Load the historical performance data for the supplied method.
87
88         Args:
89             method_id: the method for which we want historical perf data.
90         """
91         pass
92
93     @abstractmethod
94     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
95         """Save the historical performance data of the supplied method.
96
97         Args:
98             method_id: the method whose historical perf data we're saving.
99             data: the historical performance data being persisted.
100         """
101         pass
102
103     @abstractmethod
104     def delete_performance_data(self, method_id: str):
105         """Delete the historical performance data of the supplied method.
106
107         Args:
108             method_id: the method whose data should be erased.
109         """
110         pass
111
112
113 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
114     """A perf regression data persister that uses files."""
115
116     def __init__(self, filename: str):
117         """
118         Args:
119             filename: the filename to save/load historical performance data
120         """
121         super().__init__()
122         self.filename = filename
123         self.traces_to_delete: List[str] = []
124
125     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126         with open(self.filename, "rb") as f:
127             return pickle.load(f)
128
129     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
130         for trace in self.traces_to_delete:
131             if trace in data:
132                 data[trace] = []
133
134         with open(self.filename, "wb") as f:
135             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
136
137     def delete_performance_data(self, method_id: str):
138         self.traces_to_delete.append(method_id)
139
140
141 # class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
142 #    """A perf regression data persister that uses a database backend."""
143 #
144 #    def __init__(self, dbspec: str):
145 #        super().__init__()
146 #        self.dbspec = dbspec
147 #        self.engine = sa.create_engine(self.dbspec)
148 #        self.conn = self.engine.connect()
149 #
150 #    def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
151 #        results = self.conn.execute(
152 #            sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
153 #        )
154 #        ret: Dict[str, List[float]] = {method_id: []}
155 #        for result in results.all():
156 #            ret[method_id].append(result['runtime'])
157 #        results.close()
158 #        return ret
159 #
160 #    def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
161 #        self.delete_performance_data(method_id)
162 #        for (mid, perf_data) in data.items():
163 #            sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
164 #            for perf in perf_data:
165 #                self.conn.execute(sql + f'("{mid}", {perf});')
166 #
167 #    def delete_performance_data(self, method_id: str):
168 #        sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
169 #        self.conn.execute(sql)
170
171
172 def check_method_for_perf_regressions(func: Callable) -> Callable:
173     """This decorator is meant to be used on a method in a class that
174     subclasses :class:`unittest.TestCase`.  When decorated, method
175     execution timing (i.e. performance) will be measured and compared
176     with a database of historical performance for the same method.
177     The wrapper will then fail the test with a perf-related message if
178     it has become too slow.
179
180     See also :meth:`check_all_methods_for_perf_regressions`.
181
182     Example usage::
183
184         class TestMyClass(unittest.TestCase):
185
186             @check_method_for_perf_regressions
187             def test_some_part_of_my_class(self):
188                 ...
189
190     """
191
192     @functools.wraps(func)
193     def wrapper_perf_monitor(*args, **kwargs):
194         if config.config["unittests_ignore_perf"]:
195             return func(*args, **kwargs)
196
197         if config.config["unittests_persistance_strategy"] == "FILE":
198             filename = config.config["unittests_perfdb_filename"]
199             helper = FileBasedPerfRegressionDataPersister(filename)
200         elif config.config["unittests_persistance_strategy"] == "DATABASE":
201             raise NotImplementedError(
202                 "Persisting to a database is not implemented in this version"
203             )
204         else:
205             raise Exception("Unknown/unexpected --unittests_persistance_strategy value")
206
207         func_id = function_utils.function_identifier(func)
208         func_name = func.__name__
209         logger.debug("Watching %s's performance...", func_name)
210         logger.debug('Canonical function identifier = "%s"', func_id)
211
212         try:
213             perfdb = helper.load_performance_data(func_id)
214         except Exception:
215             msg = "Unable to load perfdb; skipping it..."
216             logger.exception(msg)
217             warnings.warn(msg)
218             perfdb = {}
219
220         # cmdline arg to forget perf traces for function
221         drop_id = config.config["unittests_drop_perf_traces"]
222         if drop_id is not None:
223             helper.delete_performance_data(drop_id)
224
225         # Run the wrapped test paying attention to latency.
226         start_time = time.perf_counter()
227         value = func(*args, **kwargs)
228         end_time = time.perf_counter()
229         run_time = end_time - start_time
230
231         # See if it was unexpectedly slow.
232         hist = perfdb.get(func_id, [])
233         if len(hist) < config.config["unittests_num_perf_samples"]:
234             hist.append(run_time)
235             logger.debug("Still establishing a perf baseline for %s", func_name)
236         else:
237             stdev = statistics.stdev(hist)
238             logger.debug("For %s, performance stdev=%.2f", func_name, stdev)
239             slowest = hist[-1]
240             logger.debug("For %s, slowest perf on record is %.2fs", func_name, slowest)
241             limit = slowest + stdev * 4
242             logger.debug("For %s, max acceptable runtime is %.2fs", func_name, limit)
243             logger.debug(
244                 "For %s, actual observed runtime was %.2fs", func_name, run_time
245             )
246             if run_time > limit:
247                 msg = f"""{func_id} performance has regressed unacceptably.
248 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
249 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
250 Here is the current, full db perf timing distribution:
251
252 """
253                 for x in hist:
254                     msg += f"{x:f}\n"
255                 logger.error(msg)
256                 slf = args[0]  # Peek at the wrapped function's self ref.
257                 slf.fail(msg)  # ...to fail the testcase.
258             else:
259                 hist.append(run_time)
260
261         # Don't spam the database with samples; just pick a random
262         # sample from what we have and store that back.
263         n = min(config.config["unittests_num_perf_samples"], len(hist))
264         hist = random.sample(hist, n)
265         hist.sort()
266         perfdb[func_id] = hist
267         helper.save_performance_data(func_id, perfdb)
268         return value
269
270     return wrapper_perf_monitor
271
272
273 def check_all_methods_for_perf_regressions(prefix="test_"):
274     """This decorator is meant to apply to classes that subclass from
275     :class:`unittest.TestCase` and, when applied, has the affect of
276     decorating each method that matches the `prefix` given with the
277     :meth:`check_method_for_perf_regressions` wrapper (see above).
278     This wrapper causes us to measure perf and fail tests that regress
279     perf dramatically.
280
281     Args:
282         prefix: the prefix of method names to check for regressions
283
284     See also :meth:`check_method_for_perf_regressions` to check only
285     a single method.
286
287     Example usage.  By decorating the class, all methods with names
288     that begin with `test_` will be perf monitored::
289
290         import pyutils.unittest_utils as uu
291
292         @uu.check_all_methods_for_perf_regressions()
293         class TestMyClass(unittest.TestCase):
294
295             def test_some_part_of_my_class(self):
296                 ...
297
298             def test_som_other_part_of_my_class(self):
299                 ...
300     """
301
302     def decorate_the_testcase(cls):
303         if issubclass(cls, unittest.TestCase):
304             for name, m in inspect.getmembers(cls, inspect.isfunction):
305                 if name.startswith(prefix):
306                     setattr(cls, name, check_method_for_perf_regressions(m))
307                     logger.debug("Wrapping %s:%s.", cls.__name__, name)
308         return cls
309
310     return decorate_the_testcase
311
312
313 class RecordStdout(contextlib.AbstractContextManager):
314     """
315     Records what is emitted to stdout into a buffer instead.
316
317     >>> with RecordStdout() as record:
318     ...     print("This is a test!")
319     >>> print({record().readline()})
320     {'This is a test!\\n'}
321     >>> record().close()
322     """
323
324     def __init__(self) -> None:
325         super().__init__()
326         self.destination = tempfile.SpooledTemporaryFile(mode="r+")
327         self.recorder: Optional[contextlib.redirect_stdout] = None
328
329     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
330         self.recorder = contextlib.redirect_stdout(self.destination)
331         assert self.recorder is not None
332         self.recorder.__enter__()
333         return lambda: self.destination
334
335     def __exit__(self, *args) -> Literal[False]:
336         assert self.recorder is not None
337         self.recorder.__exit__(*args)
338         self.destination.seek(0)
339         return False
340
341
342 class RecordStderr(contextlib.AbstractContextManager):
343     """
344     Record what is emitted to stderr.
345
346     >>> import sys
347     >>> with RecordStderr() as record:
348     ...     print("This is a test!", file=sys.stderr)
349     >>> print({record().readline()})
350     {'This is a test!\\n'}
351     >>> record().close()
352     """
353
354     def __init__(self) -> None:
355         super().__init__()
356         self.destination = tempfile.SpooledTemporaryFile(mode="r+")
357         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
358
359     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
360         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
361         assert self.recorder is not None
362         self.recorder.__enter__()
363         return lambda: self.destination
364
365     def __exit__(self, *args) -> Literal[False]:
366         assert self.recorder is not None
367         self.recorder.__exit__(*args)
368         self.destination.seek(0)
369         return False
370
371
372 class RecordMultipleStreams(contextlib.AbstractContextManager):
373     """
374     Record the output to more than one stream.
375
376     Example usage::
377
378         with RecordMultipleStreams(sys.stdout, sys.stderr) as r:
379             print("This is a test!", file=sys.stderr)
380             print("This is too", file=sys.stdout)
381
382         print(r().readlines())
383         r().close()
384
385     """
386
387     def __init__(self, *files) -> None:
388         super().__init__()
389         self.files = [*files]
390         self.destination = tempfile.SpooledTemporaryFile(mode="r+")
391         self.saved_writes: List[Callable[..., Any]] = []
392
393     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
394         for f in self.files:
395             self.saved_writes.append(f.write)
396             f.write = self.destination.write
397         return lambda: self.destination
398
399     def __exit__(self, *args) -> Literal[False]:
400         for f in self.files:
401             f.write = self.saved_writes.pop()
402         self.destination.seek(0)
403         return False
404
405
406 if __name__ == "__main__":
407     import doctest
408
409     doctest.testmod()