Better logging + cleanup.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Helpers for unittests.  Note that when you import this we
6 automatically wrap unittest.main() with a call to bootstrap.initialize
7 so that we getLogger config, commandline args, logging control,
8 etc... this works fine but it's a little hacky so caveat emptor.
9
10 """
11
12 import contextlib
13 import functools
14 import inspect
15 import logging
16 import os
17 import pickle
18 import random
19 import statistics
20 import tempfile
21 import time
22 import unittest
23 import warnings
24 from abc import ABC, abstractmethod
25 from typing import Any, Callable, Dict, List, Literal, Optional
26
27 import sqlalchemy as sa
28
29 import bootstrap
30 import config
31 import function_utils
32 import scott_secrets
33
34 logger = logging.getLogger(__name__)
35 cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators')
36 cfg.add_argument(
37     '--unittests_ignore_perf',
38     action='store_true',
39     default=False,
40     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
41 )
42 cfg.add_argument(
43     '--unittests_num_perf_samples',
44     type=int,
45     default=50,
46     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
47 )
48 cfg.add_argument(
49     '--unittests_drop_perf_traces',
50     type=str,
51     nargs=1,
52     default=None,
53     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
54 )
55 cfg.add_argument(
56     '--unittests_persistance_strategy',
57     choices=['FILE', 'DATABASE'],
58     default='DATABASE',
59     help='Should we persist perf data in a file or db?',
60 )
61 cfg.add_argument(
62     '--unittests_perfdb_filename',
63     type=str,
64     metavar='FILENAME',
65     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
67 )
68 cfg.add_argument(
69     '--unittests_perfdb_spec',
70     type=str,
71     metavar='DBSPEC',
72     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
73     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
74 )
75
76 # >>> This is the hacky business, FYI. <<<
77 unittest.main = bootstrap.initialize(unittest.main)
78
79
80 class PerfRegressionDataPersister(ABC):
81     """A base class for a signature dealing with persisting perf
82     regression data."""
83
84     def __init__(self):
85         pass
86
87     @abstractmethod
88     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
89         pass
90
91     @abstractmethod
92     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
93         pass
94
95     @abstractmethod
96     def delete_performance_data(self, method_id: str):
97         pass
98
99
100 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
101     """A perf regression data persister that uses files."""
102
103     def __init__(self, filename: str):
104         super().__init__()
105         self.filename = filename
106         self.traces_to_delete: List[str] = []
107
108     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
109         with open(self.filename, 'rb') as f:
110             return pickle.load(f)
111
112     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
113         for trace in self.traces_to_delete:
114             if trace in data:
115                 data[trace] = []
116
117         with open(self.filename, 'wb') as f:
118             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
119
120     def delete_performance_data(self, method_id: str):
121         self.traces_to_delete.append(method_id)
122
123
124 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
125     """A perf regression data persister that uses a database backend."""
126
127     def __init__(self, dbspec: str):
128         super().__init__()
129         self.dbspec = dbspec
130         self.engine = sa.create_engine(self.dbspec)
131         self.conn = self.engine.connect()
132
133     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
134         results = self.conn.execute(
135             sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
136         )
137         ret: Dict[str, List[float]] = {method_id: []}
138         for result in results.all():
139             ret[method_id].append(result['runtime'])
140         results.close()
141         return ret
142
143     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
144         self.delete_performance_data(method_id)
145         for (mid, perf_data) in data.items():
146             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
147             for perf in perf_data:
148                 self.conn.execute(sql + f'("{mid}", {perf});')
149
150     def delete_performance_data(self, method_id: str):
151         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
152         self.conn.execute(sql)
153
154
155 def check_method_for_perf_regressions(func: Callable) -> Callable:
156     """
157     This is meant to be used on a method in a class that subclasses
158     unittest.TestCase.  When thus decorated it will time the execution
159     of the code in the method, compare it with a database of
160     historical perfmance, and fail the test with a perf-related
161     message if it has become too slow.
162
163     """
164
165     @functools.wraps(func)
166     def wrapper_perf_monitor(*args, **kwargs):
167         if config.config['unittests_ignore_perf']:
168             return func(*args, **kwargs)
169
170         if config.config['unittests_persistance_strategy'] == 'FILE':
171             filename = config.config['unittests_perfdb_filename']
172             helper = FileBasedPerfRegressionDataPersister(filename)
173         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
174             dbspec = config.config['unittests_perfdb_spec']
175             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
176             helper = DatabasePerfRegressionDataPersister(dbspec)
177         else:
178             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
179
180         func_id = function_utils.function_identifier(func)
181         func_name = func.__name__
182         logger.debug('Watching %s\'s performance...', func_name)
183         logger.debug('Canonical function identifier = "%s"', func_id)
184
185         try:
186             perfdb = helper.load_performance_data(func_id)
187         except Exception as e:
188             logger.exception(e)
189             msg = 'Unable to load perfdb; skipping it...'
190             logger.warning(msg)
191             warnings.warn(msg)
192             perfdb = {}
193
194         # cmdline arg to forget perf traces for function
195         drop_id = config.config['unittests_drop_perf_traces']
196         if drop_id is not None:
197             helper.delete_performance_data(drop_id)
198
199         # Run the wrapped test paying attention to latency.
200         start_time = time.perf_counter()
201         value = func(*args, **kwargs)
202         end_time = time.perf_counter()
203         run_time = end_time - start_time
204
205         # See if it was unexpectedly slow.
206         hist = perfdb.get(func_id, [])
207         if len(hist) < config.config['unittests_num_perf_samples']:
208             hist.append(run_time)
209             logger.debug('Still establishing a perf baseline for %s', func_name)
210         else:
211             stdev = statistics.stdev(hist)
212             logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
213             slowest = hist[-1]
214             logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
215             limit = slowest + stdev * 4
216             logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
217             logger.debug('For %s, actual observed runtime was %.2fs', func_name, run_time)
218             if run_time > limit:
219                 msg = f'''{func_id} performance has regressed unacceptably.
220 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
221 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
222 Here is the current, full db perf timing distribution:
223
224 '''
225                 for x in hist:
226                     msg += f'{x:f}\n'
227                 logger.error(msg)
228                 slf = args[0]  # Peek at the wrapped function's self ref.
229                 slf.fail(msg)  # ...to fail the testcase.
230             else:
231                 hist.append(run_time)
232
233         # Don't spam the database with samples; just pick a random
234         # sample from what we have and store that back.
235         n = min(config.config['unittests_num_perf_samples'], len(hist))
236         hist = random.sample(hist, n)
237         hist.sort()
238         perfdb[func_id] = hist
239         helper.save_performance_data(func_id, perfdb)
240         return value
241
242     return wrapper_perf_monitor
243
244
245 def check_all_methods_for_perf_regressions(prefix='test_'):
246     """Decorate unittests with this to pay attention to the perf of the
247     testcode and flag perf regressions.  e.g.
248
249     import unittest_utils as uu
250
251     @uu.check_all_methods_for_perf_regressions()
252     class TestMyClass(unittest.TestCase):
253
254         def test_some_part_of_my_class(self):
255             ...
256
257     """
258
259     def decorate_the_testcase(cls):
260         if issubclass(cls, unittest.TestCase):
261             for name, m in inspect.getmembers(cls, inspect.isfunction):
262                 if name.startswith(prefix):
263                     setattr(cls, name, check_method_for_perf_regressions(m))
264                     logger.debug('Wrapping %s:%s.', cls.__name__, name)
265         return cls
266
267     return decorate_the_testcase
268
269
270 class RecordStdout(contextlib.AbstractContextManager):
271     """
272     Record what is emitted to stdout.
273
274     >>> with RecordStdout() as record:
275     ...     print("This is a test!")
276     >>> print({record().readline()})
277     {'This is a test!\\n'}
278     >>> record().close()
279     """
280
281     def __init__(self) -> None:
282         super().__init__()
283         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
284         self.recorder: Optional[contextlib.redirect_stdout] = None
285
286     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
287         self.recorder = contextlib.redirect_stdout(self.destination)
288         assert self.recorder is not None
289         self.recorder.__enter__()
290         return lambda: self.destination
291
292     def __exit__(self, *args) -> Literal[False]:
293         assert self.recorder is not None
294         self.recorder.__exit__(*args)
295         self.destination.seek(0)
296         return False
297
298
299 class RecordStderr(contextlib.AbstractContextManager):
300     """
301     Record what is emitted to stderr.
302
303     >>> import sys
304     >>> with RecordStderr() as record:
305     ...     print("This is a test!", file=sys.stderr)
306     >>> print({record().readline()})
307     {'This is a test!\\n'}
308     >>> record().close()
309     """
310
311     def __init__(self) -> None:
312         super().__init__()
313         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
314         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
315
316     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
317         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
318         assert self.recorder is not None
319         self.recorder.__enter__()
320         return lambda: self.destination
321
322     def __exit__(self, *args) -> Literal[False]:
323         assert self.recorder is not None
324         self.recorder.__exit__(*args)
325         self.destination.seek(0)
326         return False
327
328
329 class RecordMultipleStreams(contextlib.AbstractContextManager):
330     """
331     Record the output to more than one stream.
332     """
333
334     def __init__(self, *files) -> None:
335         super().__init__()
336         self.files = [*files]
337         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
338         self.saved_writes: List[Callable[..., Any]] = []
339
340     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
341         for f in self.files:
342             self.saved_writes.append(f.write)
343             f.write = self.destination.write
344         return lambda: self.destination
345
346     def __exit__(self, *args) -> Literal[False]:
347         for f in self.files:
348             f.write = self.saved_writes.pop()
349         self.destination.seek(0)
350         return False
351
352
353 if __name__ == '__main__':
354     import doctest
355
356     doctest.testmod()