X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=unittest_utils.py;h=2dc8cfe231e5aad03d877fac8116472e0c4b6c3d;hb=6f132c0342ab7aa438ed88d7c5f987cb52d8ca05;hp=99ac81d32b3284fc8257d750b193fa57564cebb4;hpb=3bc4daf1edc121cd633429187392227f2fa61885;p=python_utils.git diff --git a/unittest_utils.py b/unittest_utils.py index 99ac81d..2dc8cfe 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 """Helpers for unittests. Note that when you import this we -automatically wrap unittest.main() with a call to bootstrap.initialize -so that we getLogger config, commandline args, logging control, -etc... this works fine but it's a little hacky so caveat emptor. + automatically wrap unittest.main() with a call to + bootstrap.initialize so that we getLogger config, commandline args, + logging control, etc... this works fine but it's a little hacky so + caveat emptor. """ +import contextlib import functools import inspect import logging @@ -13,6 +15,7 @@ import pickle import random import statistics import time +import tempfile from typing import Callable import unittest @@ -142,3 +145,77 @@ def check_all_methods_for_perf_regressions(prefix='test_'): logger.debug(f'Wrapping {cls.__name__}:{name}.') return cls return decorate_the_testcase + + +def breakpoint(): + import pdb + pdb.set_trace() + + +class RecordStdout(object): + """ + with uu.RecordStdout() as record: + print("This is a test!") + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stdout(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return None + + +class RecordStderr(object): + """ + with uu.RecordStderr() as record: + print("This is a test!", file=sys.stderr) + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return None + + +class RecordMultipleStreams(object): + """ + with uu.RecordStreams(sys.stderr, sys.stdout) as record: + print("This is a test!") + print("This is one too.", file=sys.stderr) + + print(record().readlines()) + """ + def __init__(self, *files) -> None: + self.files = [*files] + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.saved_writes = [] + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + for f in self.files: + self.saved_writes.append(f.write) + f.write = self.destination.write + return lambda: self.destination + + def __exit__(self, *args) -> bool: + for f in self.files: + f.write = self.saved_writes.pop() + self.destination.seek(0)