Make unittest_utils log its perf data to a database, optionally.
[python_utils.git] / unittest_utils.py
1 #!/usr/bin/env python3
2
3 """Helpers for unittests.  Note that when you import this we
4    automatically wrap unittest.main() with a call to
5    bootstrap.initialize so that we getLogger config, commandline args,
6    logging control, etc... this works fine but it's a little hacky so
7    caveat emptor.
8 """
9
10 from abc import ABC, abstractmethod
11 import contextlib
12 import functools
13 import inspect
14 import logging
15 import os
16 import pickle
17 import random
18 import statistics
19 import time
20 import tempfile
21 from typing import Callable, Dict, List
22 import unittest
23 import warnings
24
25 import bootstrap
26 import config
27 import scott_secrets
28
29 import sqlalchemy as sa
30
31
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34     f'Logging ({__file__})',
35     'Args related to function decorators')
36 cfg.add_argument(
37     '--unittests_ignore_perf',
38     action='store_true',
39     default=False,
40     help='Ignore unittest perf regression in @check_method_for_perf_regressions',
41 )
42 cfg.add_argument(
43     '--unittests_num_perf_samples',
44     type=int,
45     default=50,
46     help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
47 )
48 cfg.add_argument(
49     '--unittests_drop_perf_traces',
50     type=str,
51     nargs=1,
52     default=None,
53     help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
54 )
55 cfg.add_argument(
56     '--unittests_persistance_strategy',
57     choices=['FILE', 'DATABASE'],
58     default='DATABASE',
59     help='Should we persist perf data in a file or db?'
60 )
61 cfg.add_argument(
62     '--unittests_perfdb_filename',
63     type=str,
64     metavar='FILENAME',
65     default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66     help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)'
67 )
68 cfg.add_argument(
69     '--unittests_perfdb_spec',
70     type=str,
71     metavar='DBSPEC',
72     default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
73     help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)'
74 )
75
76 # >>> This is the hacky business, FYI. <<<
77 unittest.main = bootstrap.initialize(unittest.main)
78
79
80 class PerfRegressionDataPersister(ABC):
81     def __init__(self):
82         pass
83
84     @abstractmethod
85     def load_performance_data(self) -> Dict[str, List[float]]:
86         pass
87
88     @abstractmethod
89     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
90         pass
91
92     @abstractmethod
93     def delete_performance_data(self, method_id: str):
94         pass
95
96
97 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
98     def __init__(self, filename: str):
99         self.filename = filename
100         self.traces_to_delete = []
101
102     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
103         with open(self.filename, 'rb') as f:
104             return pickle.load(f)
105
106     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
107         for trace in self.traces_to_delete:
108             if trace in data:
109                 data[trace] = []
110
111         with open(self.filename, 'wb') as f:
112             pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
113
114     def delete_performance_data(self, method_id: str):
115         self.traces_to_delete.append(method_id)
116
117
118 class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
119     def __init__(self, dbspec: str):
120         self.dbspec = dbspec
121         self.engine = sa.create_engine(self.dbspec)
122         self.conn = self.engine.connect()
123
124     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
125         results = self.conn.execute(
126             sa.text(
127                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
128             )
129         )
130         ret = {method_id: []}
131         for result in results.all():
132             ret[method_id].append(result['runtime'])
133         results.close()
134         return ret
135
136     def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
137         self.delete_performance_data(method_id)
138         for (method_id, perf_data) in data.items():
139             sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
140             for perf in perf_data:
141                 self.conn.execute(sql + f'("{method_id}", {perf});')
142
143     def delete_performance_data(self, method_id: str):
144         sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
145         self.conn.execute(sql)
146
147
148 def function_identifier(f: Callable) -> str:
149     """
150     Given a callable function, return a string that identifies it.
151     Usually that string is just __module__:__name__ but there's a
152     corner case: when __module__ is __main__ (i.e. the callable is
153     defined in the same module as __main__).  In this case,
154     f.__module__ returns "__main__" instead of the file that it is
155     defined in.  Work around this using pathlib.Path (see below).
156
157     >>> function_identifier(function_identifier)
158     'unittest_utils:function_identifier'
159
160     """
161     if f.__module__ == '__main__':
162         from pathlib import Path
163         import __main__
164         module = __main__.__file__
165         module = Path(module).stem
166         return f'{module}:{f.__name__}'
167     else:
168         return f'{f.__module__}:{f.__name__}'
169
170
171 def check_method_for_perf_regressions(func: Callable) -> Callable:
172     """
173     This is meant to be used on a method in a class that subclasses
174     unittest.TestCase.  When thus decorated it will time the execution
175     of the code in the method, compare it with a database of
176     historical perfmance, and fail the test with a perf-related
177     message if it has become too slow.
178
179     """
180     @functools.wraps(func)
181     def wrapper_perf_monitor(*args, **kwargs):
182         if config.config['unittests_persistance_strategy'] == 'FILE':
183             filename = config.config['unittests_perfdb_filename']
184             helper = FileBasedPerfRegressionDataPersister(filename)
185         elif config.config['unittests_persistance_strategy'] == 'DATABASE':
186             dbspec = config.config['unittests_perfdb_spec']
187             dbspec = dbspec.replace('<PASSWORD>', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD)
188             helper = DatabasePerfRegressionDataPersister(dbspec)
189         else:
190             raise Exception(
191                 'Unknown/unexpected --unittests_persistance_strategy value'
192             )
193
194         logger.debug(f'Watching {func.__name__}\'s performance...')
195         func_id = function_identifier(func)
196         logger.debug(f'Canonical function identifier = {func_id}')
197
198         try:
199             perfdb = helper.load_performance_data(func_id)
200         except Exception as e:
201             logger.exception(e)
202             msg = 'Unable to load perfdb; skipping it...'
203             logger.warning(msg)
204             warnings.warn(msg)
205             perfdb = {}
206
207         # cmdline arg to forget perf traces for function
208         drop_id = config.config['unittests_drop_perf_traces']
209         if drop_id is not None:
210             helper.delete_performance_data(drop_id)
211
212         # Run the wrapped test paying attention to latency.
213         start_time = time.perf_counter()
214         value = func(*args, **kwargs)
215         end_time = time.perf_counter()
216         run_time = end_time - start_time
217
218         # See if it was unexpectedly slow.
219         hist = perfdb.get(func_id, [])
220         if len(hist) < config.config['unittests_num_perf_samples']:
221             hist.append(run_time)
222             logger.debug(
223                 f'Still establishing a perf baseline for {func.__name__}'
224             )
225         else:
226             stdev = statistics.stdev(hist)
227             logger.debug(f'For {func.__name__}, performance stdev={stdev}')
228             slowest = hist[-1]
229             logger.debug(f'For {func.__name__}, slowest perf on record is {slowest:f}s')
230             limit = slowest + stdev * 4
231             logger.debug(
232                 f'For {func.__name__}, max acceptable runtime is {limit:f}s'
233             )
234             logger.debug(
235                 f'For {func.__name__}, actual observed runtime was {run_time:f}s'
236             )
237             if (
238                 run_time > limit and
239                 not config.config['unittests_ignore_perf']
240             ):
241                 msg = f'''{func_id} performance has regressed unacceptably.
242 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
243 It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
244 Here is the current, full db perf timing distribution:
245
246 '''
247                 for x in hist:
248                     msg += f'{x:f}\n'
249                 logger.error(msg)
250                 slf = args[0]
251                 slf.fail(msg)
252             else:
253                 hist.append(run_time)
254
255         n = min(config.config['unittests_num_perf_samples'], len(hist))
256         hist = random.sample(hist, n)
257         hist.sort()
258         perfdb[func_id] = hist
259         helper.save_performance_data(func_id, perfdb)
260         return value
261     return wrapper_perf_monitor
262
263
264 def check_all_methods_for_perf_regressions(prefix='test_'):
265     def decorate_the_testcase(cls):
266         if issubclass(cls, unittest.TestCase):
267             for name, m in inspect.getmembers(cls, inspect.isfunction):
268                 if name.startswith(prefix):
269                     setattr(cls, name, check_method_for_perf_regressions(m))
270                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
271         return cls
272     return decorate_the_testcase
273
274
275 def breakpoint():
276     """Hard code a breakpoint somewhere; drop into pdb."""
277     import pdb
278     pdb.set_trace()
279
280
281 class RecordStdout(object):
282     """
283     Record what is emitted to stdout.
284
285     >>> with RecordStdout() as record:
286     ...     print("This is a test!")
287     >>> print({record().readline()})
288     {'This is a test!\\n'}
289     """
290
291     def __init__(self) -> None:
292         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
293         self.recorder = None
294
295     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
296         self.recorder = contextlib.redirect_stdout(self.destination)
297         self.recorder.__enter__()
298         return lambda: self.destination
299
300     def __exit__(self, *args) -> bool:
301         self.recorder.__exit__(*args)
302         self.destination.seek(0)
303         return None
304
305
306 class RecordStderr(object):
307     """
308     Record what is emitted to stderr.
309
310     >>> import sys
311     >>> with RecordStderr() as record:
312     ...     print("This is a test!", file=sys.stderr)
313     >>> print({record().readline()})
314     {'This is a test!\\n'}
315     """
316
317     def __init__(self) -> None:
318         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
319         self.recorder = None
320
321     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
322         self.recorder = contextlib.redirect_stderr(self.destination)
323         self.recorder.__enter__()
324         return lambda: self.destination
325
326     def __exit__(self, *args) -> bool:
327         self.recorder.__exit__(*args)
328         self.destination.seek(0)
329         return None
330
331
332 class RecordMultipleStreams(object):
333     """
334     Record the output to more than one stream.
335     """
336
337     def __init__(self, *files) -> None:
338         self.files = [*files]
339         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
340         self.saved_writes = []
341
342     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
343         for f in self.files:
344             self.saved_writes.append(f.write)
345             f.write = self.destination.write
346         return lambda: self.destination
347
348     def __exit__(self, *args) -> bool:
349         for f in self.files:
350             f.write = self.saved_writes.pop()
351         self.destination.seek(0)
352
353
354 if __name__ == '__main__':
355     import doctest
356     doctest.testmod()