X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=unittest_utils.py;fp=unittest_utils.py;h=88e41954811de26629b405cb4ee7255d8bbebc62;hb=4b04fd1d5a14c5c4c7e0985e5376b4e2f879ef06;hp=70e588e2fa8025b2a70941b9837c78fb3f65421c;hpb=b63d4b5f98d9eec29e923b62d4f63c9b63f13927;p=python_utils.git diff --git a/unittest_utils.py b/unittest_utils.py index 70e588e..88e4195 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -20,7 +20,7 @@ import time import unittest import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional import sqlalchemy as sa @@ -265,14 +265,7 @@ def check_all_methods_for_perf_regressions(prefix='test_'): return decorate_the_testcase -def breakpoint(): - """Hard code a breakpoint somewhere; drop into pdb.""" - import pdb - - pdb.set_trace() - - -class RecordStdout(object): +class RecordStdout(contextlib.AbstractContextManager): """ Record what is emitted to stdout. @@ -284,6 +277,7 @@ class RecordStdout(object): """ def __init__(self) -> None: + super().__init__() self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.recorder: Optional[contextlib.redirect_stdout] = None @@ -293,14 +287,14 @@ class RecordStdout(object): self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> Optional[bool]: + def __exit__(self, *args) -> Literal[False]: assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) - return None + return False -class RecordStderr(object): +class RecordStderr(contextlib.AbstractContextManager): """ Record what is emitted to stderr. @@ -313,6 +307,7 @@ class RecordStderr(object): """ def __init__(self) -> None: + super().__init__() self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.recorder: Optional[contextlib.redirect_stdout[Any]] = None @@ -322,19 +317,20 @@ class RecordStderr(object): self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> Optional[bool]: + def __exit__(self, *args) -> Literal[False]: assert self.recorder is not None self.recorder.__exit__(*args) self.destination.seek(0) - return None + return False -class RecordMultipleStreams(object): +class RecordMultipleStreams(contextlib.AbstractContextManager): """ Record the output to more than one stream. """ def __init__(self, *files) -> None: + super().__init__() self.files = [*files] self.destination = tempfile.SpooledTemporaryFile(mode='r+') self.saved_writes: List[Callable[..., Any]] = [] @@ -345,11 +341,11 @@ class RecordMultipleStreams(object): f.write = self.destination.write return lambda: self.destination - def __exit__(self, *args) -> Optional[bool]: + def __exit__(self, *args) -> Literal[False]: for f in self.files: f.write = self.saved_writes.pop() self.destination.seek(0) - return None + return False if __name__ == '__main__':