Ran black code formatter on everything.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 from abc import ABC, abstractmethod
11 import contextlib
12 import functools
13 import inspect
14 import logging
15 import os
16 import pickle
17 import random
18 import statistics
19 import time
20 import tempfile
21 from typing import Callable, Dict, List
22 import unittest
23 import warnings
24
25 import bootstrap
26 import config
27 import function_utils
28 import scott_secrets
29
30 import sqlalchemy as sa
31
32
33 logger = logging.getLogger(__name__)
34 cfg = config.add_commandline_args(
35     f'Logging ({__file__})', 'Args related to function decorators'
36 )
37 cfg.add_argument(
38     '--unittests_ignore_perf',
39     action='store_true',
40     default=False,
41     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
42 )
43 cfg.add_argument(
44     '--unittests_num_perf_samples',
45     type=int,
46     default=50,
47     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
48 )
49 cfg.add_argument(
50     '--unittests_drop_perf_traces',
51     type=str,
52     nargs=1,
53     default=None,
54     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
55 )
56 cfg.add_argument(
57     '--unittests_persistance_strategy',
58     choices=['FILE', 'DATABASE'],
59     default='DATABASE',
60     help='Should we persist perf data in a file or db?',
61 )
62 cfg.add_argument(
63     '--unittests_perfdb_filename',
64     type=str,
65     metavar='FILENAME',
66     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
67     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
68 )
69 cfg.add_argument(
70     '--unittests_perfdb_spec',
71     type=str,
72     metavar='DBSPEC',
73     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
74     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
75 )
76
77 # >>> This is the hacky business, FYI. <<<
78 unittest.main = bootstrap.initialize(unittest.main)
79
80
81 class PerfRegressionDataPersister(ABC):
82     def __init__(self):
83         pass
84
85     @abstractmethod
86     def load_performance_data(self) -> Dict[str, List[float]]:
87         pass
88
89     @abstractmethod
90     def save_performance_data(
91         self, method_id: str, data: Dict[str, List[float]]
92     ):
93         pass
94
95     @abstractmethod
96     def delete_performance_data(self, method_id: str):
97         pass
98
99
100 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
101     def __init__(self, filename: str):
102         self.filename = filename
103         self.traces_to_delete = []
104
105     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
106         with open(self.filename, 'rb') as f:
107             return pickle.load(f)
108
109     def save_performance_data(
110         self, method_id: str, data: Dict[str, List[float]]
111     ):
112         for trace in self.traces_to_delete:
113             if trace in data:
114                 data[trace] = []
115
116         with open(self.filename, 'wb') as f:
117             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
118
119     def delete_performance_data(self, method_id: str):
120         self.traces_to_delete.append(method_id)
121
122
123 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
124     def __init__(self, dbspec: str):
125         self.dbspec = dbspec
126         self.engine = sa.create_engine(self.dbspec)
127         self.conn = self.engine.connect()
128
129     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
130         results = self.conn.execute(
131             sa.text(
132                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
133             )
134         )
135         ret = {method_id: []}
136         for result in results.all():
137             ret[method_id].append(result['runtime'])
138         results.close()
139         return ret
140
141     def save_performance_data(
142         self, method_id: str, data: Dict[str, List[float]]
143     ):
144         self.delete_performance_data(method_id)
145         for (method_id, perf_data) in data.items():
146             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
147             for perf in perf_data:
148                 self.conn.execute(sql + f'("{method_id}", {perf});')
149
150     def delete_performance_data(self, method_id: str):
151         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
152         self.conn.execute(sql)
153
154
155 def check_method_for_perf_regressions(func: Callable) -> Callable:
156     """
157     This is meant to be used on a method in a class that subclasses
158     unittest.TestCase.  When thus decorated it will time the execution
159     of the code in the method, compare it with a database of
160     historical perfmance, and fail the test with a perf-related
161     message if it has become too slow.
162
163     """
164
165     @functools.wraps(func)
166     def wrapper_perf_monitor(*args, **kwargs):
167         if config.config['unittests_persistance_strategy'] == 'FILE':
168             filename = config.config['unittests_perfdb_filename']
169             helper = FileBasedPerfRegressionDataPersister(filename)
170         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
171             dbspec = config.config['unittests_perfdb_spec']
172             dbspec = dbspec.replace(
173                 '<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD
174             )
175             helper = DatabasePerfRegressionDataPersister(dbspec)
176         else:
177             raise Exception(
178                 'Unknown/unexpected --unittests_persistance_strategy value'
179             )
180
181         func_id = function_utils.function_identifier(func)
182         func_name = func.__name__
183         logger.debug(f'Watching {func_name}\'s performance...')
184         logger.debug(f'Canonical function identifier = {func_id}')
185
186         try:
187             perfdb = helper.load_performance_data(func_id)
188         except Exception as e:
189             logger.exception(e)
190             msg = 'Unable to load perfdb; skipping it...'
191             logger.warning(msg)
192             warnings.warn(msg)
193             perfdb = {}
194
195         # cmdline arg to forget perf traces for function
196         drop_id = config.config['unittests_drop_perf_traces']
197         if drop_id is not None:
198             helper.delete_performance_data(drop_id)
199
200         # Run the wrapped test paying attention to latency.
201         start_time = time.perf_counter()
202         value = func(*args, **kwargs)
203         end_time = time.perf_counter()
204         run_time = end_time - start_time
205
206         # See if it was unexpectedly slow.
207         hist = perfdb.get(func_id, [])
208         if len(hist) < config.config['unittests_num_perf_samples']:
209             hist.append(run_time)
210             logger.debug(f'Still establishing a perf baseline for {func_name}')
211         else:
212             stdev = statistics.stdev(hist)
213             logger.debug(f'For {func_name}, performance stdev={stdev}')
214             slowest = hist[-1]
215             logger.debug(
216                 f'For {func_name}, slowest perf on record is {slowest:f}s'
217             )
218             limit = slowest + stdev * 4
219             logger.debug(
220                 f'For {func_name}, max acceptable runtime is {limit:f}s'
221             )
222             logger.debug(
223                 f'For {func_name}, actual observed runtime was {run_time:f}s'
224             )
225             if run_time > limit and not config.config['unittests_ignore_perf']:
226                 msg = f'''{func_id} performance has regressed unacceptably.
227 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
228 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
229 Here is the current, full db perf timing distribution:
230
231 '''
232                 for x in hist:
233                     msg += f'{x:f}\n'
234                 logger.error(msg)
235                 slf = args[0]  # Peek at the wrapped function's self ref.
236                 slf.fail(msg)  # ...to fail the testcase.
237             else:
238                 hist.append(run_time)
239
240         # Don't spam the database with samples; just pick a random
241         # sample from what we have and store that back.
242         n = min(config.config['unittests_num_perf_samples'], len(hist))
243         hist = random.sample(hist, n)
244         hist.sort()
245         perfdb[func_id] = hist
246         helper.save_performance_data(func_id, perfdb)
247         return value
248
249     return wrapper_perf_monitor
250
251
252 def check_all_methods_for_perf_regressions(prefix='test_'):
253     """Decorate unittests with this to pay attention to the perf of the
254     testcode and flag perf regressions.  e.g.
255
256     import unittest_utils as uu
257
258     @uu.check_all_methods_for_perf_regressions()
259     class TestMyClass(unittest.TestCase):
260
261         def test_some_part_of_my_class(self):
262             ...
263
264     """
265
266     def decorate_the_testcase(cls):
267         if issubclass(cls, unittest.TestCase):
268             for name, m in inspect.getmembers(cls, inspect.isfunction):
269                 if name.startswith(prefix):
270                     setattr(cls, name, check_method_for_perf_regressions(m))
271                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
272         return cls
273
274     return decorate_the_testcase
275
276
277 def breakpoint():
278     """Hard code a breakpoint somewhere; drop into pdb."""
279     import pdb
280
281     pdb.set_trace()
282
283
284 class RecordStdout(object):
285     """
286     Record what is emitted to stdout.
287
288     >>> with RecordStdout() as record:
289     ...     print("This is a test!")
290     >>> print({record().readline()})
291     {'This is a test!\\n'}
292     """
293
294     def __init__(self) -> None:
295         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
296         self.recorder = None
297
298     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
299         self.recorder = contextlib.redirect_stdout(self.destination)
300         self.recorder.__enter__()
301         return lambda: self.destination
302
303     def __exit__(self, *args) -> bool:
304         self.recorder.__exit__(*args)
305         self.destination.seek(0)
306         return None
307
308
309 class RecordStderr(object):
310     """
311     Record what is emitted to stderr.
312
313     >>> import sys
314     >>> with RecordStderr() as record:
315     ...     print("This is a test!", file=sys.stderr)
316     >>> print({record().readline()})
317     {'This is a test!\\n'}
318     """
319
320     def __init__(self) -> None:
321         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
322         self.recorder = None
323
324     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
325         self.recorder = contextlib.redirect_stderr(self.destination)
326         self.recorder.__enter__()
327         return lambda: self.destination
328
329     def __exit__(self, *args) -> bool:
330         self.recorder.__exit__(*args)
331         self.destination.seek(0)
332         return None
333
334
335 class RecordMultipleStreams(object):
336     """
337     Record the output to more than one stream.
338     """
339
340     def __init__(self, *files) -> None:
341         self.files = [*files]
342         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
343         self.saved_writes = []
344
345     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
346         for f in self.files:
347             self.saved_writes.append(f.write)
348             f.write = self.destination.write
349         return lambda: self.destination
350
351     def __exit__(self, *args) -> bool:
352         for f in self.files:
353             f.write = self.saved_writes.pop()
354         self.destination.seek(0)
355
356
357 if __name__ == '__main__':
358     import doctest
359
360     doctest.testmod()