Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Helpers for unittests.
6
7 .. note::
8
9     When you import this we automatically wrap unittest.main()
10     with a call to bootstrap.initialize so that we getLogger
11     config, commandline args, logging control, etc... this works
12     fine but it's a little hacky so caveat emptor.
13 """
14
15 import contextlib
16 import functools
17 import inspect
18 import logging
19 import os
20 import pickle
21 import random
22 import statistics
23 import tempfile
24 import time
25 import unittest
26 import warnings
27 from abc import ABC, abstractmethod
28 from typing import Any, Callable, Dict, List, Literal, Optional
29
30 import sqlalchemy as sa
31
32 import bootstrap
33 import config
34 import function_utils
35 import scott_secrets
36
37 logger = logging.getLogger(__name__)
38 cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators')
39 cfg.add_argument(
40     '--unittests_ignore_perf',
41     action='store_true',
42     default=False,
43     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
44 )
45 cfg.add_argument(
46     '--unittests_num_perf_samples',
47     type=int,
48     default=50,
49     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
50 )
51 cfg.add_argument(
52     '--unittests_drop_perf_traces',
53     type=str,
54     nargs=1,
55     default=None,
56     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
57 )
58 cfg.add_argument(
59     '--unittests_persistance_strategy',
60     choices=['FILE', 'DATABASE'],
61     default='DATABASE',
62     help='Should we persist perf data in a file or db?',
63 )
64 cfg.add_argument(
65     '--unittests_perfdb_filename',
66     type=str,
67     metavar='FILENAME',
68     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
69     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
70 )
71 cfg.add_argument(
72     '--unittests_perfdb_spec',
73     type=str,
74     metavar='DBSPEC',
75     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
76     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
77 )
78
79 # >>> This is the hacky business, FYI. <<<
80 unittest.main = bootstrap.initialize(unittest.main)
81
82
83 class PerfRegressionDataPersister(ABC):
84     """A base class for a signature dealing with persisting perf
85     regression data."""
86
87     def __init__(self):
88         pass
89
90     @abstractmethod
91     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
92         pass
93
94     @abstractmethod
95     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
96         pass
97
98     @abstractmethod
99     def delete_performance_data(self, method_id: str):
100         pass
101
102
103 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
104     """A perf regression data persister that uses files."""
105
106     def __init__(self, filename: str):
107         super().__init__()
108         self.filename = filename
109         self.traces_to_delete: List[str] = []
110
111     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
112         with open(self.filename, 'rb') as f:
113             return pickle.load(f)
114
115     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
116         for trace in self.traces_to_delete:
117             if trace in data:
118                 data[trace] = []
119
120         with open(self.filename, 'wb') as f:
121             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
122
123     def delete_performance_data(self, method_id: str):
124         self.traces_to_delete.append(method_id)
125
126
127 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
128     """A perf regression data persister that uses a database backend."""
129
130     def __init__(self, dbspec: str):
131         super().__init__()
132         self.dbspec = dbspec
133         self.engine = sa.create_engine(self.dbspec)
134         self.conn = self.engine.connect()
135
136     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
137         results = self.conn.execute(
138             sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
139         )
140         ret: Dict[str, List[float]] = {method_id: []}
141         for result in results.all():
142             ret[method_id].append(result['runtime'])
143         results.close()
144         return ret
145
146     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
147         self.delete_performance_data(method_id)
148         for (mid, perf_data) in data.items():
149             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
150             for perf in perf_data:
151                 self.conn.execute(sql + f'("{mid}", {perf});')
152
153     def delete_performance_data(self, method_id: str):
154         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
155         self.conn.execute(sql)
156
157
158 def check_method_for_perf_regressions(func: Callable) -> Callable:
159     """
160     This is meant to be used on a method in a class that subclasses
161     unittest.TestCase.  When thus decorated it will time the execution
162     of the code in the method, compare it with a database of
163     historical perfmance, and fail the test with a perf-related
164     message if it has become too slow.
165
166     """
167
168     @functools.wraps(func)
169     def wrapper_perf_monitor(*args, **kwargs):
170         if config.config['unittests_ignore_perf']:
171             return func(*args, **kwargs)
172
173         if config.config['unittests_persistance_strategy'] == 'FILE':
174             filename = config.config['unittests_perfdb_filename']
175             helper = FileBasedPerfRegressionDataPersister(filename)
176         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
177             dbspec = config.config['unittests_perfdb_spec']
178             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
179             helper = DatabasePerfRegressionDataPersister(dbspec)
180         else:
181             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
182
183         func_id = function_utils.function_identifier(func)
184         func_name = func.__name__
185         logger.debug('Watching %s\'s performance...', func_name)
186         logger.debug('Canonical function identifier = "%s"', func_id)
187
188         try:
189             perfdb = helper.load_performance_data(func_id)
190         except Exception as e:
191             logger.exception(e)
192             msg = 'Unable to load perfdb; skipping it...'
193             logger.warning(msg)
194             warnings.warn(msg)
195             perfdb = {}
196
197         # cmdline arg to forget perf traces for function
198         drop_id = config.config['unittests_drop_perf_traces']
199         if drop_id is not None:
200             helper.delete_performance_data(drop_id)
201
202         # Run the wrapped test paying attention to latency.
203         start_time = time.perf_counter()
204         value = func(*args, **kwargs)
205         end_time = time.perf_counter()
206         run_time = end_time - start_time
207
208         # See if it was unexpectedly slow.
209         hist = perfdb.get(func_id, [])
210         if len(hist) < config.config['unittests_num_perf_samples']:
211             hist.append(run_time)
212             logger.debug('Still establishing a perf baseline for %s', func_name)
213         else:
214             stdev = statistics.stdev(hist)
215             logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
216             slowest = hist[-1]
217             logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
218             limit = slowest + stdev * 4
219             logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
220             logger.debug('For %s, actual observed runtime was %.2fs', func_name, run_time)
221             if run_time > limit:
222                 msg = f'''{func_id} performance has regressed unacceptably.
223 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
224 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
225 Here is the current, full db perf timing distribution:
226
227 '''
228                 for x in hist:
229                     msg += f'{x:f}\n'
230                 logger.error(msg)
231                 slf = args[0]  # Peek at the wrapped function's self ref.
232                 slf.fail(msg)  # ...to fail the testcase.
233             else:
234                 hist.append(run_time)
235
236         # Don't spam the database with samples; just pick a random
237         # sample from what we have and store that back.
238         n = min(config.config['unittests_num_perf_samples'], len(hist))
239         hist = random.sample(hist, n)
240         hist.sort()
241         perfdb[func_id] = hist
242         helper.save_performance_data(func_id, perfdb)
243         return value
244
245     return wrapper_perf_monitor
246
247
248 def check_all_methods_for_perf_regressions(prefix='test_'):
249     """Decorate unittests with this to pay attention to the perf of the
250     testcode and flag perf regressions.  e.g.
251
252     import unittest_utils as uu
253
254     @uu.check_all_methods_for_perf_regressions()
255     class TestMyClass(unittest.TestCase):
256
257         def test_some_part_of_my_class(self):
258             ...
259
260     """
261
262     def decorate_the_testcase(cls):
263         if issubclass(cls, unittest.TestCase):
264             for name, m in inspect.getmembers(cls, inspect.isfunction):
265                 if name.startswith(prefix):
266                     setattr(cls, name, check_method_for_perf_regressions(m))
267                     logger.debug('Wrapping %s:%s.', cls.__name__, name)
268         return cls
269
270     return decorate_the_testcase
271
272
273 class RecordStdout(contextlib.AbstractContextManager):
274     """
275     Record what is emitted to stdout.
276
277     >>> with RecordStdout() as record:
278     ...     print("This is a test!")
279     >>> print({record().readline()})
280     {'This is a test!\\n'}
281     >>> record().close()
282     """
283
284     def __init__(self) -> None:
285         super().__init__()
286         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
287         self.recorder: Optional[contextlib.redirect_stdout] = None
288
289     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
290         self.recorder = contextlib.redirect_stdout(self.destination)
291         assert self.recorder is not None
292         self.recorder.__enter__()
293         return lambda: self.destination
294
295     def __exit__(self, *args) -> Literal[False]:
296         assert self.recorder is not None
297         self.recorder.__exit__(*args)
298         self.destination.seek(0)
299         return False
300
301
302 class RecordStderr(contextlib.AbstractContextManager):
303     """
304     Record what is emitted to stderr.
305
306     >>> import sys
307     >>> with RecordStderr() as record:
308     ...     print("This is a test!", file=sys.stderr)
309     >>> print({record().readline()})
310     {'This is a test!\\n'}
311     >>> record().close()
312     """
313
314     def __init__(self) -> None:
315         super().__init__()
316         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
317         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
318
319     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
320         self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
321         assert self.recorder is not None
322         self.recorder.__enter__()
323         return lambda: self.destination
324
325     def __exit__(self, *args) -> Literal[False]:
326         assert self.recorder is not None
327         self.recorder.__exit__(*args)
328         self.destination.seek(0)
329         return False
330
331
332 class RecordMultipleStreams(contextlib.AbstractContextManager):
333     """
334     Record the output to more than one stream.
335     """
336
337     def __init__(self, *files) -> None:
338         super().__init__()
339         self.files = [*files]
340         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
341         self.saved_writes: List[Callable[..., Any]] = []
342
343     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
344         for f in self.files:
345             self.saved_writes.append(f.write)
346             f.write = self.destination.write
347         return lambda: self.destination
348
349     def __exit__(self, *args) -> Literal[False]:
350         for f in self.files:
351             f.write = self.saved_writes.pop()
352         self.destination.seek(0)
353         return False
354
355
356 if __name__ == '__main__':
357     import doctest
358
359     doctest.testmod()