#!/usr/bin/env python3 """Helpers for unittests. Note that when you import this we automatically wrap unittest.main() with a call to bootstrap.initialize so that we getLogger config, commandline args, logging control, etc... this works fine but it's a little hacky so caveat emptor. """ import contextlib import functools import inspect import logging import pickle import random import statistics import time import tempfile from typing import Callable import unittest import bootstrap import config logger = logging.getLogger(__name__) cfg = config.add_commandline_args( f'Logging ({__file__})', 'Args related to function decorators') cfg.add_argument( '--unittests_ignore_perf', action='store_true', default=False, help='Ignore unittest perf regression in @check_method_for_perf_regressions', ) cfg.add_argument( '--unittests_num_perf_samples', type=int, default=20, help='The count of perf timing samples we need to see before blocking slow runs on perf grounds' ) cfg.add_argument( '--unittests_drop_perf_traces', type=str, nargs=1, default=None, help='The identifier (i.e. file!test_fixture) for which we should drop all perf data' ) # >>> This is the hacky business, FYI. <<< unittest.main = bootstrap.initialize(unittest.main) _db = '/home/scott/.python_unittest_performance_db' def check_method_for_perf_regressions(func: Callable) -> Callable: """This is meant to be used on a method in a class that subclasses unittest.TestCase. When thus decorated it will time the execution of the code in the method, compare it with a database of historical perfmance, and fail the test with a perf-related message if it has become too slow. """ def load_known_test_performance_characteristics(): with open(_db, 'rb') as f: return pickle.load(f) def save_known_test_performance_characteristics(perfdb): with open(_db, 'wb') as f: pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL) @functools.wraps(func) def wrapper_perf_monitor(*args, **kwargs): try: perfdb = load_known_test_performance_characteristics() except Exception as e: logger.exception(e) logger.warning(f'Unable to load perfdb from {_db}') perfdb = {} # This is a unique identifier for a test: filepath!function logger.debug(f'Watching {func.__name__}\'s performance...') func_id = f'{func.__globals__["__file__"]}!{func.__name__}' logger.debug(f'Canonical function identifier = {func_id}') # cmdline arg to forget perf traces for function drop_id = config.config['unittests_drop_perf_traces'] if drop_id is not None: if drop_id in perfdb: perfdb[drop_id] = [] # Run the wrapped test paying attention to latency. start_time = time.perf_counter() value = func(*args, **kwargs) end_time = time.perf_counter() run_time = end_time - start_time logger.debug(f'{func.__name__} executed in {run_time:f}s.') # Check the db; see if it was unexpectedly slow. hist = perfdb.get(func_id, []) if len(hist) < config.config['unittests_num_perf_samples']: hist.append(run_time) logger.debug( f'Still establishing a perf baseline for {func.__name__}' ) else: stdev = statistics.stdev(hist) limit = hist[-1] + stdev * 3 logger.debug( f'Max acceptable performace for {func.__name__} is {limit:f}s' ) if ( run_time > limit and not config.config['unittests_ignore_perf'] ): msg = f'''{func_id} performance has regressed unacceptably. {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples. It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample. Here is the current, full db perf timing distribution: {hist}''' slf = args[0] logger.error(msg) slf.fail(msg) else: hist.append(run_time) n = min(config.config['unittests_num_perf_samples'], len(hist)) hist = random.sample(hist, n) hist.sort() perfdb[func_id] = hist save_known_test_performance_characteristics(perfdb) return value return wrapper_perf_monitor def check_all_methods_for_perf_regressions(prefix='test_'): def decorate_the_testcase(cls): if issubclass(cls, unittest.TestCase): for name, m in inspect.getmembers(cls, inspect.isfunction): if name.startswith(prefix): setattr(cls, name, check_method_for_perf_regressions(m)) logger.debug(f'Wrapping {cls.__name__}:{name}.') return cls return decorate_the_testcase def breakpoint(): """Hard code a breakpoint somewhere; drop into pdb.""" import pdb pdb.set_trace() class RecordStdout(object): """ Record what is emitted to stdout. >>> with RecordStdout() as record: ... print("This is a test!") >>> print({record().readline()}) {'This is a test!\\n'} """ def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.recorder = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: self.recorder = contextlib.redirect_stdout(self.destination) self.recorder.__enter__() return lambda: self.destination def __exit__(self, *args) -> bool: self.recorder.__exit__(*args) self.destination.seek(0) return None class RecordStderr(object): """ Record what is emitted to stderr. >>> import sys >>> with RecordStderr() as record: ... print("This is a test!", file=sys.stderr) >>> print({record().readline()}) {'This is a test!\\n'} """ def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.recorder = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: self.recorder = contextlib.redirect_stderr(self.destination) self.recorder.__enter__() return lambda: self.destination def __exit__(self, *args) -> bool: self.recorder.__exit__(*args) self.destination.seek(0) return None class RecordMultipleStreams(object): """ Record the output to more than one stream. """ def __init__(self, *files) -> None: self.files = [*files] self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.saved_writes = [] def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: for f in self.files: self.saved_writes.append(f.write) f.write = self.destination.write return lambda: self.destination def __exit__(self, *args) -> bool: for f in self.files: f.write = self.saved_writes.pop() self.destination.seek(0) if __name__ == '__main__': import doctest doctest.testmod()