X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=unittest_utils.py;h=f229df75e8b88825d66ca227d7e907d3dc725e1a;hb=dfc2136113428b99719c49a57d3ce68391dcb307;hp=f4fed35f09fdf29970820bef8566652825327634;hpb=e6f32fdd9b373dfcd100c7accb41f57d83c2f0a1;p=python_utils.git diff --git a/unittest_utils.py b/unittest_utils.py index f4fed35..f229df7 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -7,7 +7,6 @@ caveat emptor. """ -from abc import ABC, abstractmethod import contextlib import functools import inspect @@ -16,24 +15,22 @@ import os import pickle import random import statistics -import time import tempfile -from typing import Callable, Dict, List +import time import unittest import warnings +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional + +import sqlalchemy as sa import bootstrap import config import function_utils import scott_secrets -import sqlalchemy as sa - - logger = logging.getLogger(__name__) -cfg = config.add_commandline_args( - f'Logging ({__file__})', 'Args related to function decorators' -) +cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to function decorators') cfg.add_argument( '--unittests_ignore_perf', action='store_true', @@ -83,7 +80,7 @@ class PerfRegressionDataPersister(ABC): pass @abstractmethod - def load_performance_data(self) -> Dict[str, List[float]]: + def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: pass @abstractmethod @@ -98,7 +95,7 @@ class PerfRegressionDataPersister(ABC): class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): def __init__(self, filename: str): self.filename = filename - self.traces_to_delete = [] + self.traces_to_delete: List[str] = [] def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: with open(self.filename, 'rb') as f: @@ -124,11 +121,9 @@ class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister): def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: results = self.conn.execute( - sa.text( - f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";' - ) + sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";') ) - ret = {method_id: []} + ret: Dict[str, List[float]] = {method_id: []} for result in results.all(): ret[method_id].append(result['runtime']) results.close() @@ -158,14 +153,15 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: @functools.wraps(func) def wrapper_perf_monitor(*args, **kwargs): + if config.config['unittests_ignore_perf']: + return func(*args, **kwargs) + if config.config['unittests_persistance_strategy'] == 'FILE': filename = config.config['unittests_perfdb_filename'] helper = FileBasedPerfRegressionDataPersister(filename) elif config.config['unittests_persistance_strategy'] == 'DATABASE': dbspec = config.config['unittests_perfdb_spec'] - dbspec = dbspec.replace( - '', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD - ) + dbspec = dbspec.replace('', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD) helper = DatabasePerfRegressionDataPersister(dbspec) else: raise Exception('Unknown/unexpected --unittests_persistance_strategy value') @@ -208,7 +204,7 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: limit = slowest + stdev * 4 logger.debug(f'For {func_name}, max acceptable runtime is {limit:f}s') logger.debug(f'For {func_name}, actual observed runtime was {run_time:f}s') - if run_time > limit and not config.config['unittests_ignore_perf']: + if run_time > limit: msg = f'''{func_id} performance has regressed unacceptably. {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples. It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest. @@ -275,18 +271,21 @@ class RecordStdout(object): ... print("This is a test!") >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: self.recorder = contextlib.redirect_stdout(self.destination) + assert self.recorder is not None self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: + assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) return None @@ -301,18 +300,21 @@ class RecordStderr(object): ... print("This is a test!", file=sys.stderr) >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout[Any]] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: - self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore + assert self.recorder is not None self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: + assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) return None @@ -326,7 +328,7 @@ class RecordMultipleStreams(object): def __init__(self, *files) -> None: self.files = [*files] self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.saved_writes = [] + self.saved_writes: List[Callable[..., Any]] = [] def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: for f in self.files: @@ -334,10 +336,11 @@ class RecordMultipleStreams(object): f.write = self.destination.write return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: for f in self.files: f.write = self.saved_writes.pop() self.destination.seek(0) + return None if __name__ == '__main__':