X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;ds=sidebyside;f=unittest_utils.py;h=28b577e2086af4ff20647d05cd9be24761839d6d;hb=532df2c5b57c7517dfb3dddd8c1358fbadf8baf3;hp=4a9669d3a21f66e35004e1968cc85b65d711fd5c;hpb=36fea7f15ed17150691b5b3ead75450e575229ef;p=python_utils.git diff --git a/unittest_utils.py b/unittest_utils.py index 4a9669d..28b577e 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 +# © Copyright 2021-2022, Scott Gasch + """Helpers for unittests. Note that when you import this we - automatically wrap unittest.main() with a call to - bootstrap.initialize so that we getLogger config, commandline args, - logging control, etc... this works fine but it's a little hacky so - caveat emptor. +automatically wrap unittest.main() with a call to bootstrap.initialize +so that we getLogger config, commandline args, logging control, +etc... this works fine but it's a little hacky so caveat emptor. + """ -from abc import ABC, abstractmethod import contextlib import functools import inspect @@ -16,24 +17,22 @@ import os import pickle import random import statistics -import time import tempfile -from typing import Callable, Dict, List +import time import unittest import warnings +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Literal, Optional + +import sqlalchemy as sa import bootstrap import config import function_utils import scott_secrets -import sqlalchemy as sa - - logger = logging.getLogger(__name__) -cfg = config.add_commandline_args( - f'Logging ({__file__})', 'Args related to function decorators' -) +cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators') cfg.add_argument( '--unittests_ignore_perf', action='store_true', @@ -79,17 +78,18 @@ unittest.main = bootstrap.initialize(unittest.main) class PerfRegressionDataPersister(ABC): + """A base class for a signature dealing with persisting perf + regression data.""" + def __init__(self): pass @abstractmethod - def load_performance_data(self) -> Dict[str, List[float]]: + def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: pass @abstractmethod - def save_performance_data( - self, method_id: str, data: Dict[str, List[float]] - ): + def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): pass @abstractmethod @@ -98,17 +98,18 @@ class PerfRegressionDataPersister(ABC): class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): + """A perf regression data persister that uses files.""" + def __init__(self, filename: str): + super().__init__() self.filename = filename - self.traces_to_delete = [] + self.traces_to_delete: List[str] = [] def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: with open(self.filename, 'rb') as f: return pickle.load(f) - def save_performance_data( - self, method_id: str, data: Dict[str, List[float]] - ): + def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): for trace in self.traces_to_delete: if trace in data: data[trace] = [] @@ -121,31 +122,30 @@ class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister): + """A perf regression data persister that uses a database backend.""" + def __init__(self, dbspec: str): + super().__init__() self.dbspec = dbspec self.engine = sa.create_engine(self.dbspec) self.conn = self.engine.connect() def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: results = self.conn.execute( - sa.text( - f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";' - ) + sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";') ) - ret = {method_id: []} + ret: Dict[str, List[float]] = {method_id: []} for result in results.all(): ret[method_id].append(result['runtime']) results.close() return ret - def save_performance_data( - self, method_id: str, data: Dict[str, List[float]] - ): + def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): self.delete_performance_data(method_id) - for (method_id, perf_data) in data.items(): + for (mid, perf_data) in data.items(): sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES ' for perf in perf_data: - self.conn.execute(sql + f'("{method_id}", {perf});') + self.conn.execute(sql + f'("{mid}", {perf});') def delete_performance_data(self, method_id: str): sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"' @@ -164,24 +164,23 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: @functools.wraps(func) def wrapper_perf_monitor(*args, **kwargs): + if config.config['unittests_ignore_perf']: + return func(*args, **kwargs) + if config.config['unittests_persistance_strategy'] == 'FILE': filename = config.config['unittests_perfdb_filename'] helper = FileBasedPerfRegressionDataPersister(filename) elif config.config['unittests_persistance_strategy'] == 'DATABASE': dbspec = config.config['unittests_perfdb_spec'] - dbspec = dbspec.replace( - '', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD - ) + dbspec = dbspec.replace('', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD) helper = DatabasePerfRegressionDataPersister(dbspec) else: - raise Exception( - 'Unknown/unexpected --unittests_persistance_strategy value' - ) + raise Exception('Unknown/unexpected --unittests_persistance_strategy value') func_id = function_utils.function_identifier(func) func_name = func.__name__ - logger.debug(f'Watching {func_name}\'s performance...') - logger.debug(f'Canonical function identifier = {func_id}') + logger.debug('Watching %s\'s performance...', func_name) + logger.debug('Canonical function identifier = "%s"', func_id) try: perfdb = helper.load_performance_data(func_id) @@ -207,22 +206,16 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: hist = perfdb.get(func_id, []) if len(hist) < config.config['unittests_num_perf_samples']: hist.append(run_time) - logger.debug(f'Still establishing a perf baseline for {func_name}') + logger.debug('Still establishing a perf baseline for %s', func_name) else: stdev = statistics.stdev(hist) - logger.debug(f'For {func_name}, performance stdev={stdev}') + logger.debug('For %s, performance stdev=%.2f', func_name, stdev) slowest = hist[-1] - logger.debug( - f'For {func_name}, slowest perf on record is {slowest:f}s' - ) + logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest) limit = slowest + stdev * 4 - logger.debug( - f'For {func_name}, max acceptable runtime is {limit:f}s' - ) - logger.debug( - f'For {func_name}, actual observed runtime was {run_time:f}s' - ) - if run_time > limit and not config.config['unittests_ignore_perf']: + logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit) + logger.debug('For %s, actual observed runtime was %.2fs', func_name, run_time) + if run_time > limit: msg = f'''{func_id} performance has regressed unacceptably. {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples. It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest. @@ -268,20 +261,13 @@ def check_all_methods_for_perf_regressions(prefix='test_'): for name, m in inspect.getmembers(cls, inspect.isfunction): if name.startswith(prefix): setattr(cls, name, check_method_for_perf_regressions(m)) - logger.debug(f'Wrapping {cls.__name__}:{name}.') + logger.debug('Wrapping %s:%s.', cls.__name__, name) return cls return decorate_the_testcase -def breakpoint(): - """Hard code a breakpoint somewhere; drop into pdb.""" - import pdb - - pdb.set_trace() - - -class RecordStdout(object): +class RecordStdout(contextlib.AbstractContextManager): """ Record what is emitted to stdout. @@ -289,24 +275,28 @@ class RecordStdout(object): ... print("This is a test!") >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: + super().__init__() self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: self.recorder = contextlib.redirect_stdout(self.destination) + assert self.recorder is not None self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Literal[False]: + assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) - return None + return False -class RecordStderr(object): +class RecordStderr(contextlib.AbstractContextManager): """ Record what is emitted to stderr. @@ -315,32 +305,37 @@ class RecordStderr(object): ... print("This is a test!", file=sys.stderr) >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: + super().__init__() self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout[Any]] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: - self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore + assert self.recorder is not None self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Literal[False]: + assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) - return None + return False -class RecordMultipleStreams(object): +class RecordMultipleStreams(contextlib.AbstractContextManager): """ Record the output to more than one stream. """ def __init__(self, *files) -> None: + super().__init__() self.files = [*files] self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.saved_writes = [] + self.saved_writes: List[Callable[..., Any]] = [] def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: for f in self.files: @@ -348,10 +343,11 @@ class RecordMultipleStreams(object): f.write = self.destination.write return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Literal[False]: for f in self.files: f.write = self.saved_writes.pop() self.destination.seek(0) + return False if __name__ == '__main__':