Optionally surface exceptions that happen under executors by reading
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 from abc import ABC, abstractmethod
11 import contextlib
12 import functools
13 import inspect
14 import logging
15 import os
16 import pickle
17 import random
18 import statistics
19 import time
20 import tempfile
21 from typing import Callable, Dict, List
22 import unittest
23 import warnings
24
25 import bootstrap
26 import config
27 import function_utils
28 import scott_secrets
29
30 import sqlalchemy as sa
31
32
33 logger = logging.getLogger(__name__)
34 cfg = config.add_commandline_args(
35     f'Logging ({__file__})', 'Args related to function decorators'
36 )
37 cfg.add_argument(
38     '--unittests_ignore_perf',
39     action='store_true',
40     default=False,
41     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
42 )
43 cfg.add_argument(
44     '--unittests_num_perf_samples',
45     type=int,
46     default=50,
47     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
48 )
49 cfg.add_argument(
50     '--unittests_drop_perf_traces',
51     type=str,
52     nargs=1,
53     default=None,
54     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
55 )
56 cfg.add_argument(
57     '--unittests_persistance_strategy',
58     choices=['FILE', 'DATABASE'],
59     default='DATABASE',
60     help='Should we persist perf data in a file or db?',
61 )
62 cfg.add_argument(
63     '--unittests_perfdb_filename',
64     type=str,
65     metavar='FILENAME',
66     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
67     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
68 )
69 cfg.add_argument(
70     '--unittests_perfdb_spec',
71     type=str,
72     metavar='DBSPEC',
73     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
74     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
75 )
76
77 # >>> This is the hacky business, FYI. <<<
78 unittest.main = bootstrap.initialize(unittest.main)
79
80
81 class PerfRegressionDataPersister(ABC):
82     def __init__(self):
83         pass
84
85     @abstractmethod
86     def load_performance_data(self) -> Dict[str, List[float]]:
87         pass
88
89     @abstractmethod
90     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
91         pass
92
93     @abstractmethod
94     def delete_performance_data(self, method_id: str):
95         pass
96
97
98 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
99     def __init__(self, filename: str):
100         self.filename = filename
101         self.traces_to_delete = []
102
103     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
104         with open(self.filename, 'rb') as f:
105             return pickle.load(f)
106
107     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
108         for trace in self.traces_to_delete:
109             if trace in data:
110                 data[trace] = []
111
112         with open(self.filename, 'wb') as f:
113             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
114
115     def delete_performance_data(self, method_id: str):
116         self.traces_to_delete.append(method_id)
117
118
119 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
120     def __init__(self, dbspec: str):
121         self.dbspec = dbspec
122         self.engine = sa.create_engine(self.dbspec)
123         self.conn = self.engine.connect()
124
125     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126         results = self.conn.execute(
127             sa.text(
128                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
129             )
130         )
131         ret = {method_id: []}
132         for result in results.all():
133             ret[method_id].append(result['runtime'])
134         results.close()
135         return ret
136
137     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
138         self.delete_performance_data(method_id)
139         for (method_id, perf_data) in data.items():
140             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
141             for perf in perf_data:
142                 self.conn.execute(sql + f'("{method_id}", {perf});')
143
144     def delete_performance_data(self, method_id: str):
145         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
146         self.conn.execute(sql)
147
148
149 def check_method_for_perf_regressions(func: Callable) -> Callable:
150     """
151     This is meant to be used on a method in a class that subclasses
152     unittest.TestCase.  When thus decorated it will time the execution
153     of the code in the method, compare it with a database of
154     historical perfmance, and fail the test with a perf-related
155     message if it has become too slow.
156
157     """
158
159     @functools.wraps(func)
160     def wrapper_perf_monitor(*args, **kwargs):
161         if config.config['unittests_persistance_strategy'] == 'FILE':
162             filename = config.config['unittests_perfdb_filename']
163             helper = FileBasedPerfRegressionDataPersister(filename)
164         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
165             dbspec = config.config['unittests_perfdb_spec']
166             dbspec = dbspec.replace(
167                 '<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD
168             )
169             helper = DatabasePerfRegressionDataPersister(dbspec)
170         else:
171             raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
172
173         func_id = function_utils.function_identifier(func)
174         func_name = func.__name__
175         logger.debug(f'Watching {func_name}\'s performance...')
176         logger.debug(f'Canonical function identifier = {func_id}')
177
178         try:
179             perfdb = helper.load_performance_data(func_id)
180         except Exception as e:
181             logger.exception(e)
182             msg = 'Unable to load perfdb; skipping it...'
183             logger.warning(msg)
184             warnings.warn(msg)
185             perfdb = {}
186
187         # cmdline arg to forget perf traces for function
188         drop_id = config.config['unittests_drop_perf_traces']
189         if drop_id is not None:
190             helper.delete_performance_data(drop_id)
191
192         # Run the wrapped test paying attention to latency.
193         start_time = time.perf_counter()
194         value = func(*args, **kwargs)
195         end_time = time.perf_counter()
196         run_time = end_time - start_time
197
198         # See if it was unexpectedly slow.
199         hist = perfdb.get(func_id, [])
200         if len(hist) < config.config['unittests_num_perf_samples']:
201             hist.append(run_time)
202             logger.debug(f'Still establishing a perf baseline for {func_name}')
203         else:
204             stdev = statistics.stdev(hist)
205             logger.debug(f'For {func_name}, performance stdev={stdev}')
206             slowest = hist[-1]
207             logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s')
208             limit = slowest + stdev * 4
209             logger.debug(f'For {func_name}, max acceptable runtime is {limit:f}s')
210             logger.debug(f'For {func_name}, actual observed runtime was {run_time:f}s')
211             if run_time > limit and not config.config['unittests_ignore_perf']:
212                 msg = f'''{func_id} performance has regressed unacceptably.
213 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
214 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
215 Here is the current, full db perf timing distribution:
216
217 '''
218                 for x in hist:
219                     msg += f'{x:f}\n'
220                 logger.error(msg)
221                 slf = args[0]  # Peek at the wrapped function's self ref.
222                 slf.fail(msg)  # ...to fail the testcase.
223             else:
224                 hist.append(run_time)
225
226         # Don't spam the database with samples; just pick a random
227         # sample from what we have and store that back.
228         n = min(config.config['unittests_num_perf_samples'], len(hist))
229         hist = random.sample(hist, n)
230         hist.sort()
231         perfdb[func_id] = hist
232         helper.save_performance_data(func_id, perfdb)
233         return value
234
235     return wrapper_perf_monitor
236
237
238 def check_all_methods_for_perf_regressions(prefix='test_'):
239     """Decorate unittests with this to pay attention to the perf of the
240     testcode and flag perf regressions.  e.g.
241
242     import unittest_utils as uu
243
244     @uu.check_all_methods_for_perf_regressions()
245     class TestMyClass(unittest.TestCase):
246
247         def test_some_part_of_my_class(self):
248             ...
249
250     """
251
252     def decorate_the_testcase(cls):
253         if issubclass(cls, unittest.TestCase):
254             for name, m in inspect.getmembers(cls, inspect.isfunction):
255                 if name.startswith(prefix):
256                     setattr(cls, name, check_method_for_perf_regressions(m))
257                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
258         return cls
259
260     return decorate_the_testcase
261
262
263 def breakpoint():
264     """Hard code a breakpoint somewhere; drop into pdb."""
265     import pdb
266
267     pdb.set_trace()
268
269
270 class RecordStdout(object):
271     """
272     Record what is emitted to stdout.
273
274     >>> with RecordStdout() as record:
275     ...     print("This is a test!")
276     >>> print({record().readline()})
277     {'This is a test!\\n'}
278     >>> record().close()
279     """
280
281     def __init__(self) -> None:
282         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
283         self.recorder = None
284
285     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
286         self.recorder = contextlib.redirect_stdout(self.destination)
287         self.recorder.__enter__()
288         return lambda: self.destination
289
290     def __exit__(self, *args) -> bool:
291         self.recorder.__exit__(*args)
292         self.destination.seek(0)
293         return None
294
295
296 class RecordStderr(object):
297     """
298     Record what is emitted to stderr.
299
300     >>> import sys
301     >>> with RecordStderr() as record:
302     ...     print("This is a test!", file=sys.stderr)
303     >>> print({record().readline()})
304     {'This is a test!\\n'}
305     >>> record().close()
306     """
307
308     def __init__(self) -> None:
309         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
310         self.recorder = None
311
312     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
313         self.recorder = contextlib.redirect_stderr(self.destination)
314         self.recorder.__enter__()
315         return lambda: self.destination
316
317     def __exit__(self, *args) -> bool:
318         self.recorder.__exit__(*args)
319         self.destination.seek(0)
320         return None
321
322
323 class RecordMultipleStreams(object):
324     """
325     Record the output to more than one stream.
326     """
327
328     def __init__(self, *files) -> None:
329         self.files = [*files]
330         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
331         self.saved_writes = []
332
333     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
334         for f in self.files:
335             self.saved_writes.append(f.write)
336             f.write = self.destination.write
337         return lambda: self.destination
338
339     def __exit__(self, *args) -> bool:
340         for f in self.files:
341             f.write = self.saved_writes.pop()
342         self.destination.seek(0)
343
344
345 if __name__ == '__main__':
346     import doctest
347
348     doctest.testmod()