Creates a function_utils and pull a function_identifer method into
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 from abc import ABC, abstractmethod
11 import contextlib
12 import functools
13 import inspect
14 import logging
15 import os
16 import pickle
17 import random
18 import statistics
19 import time
20 import tempfile
21 from typing import Callable, Dict, List
22 import unittest
23 import warnings
24
25 import bootstrap
26 import config
27 import function_utils
28 import scott_secrets
29
30 import sqlalchemy as sa
31
32
33 logger = logging.getLogger(__name__)
34 cfg = config.add_commandline_args(
35     f'Logging ({__file__})',
36     'Args related to function decorators')
37 cfg.add_argument(
38     '--unittests_ignore_perf',
39     action='store_true',
40     default=False,
41     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
42 )
43 cfg.add_argument(
44     '--unittests_num_perf_samples',
45     type=int,
46     default=50,
47     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
48 )
49 cfg.add_argument(
50     '--unittests_drop_perf_traces',
51     type=str,
52     nargs=1,
53     default=None,
54     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
55 )
56 cfg.add_argument(
57     '--unittests_persistance_strategy',
58     choices=['FILE', 'DATABASE'],
59     default='DATABASE',
60     help='Should we persist perf data in a file or db?'
61 )
62 cfg.add_argument(
63     '--unittests_perfdb_filename',
64     type=str,
65     metavar='FILENAME',
66     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
67     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)'
68 )
69 cfg.add_argument(
70     '--unittests_perfdb_spec',
71     type=str,
72     metavar='DBSPEC',
73     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
74     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)'
75 )
76
77 # >>> This is the hacky business, FYI. <<<
78 unittest.main = bootstrap.initialize(unittest.main)
79
80
81 class PerfRegressionDataPersister(ABC):
82     def __init__(self):
83         pass
84
85     @abstractmethod
86     def load_performance_data(self) -> Dict[str, List[float]]:
87         pass
88
89     @abstractmethod
90     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
91         pass
92
93     @abstractmethod
94     def delete_performance_data(self, method_id: str):
95         pass
96
97
98 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
99     def __init__(self, filename: str):
100         self.filename = filename
101         self.traces_to_delete = []
102
103     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
104         with open(self.filename, 'rb') as f:
105             return pickle.load(f)
106
107     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
108         for trace in self.traces_to_delete:
109             if trace in data:
110                 data[trace] = []
111
112         with open(self.filename, 'wb') as f:
113             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
114
115     def delete_performance_data(self, method_id: str):
116         self.traces_to_delete.append(method_id)
117
118
119 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
120     def __init__(self, dbspec: str):
121         self.dbspec = dbspec
122         self.engine = sa.create_engine(self.dbspec)
123         self.conn = self.engine.connect()
124
125     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126         results = self.conn.execute(
127             sa.text(
128                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
129             )
130         )
131         ret = {method_id: []}
132         for result in results.all():
133             ret[method_id].append(result['runtime'])
134         results.close()
135         return ret
136
137     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
138         self.delete_performance_data(method_id)
139         for (method_id, perf_data) in data.items():
140             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
141             for perf in perf_data:
142                 self.conn.execute(sql + f'("{method_id}", {perf});')
143
144     def delete_performance_data(self, method_id: str):
145         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
146         self.conn.execute(sql)
147
148
149 def check_method_for_perf_regressions(func: Callable) -> Callable:
150     """
151     This is meant to be used on a method in a class that subclasses
152     unittest.TestCase.  When thus decorated it will time the execution
153     of the code in the method, compare it with a database of
154     historical perfmance, and fail the test with a perf-related
155     message if it has become too slow.
156
157     """
158     @functools.wraps(func)
159     def wrapper_perf_monitor(*args, **kwargs):
160         if config.config['unittests_persistance_strategy'] == 'FILE':
161             filename = config.config['unittests_perfdb_filename']
162             helper = FileBasedPerfRegressionDataPersister(filename)
163         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
164             dbspec = config.config['unittests_perfdb_spec']
165             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
166             helper = DatabasePerfRegressionDataPersister(dbspec)
167         else:
168             raise Exception(
169                 'Unknown/unexpected --unittests_persistance_strategy value'
170             )
171
172         func_id = function_utils.function_identifier(func)
173         func_name = func.__name__
174         logger.debug(f'Watching {func_name}\'s performance...')
175         logger.debug(f'Canonical function identifier = {func_id}')
176
177         try:
178             perfdb = helper.load_performance_data(func_id)
179         except Exception as e:
180             logger.exception(e)
181             msg = 'Unable to load perfdb; skipping it...'
182             logger.warning(msg)
183             warnings.warn(msg)
184             perfdb = {}
185
186         # cmdline arg to forget perf traces for function
187         drop_id = config.config['unittests_drop_perf_traces']
188         if drop_id is not None:
189             helper.delete_performance_data(drop_id)
190
191         # Run the wrapped test paying attention to latency.
192         start_time = time.perf_counter()
193         value = func(*args, **kwargs)
194         end_time = time.perf_counter()
195         run_time = end_time - start_time
196
197         # See if it was unexpectedly slow.
198         hist = perfdb.get(func_id, [])
199         if len(hist) < config.config['unittests_num_perf_samples']:
200             hist.append(run_time)
201             logger.debug(
202                 f'Still establishing a perf baseline for {func_name}'
203             )
204         else:
205             stdev = statistics.stdev(hist)
206             logger.debug(f'For {func_name}, performance stdev={stdev}')
207             slowest = hist[-1]
208             logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s')
209             limit = slowest + stdev * 4
210             logger.debug(
211                 f'For {func_name}, max acceptable runtime is {limit:f}s'
212             )
213             logger.debug(
214                 f'For {func_name}, actual observed runtime was {run_time:f}s'
215             )
216             if (
217                 run_time > limit and
218                 not config.config['unittests_ignore_perf']
219             ):
220                 msg = f'''{func_id} performance has regressed unacceptably.
221 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
222 It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
223 Here is the current, full db perf timing distribution:
224
225 '''
226                 for x in hist:
227                     msg += f'{x:f}\n'
228                 logger.error(msg)
229                 slf = args[0]
230                 slf.fail(msg)
231             else:
232                 hist.append(run_time)
233
234         n = min(config.config['unittests_num_perf_samples'], len(hist))
235         hist = random.sample(hist, n)
236         hist.sort()
237         perfdb[func_id] = hist
238         helper.save_performance_data(func_id, perfdb)
239         return value
240     return wrapper_perf_monitor
241
242
243 def check_all_methods_for_perf_regressions(prefix='test_'):
244     def decorate_the_testcase(cls):
245         if issubclass(cls, unittest.TestCase):
246             for name, m in inspect.getmembers(cls, inspect.isfunction):
247                 if name.startswith(prefix):
248                     setattr(cls, name, check_method_for_perf_regressions(m))
249                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
250         return cls
251     return decorate_the_testcase
252
253
254 def breakpoint():
255     """Hard code a breakpoint somewhere; drop into pdb."""
256     import pdb
257     pdb.set_trace()
258
259
260 class RecordStdout(object):
261     """
262     Record what is emitted to stdout.
263
264     >>> with RecordStdout() as record:
265     ...     print("This is a test!")
266     >>> print({record().readline()})
267     {'This is a test!\\n'}
268     """
269
270     def __init__(self) -> None:
271         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
272         self.recorder = None
273
274     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
275         self.recorder = contextlib.redirect_stdout(self.destination)
276         self.recorder.__enter__()
277         return lambda: self.destination
278
279     def __exit__(self, *args) -> bool:
280         self.recorder.__exit__(*args)
281         self.destination.seek(0)
282         return None
283
284
285 class RecordStderr(object):
286     """
287     Record what is emitted to stderr.
288
289     >>> import sys
290     >>> with RecordStderr() as record:
291     ...     print("This is a test!", file=sys.stderr)
292     >>> print({record().readline()})
293     {'This is a test!\\n'}
294     """
295
296     def __init__(self) -> None:
297         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
298         self.recorder = None
299
300     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
301         self.recorder = contextlib.redirect_stderr(self.destination)
302         self.recorder.__enter__()
303         return lambda: self.destination
304
305     def __exit__(self, *args) -> bool:
306         self.recorder.__exit__(*args)
307         self.destination.seek(0)
308         return None
309
310
311 class RecordMultipleStreams(object):
312     """
313     Record the output to more than one stream.
314     """
315
316     def __init__(self, *files) -> None:
317         self.files = [*files]
318         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
319         self.saved_writes = []
320
321     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
322         for f in self.files:
323             self.saved_writes.append(f.write)
324             f.write = self.destination.write
325         return lambda: self.destination
326
327     def __exit__(self, *args) -> bool:
328         for f in self.files:
329             f.write = self.saved_writes.pop()
330         self.destination.seek(0)
331
332
333 if __name__ == '__main__':
334     import doctest
335     doctest.testmod()