X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;ds=sidebyside;f=unittest_utils.py;h=2dc8cfe231e5aad03d877fac8116472e0c4b6c3d;hb=b10d30a46e601c9ee1f843241f2d69a1f90f7a94;hp=e7090bcc6237e50d9d1f982c5db52f0e6f77bd0f;hpb=1574e8a3a8982fab9278ad534f9427d464e4bffb;p=python_utils.git diff --git a/unittest_utils.py b/unittest_utils.py index e7090bc..2dc8cfe 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -1,22 +1,22 @@ #!/usr/bin/env python3 """Helpers for unittests. Note that when you import this we -automatically wrap unittest.main() with a call to bootstrap.initialize -so that we getLogger config, commandline args, logging control, -etc... this works fine but it's a little hacky so caveat emptor. + automatically wrap unittest.main() with a call to + bootstrap.initialize so that we getLogger config, commandline args, + logging control, etc... this works fine but it's a little hacky so + caveat emptor. """ +import contextlib import functools import inspect -import io import logging import pickle import random import statistics -import sys import time import tempfile -from typing import Callable, Iterable +from typing import Callable import unittest import bootstrap @@ -152,3 +152,70 @@ def breakpoint(): pdb.set_trace() +class RecordStdout(object): + """ + with uu.RecordStdout() as record: + print("This is a test!") + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stdout(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return None + + +class RecordStderr(object): + """ + with uu.RecordStderr() as record: + print("This is a test!", file=sys.stderr) + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return None + + +class RecordMultipleStreams(object): + """ + with uu.RecordStreams(sys.stderr, sys.stdout) as record: + print("This is a test!") + print("This is one too.", file=sys.stderr) + + print(record().readlines()) + """ + def __init__(self, *files) -> None: + self.files = [*files] + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.saved_writes = [] + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + for f in self.files: + self.saved_writes.append(f.write) + f.write = self.destination.write + return lambda: self.destination + + def __exit__(self, *args) -> bool: + for f in self.files: + f.write = self.saved_writes.pop() + self.destination.seek(0)