From c79ecbf708a63a54a9c3e8d189b65d4794930082 Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Sat, 17 Jul 2021 14:47:01 -0700 Subject: [PATCH] Record streams. --- logging_utils.py | 85 +++++++++++++++++++------------------ tests/logging_utils_test.py | 48 +++++++++++++-------- unittest_utils.py | 72 +++++++++++++++++++++++++++++-- 3 files changed, 142 insertions(+), 63 deletions(-) diff --git a/logging_utils.py b/logging_utils.py index b7fd11f..0c7d193 100644 --- a/logging_utils.py +++ b/logging_utils.py @@ -5,11 +5,13 @@ import contextlib import datetime import enum +import io import logging from logging.handlers import RotatingFileHandler, SysLogHandler import os import pytz import sys +from typing import Iterable, Optional # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. @@ -239,46 +241,46 @@ class OutputMultiplexer(object): class Destination(enum.IntEnum): """Bits in the destination_bitv bitvector. Used to indicate the output destination.""" - STDOUT = 0x1 - STDERR = 0x2 - LOG_DEBUG = 0x4 # -\ - LOG_INFO = 0x8 # | - LOG_WARNING = 0x10 # > Should provide logger to the c'tor. - LOG_ERROR = 0x20 # | - LOG_CRITICAL = 0x40 # _/ - FILENAME = 0x80 # Must provide a filename to the c'tor. - FILEHANDLE = 0x100 # Must provide a handle to the c'tor. - HLOG = 0x200 + LOG_DEBUG = 0x01 # -\ + LOG_INFO = 0x02 # | + LOG_WARNING = 0x04 # > Should provide logger to the c'tor. + LOG_ERROR = 0x08 # | + LOG_CRITICAL = 0x10 # _/ + FILENAMES = 0x20 # Must provide a filename to the c'tor. + FILEHANDLES = 0x40 # Must provide a handle to the c'tor. + HLOG = 0x80 ALL_LOG_DESTINATIONS = ( LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL ) - ALL_OUTPUT_DESTINATIONS = 0x2FF + ALL_OUTPUT_DESTINATIONS = 0x8F def __init__(self, destination_bitv: int, *, logger=None, - filename=None, - handle=None): + filenames: Optional[Iterable[str]] = None, + handles: Optional[Iterable[io.TextIOWrapper]] = None): if logger is None: logger = logging.getLogger(None) self.logger = logger - if filename is not None: - self.f = open(filename, "wb", buffering=0) + if filenames is not None: + self.f = [ + open(filename, 'wb', buffering=0) for filename in filenames + ] else: - if self.destination_bitv & OutputMultiplexer.FILENAME: + if self.destination_bitv & OutputMultiplexer.FILENAMES: raise ValueError( - "Filename argument is required if bitv & FILENAME" + "Filenames argument is required if bitv & FILENAMES" ) self.f = None - if handle is not None: - self.h = handle + if handles is not None: + self.h = [handle for handle in handles] else: - if self.destination_bitv & OutputMultiplexer.FILEHANDLE: + if self.destination_bitv & OutputMultiplexer.FILEHANDLES: raise ValueError( - "Handle argument is required if bitv & FILEHANDLE" + "Handle argument is required if bitv & FILEHANDLES" ) self.h = None @@ -288,13 +290,13 @@ class OutputMultiplexer(object): return self.destination_bitv def set_destination_bitv(self, destination_bitv: int): - if destination_bitv & self.Destination.FILENAME and self.f is None: + if destination_bitv & self.Destination.FILENAMES and self.f is None: raise ValueError( - "Filename argument is required if bitv & FILENAME" + "Filename argument is required if bitv & FILENAMES" ) - if destination_bitv & self.Destination.FILEHANDLE and self.h is None: + if destination_bitv & self.Destination.FILEHANDLES and self.h is None: raise ValueError( - "Handle argument is required if bitv & FILEHANDLE" + "Handle argument is required if bitv & FILEHANDLES" ) self.destination_bitv = destination_bitv @@ -315,25 +317,23 @@ class OutputMultiplexer(object): sep = " " if end is None: end = "\n" - if self.destination_bitv & self.Destination.STDOUT: - print(buf, file=sys.stdout, sep=sep, end=end) - if self.destination_bitv & self.Destination.STDERR: - print(buf, file=sys.stderr, sep=sep, end=end) if end == '\n': buf += '\n' if ( - self.destination_bitv & self.Destination.FILENAME and + self.destination_bitv & self.Destination.FILENAMES and self.f is not None ): - self.f.write(buf.encode('utf-8')) - self.f.flush() + for _ in self.f: + _.write(buf.encode('utf-8')) + _.flush() if ( - self.destination_bitv & self.Destination.FILEHANDLE and + self.destination_bitv & self.Destination.FILEHANDLES and self.h is not None ): - self.h.write(buf) - self.h.flush() + for _ in self.h: + _.write(buf) + _.flush() buf = strip_escape_sequences(buf) if self.logger is not None: @@ -352,21 +352,22 @@ class OutputMultiplexer(object): def close(self): if self.f is not None: - self.f.close() + for _ in self.f: + _.close() -class OutputContext(OutputMultiplexer, contextlib.ContextDecorator): +class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator): def __init__(self, destination_bitv: OutputMultiplexer.Destination, *, - logger=None, - filename=None, - handle=None): + logger = None, + filenames = None, + handles = None): super().__init__( destination_bitv, logger=logger, - filename=filename, - handle=handle) + filenames=filenames, + handles=handles) def __enter__(self): return self diff --git a/tests/logging_utils_test.py b/tests/logging_utils_test.py index 48e5015..87c00d6 100755 --- a/tests/logging_utils_test.py +++ b/tests/logging_utils_test.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 -import contextlib +import os +import sys import tempfile import unittest -import bootstrap import logging_utils as lutils import string_utils as sutils +import unittest_utils as uu class TestLoggingUtils(unittest.TestCase): @@ -17,26 +18,37 @@ class TestLoggingUtils(unittest.TestCase): secret_message = f'This is a test, {unique_suffix}.' with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile1: - with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile2: - with contextlib.redirect_stdout(tmpfile1): - with lutils.OutputContext( - lutils.OutputMultiplexer.Destination.FILENAME | - lutils.OutputMultiplexer.Destination.FILEHANDLE | - lutils.OutputMultiplexer.Destination.STDOUT, - filename = filename, - handle = tmpfile2, - ) as mplex: - mplex.print(secret_message, end='') - with open(filename, 'r') as rf: - self.assertEqual(rf.readline(), secret_message) - tmpfile2.seek(0) - tmp = tmpfile2.readline() - self.assertEqual(tmp, secret_message) + with uu.RecordStdout() as record: + with lutils.OutputMultiplexerContext( + lutils.OutputMultiplexer.Destination.FILENAMES | + lutils.OutputMultiplexer.Destination.FILEHANDLES | + lutils.OutputMultiplexer.Destination.LOG_INFO, + filenames = [filename, '/dev/null'], + handles = [tmpfile1, sys.stdout], + ) as mplex: + mplex.print(secret_message, end='') + + # Make sure it was written to the filename. + with open(filename, 'r') as rf: + self.assertEqual(rf.readline(), secret_message) + os.unlink(filename) + + # Make sure it was written to stdout. + tmp = record().readline() + self.assertEqual(tmp, secret_message) + + # Make sure it was written to the filehandle. tmpfile1.seek(0) tmp = tmpfile1.readline() self.assertEqual(tmp, secret_message) + def test_record_streams(self): + with uu.RecordMultipleStreams(sys.stderr, sys.stdout) as record: + print("This is a test!") + print("This is one too.", file=sys.stderr) + self.assertEqual(record().readlines(), + ["This is a test!\n", "This is one too.\n"]) + if __name__ == '__main__': - unittest.main = bootstrap.initialize(unittest.main) unittest.main() diff --git a/unittest_utils.py b/unittest_utils.py index e7090bc..5987da6 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -6,17 +6,16 @@ so that we getLogger config, commandline args, logging control, etc... this works fine but it's a little hacky so caveat emptor. """ +import contextlib import functools import inspect -import io import logging import pickle import random import statistics -import sys import time import tempfile -from typing import Callable, Iterable +from typing import Callable import unittest import bootstrap @@ -152,3 +151,70 @@ def breakpoint(): pdb.set_trace() +class RecordStdout(object): + """ + with uu.RecordStdout() as record: + print("This is a test!") + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stdout(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return True + + +class RecordStderr(object): + """ + with uu.RecordStderr() as record: + print("This is a test!", file=sys.stderr) + print({record().readline()}) + """ + + def __init__(self) -> None: + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> bool: + self.recorder.__exit__(*args) + self.destination.seek(0) + return True + + +class RecordMultipleStreams(object): + """ + with uu.RecordStreams(sys.stderr, sys.stdout) as record: + print("This is a test!") + print("This is one too.", file=sys.stderr) + + print(record().readlines()) + """ + def __init__(self, *files) -> None: + self.files = [*files] + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.saved_writes = [] + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + for f in self.files: + self.saved_writes.append(f.write) + f.write = self.destination.write + return lambda: self.destination + + def __exit__(self, *args) -> bool: + for f in self.files: + f.write = self.saved_writes.pop() + self.destination.seek(0) -- 2.45.2