Record streams.
authorScott Gasch <[email protected]>
Sat, 17 Jul 2021 21:47:01 +0000 (14:47 -0700)
committerScott Gasch <[email protected]>
Sat, 17 Jul 2021 21:47:01 +0000 (14:47 -0700)
logging_utils.py
tests/logging_utils_test.py
unittest_utils.py

index b7fd11fd8168cfbc985fa1fa60e28704b86075bd..0c7d19362d7ed59cf9009053465763e47e6e4709 100644 (file)
@@ -5,11 +5,13 @@
 import contextlib
 import datetime
 import enum
+import io
 import logging
 from logging.handlers import RotatingFileHandler, SysLogHandler
 import os
 import pytz
 import sys
+from typing import Iterable, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
@@ -239,46 +241,46 @@ class OutputMultiplexer(object):
     class Destination(enum.IntEnum):
         """Bits in the destination_bitv bitvector.  Used to indicate the
         output destination."""
-        STDOUT = 0x1
-        STDERR = 0x2
-        LOG_DEBUG = 0x4          # -\
-        LOG_INFO = 0x8           #  |
-        LOG_WARNING = 0x10       #   > Should provide logger to the c'tor.
-        LOG_ERROR = 0x20         #  |
-        LOG_CRITICAL = 0x40      # _/
-        FILENAME = 0x80          # Must provide a filename to the c'tor.
-        FILEHANDLE = 0x100       # Must provide a handle to the c'tor.
-        HLOG = 0x200
+        LOG_DEBUG = 0x01         # -\
+        LOG_INFO = 0x02          #  |
+        LOG_WARNING = 0x04       #   > Should provide logger to the c'tor.
+        LOG_ERROR = 0x08         #  |
+        LOG_CRITICAL = 0x10      # _/
+        FILENAMES = 0x20         # Must provide a filename to the c'tor.
+        FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
+        HLOG = 0x80
         ALL_LOG_DESTINATIONS = (
             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
         )
-        ALL_OUTPUT_DESTINATIONS = 0x2FF
+        ALL_OUTPUT_DESTINATIONS = 0x8F
 
     def __init__(self,
                  destination_bitv: int,
                  *,
                  logger=None,
-                 filename=None,
-                 handle=None):
+                 filenames: Optional[Iterable[str]] = None,
+                 handles: Optional[Iterable[io.TextIOWrapper]] = None):
         if logger is None:
             logger = logging.getLogger(None)
         self.logger = logger
 
-        if filename is not None:
-            self.f = open(filename, "wb", buffering=0)
+        if filenames is not None:
+            self.f = [
+                open(filename, 'wb', buffering=0) for filename in filenames
+            ]
         else:
-            if self.destination_bitv & OutputMultiplexer.FILENAME:
+            if self.destination_bitv & OutputMultiplexer.FILENAMES:
                 raise ValueError(
-                    "Filename argument is required if bitv & FILENAME"
+                    "Filenames argument is required if bitv & FILENAMES"
                 )
             self.f = None
 
-        if handle is not None:
-            self.h = handle
+        if handles is not None:
+            self.h = [handle for handle in handles]
         else:
-            if self.destination_bitv & OutputMultiplexer.FILEHANDLE:
+            if self.destination_bitv & OutputMultiplexer.FILEHANDLES:
                 raise ValueError(
-                    "Handle argument is required if bitv & FILEHANDLE"
+                    "Handle argument is required if bitv & FILEHANDLES"
                 )
             self.h = None
 
@@ -288,13 +290,13 @@ class OutputMultiplexer(object):
         return self.destination_bitv
 
     def set_destination_bitv(self, destination_bitv: int):
-        if destination_bitv & self.Destination.FILENAME and self.f is None:
+        if destination_bitv & self.Destination.FILENAMES and self.f is None:
             raise ValueError(
-                "Filename argument is required if bitv & FILENAME"
+                "Filename argument is required if bitv & FILENAMES"
             )
-        if destination_bitv & self.Destination.FILEHANDLE and self.h is None:
+        if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
             raise ValueError(
-                    "Handle argument is required if bitv & FILEHANDLE"
+                    "Handle argument is required if bitv & FILEHANDLES"
                 )
         self.destination_bitv = destination_bitv
 
@@ -315,25 +317,23 @@ class OutputMultiplexer(object):
             sep = " "
         if end is None:
             end = "\n"
-        if self.destination_bitv & self.Destination.STDOUT:
-            print(buf, file=sys.stdout, sep=sep, end=end)
-        if self.destination_bitv & self.Destination.STDERR:
-            print(buf, file=sys.stderr, sep=sep, end=end)
         if end == '\n':
             buf += '\n'
         if (
-                self.destination_bitv & self.Destination.FILENAME and
+                self.destination_bitv & self.Destination.FILENAMES and
                 self.f is not None
         ):
-            self.f.write(buf.encode('utf-8'))
-            self.f.flush()
+            for _ in self.f:
+                _.write(buf.encode('utf-8'))
+                _.flush()
 
         if (
-                self.destination_bitv & self.Destination.FILEHANDLE and
+                self.destination_bitv & self.Destination.FILEHANDLES and
                 self.h is not None
         ):
-            self.h.write(buf)
-            self.h.flush()
+            for _ in self.h:
+                _.write(buf)
+                _.flush()
 
         buf = strip_escape_sequences(buf)
         if self.logger is not None:
@@ -352,21 +352,22 @@ class OutputMultiplexer(object):
 
     def close(self):
         if self.f is not None:
-            self.f.close()
+            for _ in self.f:
+                _.close()
 
 
-class OutputContext(OutputMultiplexer, contextlib.ContextDecorator):
+class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
     def __init__(self,
                  destination_bitv: OutputMultiplexer.Destination,
                  *,
-                 logger=None,
-                 filename=None,
-                 handle=None):
+                 logger = None,
+                 filenames = None,
+                 handles = None):
         super().__init__(
             destination_bitv,
             logger=logger,
-            filename=filename,
-            handle=handle)
+            filenames=filenames,
+            handles=handles)
 
     def __enter__(self):
         return self
index 48e5015f12f1f0af1595c3e1aa026bdcb9236123..87c00d64e30e3a9b0a90e08658742ad0ae2739aa 100755 (executable)
@@ -1,12 +1,13 @@
 #!/usr/bin/env python3
 
-import contextlib
+import os
+import sys
 import tempfile
 import unittest
 
-import bootstrap
 import logging_utils as lutils
 import string_utils as sutils
+import unittest_utils as uu
 
 
 class TestLoggingUtils(unittest.TestCase):
@@ -17,26 +18,37 @@ class TestLoggingUtils(unittest.TestCase):
         secret_message = f'This is a test, {unique_suffix}.'
 
         with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile1:
-            with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile2:
-                with contextlib.redirect_stdout(tmpfile1):
-                    with lutils.OutputContext(
-                            lutils.OutputMultiplexer.Destination.FILENAME |
-                            lutils.OutputMultiplexer.Destination.FILEHANDLE |
-                            lutils.OutputMultiplexer.Destination.STDOUT,
-                            filename = filename,
-                            handle = tmpfile2,
-                    ) as mplex:
-                        mplex.print(secret_message, end='')
-                    with open(filename, 'r') as rf:
-                        self.assertEqual(rf.readline(), secret_message)
-                tmpfile2.seek(0)
-                tmp = tmpfile2.readline()
-                self.assertEqual(tmp, secret_message)
+            with uu.RecordStdout() as record:
+                with lutils.OutputMultiplexerContext(
+                        lutils.OutputMultiplexer.Destination.FILENAMES |
+                        lutils.OutputMultiplexer.Destination.FILEHANDLES |
+                        lutils.OutputMultiplexer.Destination.LOG_INFO,
+                        filenames = [filename, '/dev/null'],
+                        handles = [tmpfile1, sys.stdout],
+                ) as mplex:
+                    mplex.print(secret_message, end='')
+
+                # Make sure it was written to the filename.
+                with open(filename, 'r') as rf:
+                    self.assertEqual(rf.readline(), secret_message)
+                os.unlink(filename)
+
+            # Make sure it was written to stdout.
+            tmp = record().readline()
+            self.assertEqual(tmp, secret_message)
+
+            # Make sure it was written to the filehandle.
             tmpfile1.seek(0)
             tmp = tmpfile1.readline()
             self.assertEqual(tmp, secret_message)
 
+    def test_record_streams(self):
+        with uu.RecordMultipleStreams(sys.stderr, sys.stdout) as record:
+            print("This is a test!")
+            print("This is one too.", file=sys.stderr)
+        self.assertEqual(record().readlines(),
+                         ["This is a test!\n", "This is one too.\n"])
+
 
 if __name__ == '__main__':
-    unittest.main = bootstrap.initialize(unittest.main)
     unittest.main()
index e7090bcc6237e50d9d1f982c5db52f0e6f77bd0f..5987da6ba38194bf639b7382347ed599ff5bc0e9 100644 (file)
@@ -6,17 +6,16 @@ so that we getLogger config, commandline args, logging control,
 etc... this works fine but it's a little hacky so caveat emptor.
 """
 
+import contextlib
 import functools
 import inspect
-import io
 import logging
 import pickle
 import random
 import statistics
-import sys
 import time
 import tempfile
-from typing import Callable, Iterable
+from typing import Callable
 import unittest
 
 import bootstrap
@@ -152,3 +151,70 @@ def breakpoint():
     pdb.set_trace()
 
 
+class RecordStdout(object):
+    """
+        with uu.RecordStdout() as record:
+            print("This is a test!")
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stdout(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return True
+
+
+class RecordStderr(object):
+    """
+        with uu.RecordStderr() as record:
+            print("This is a test!", file=sys.stderr)
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stderr(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return True
+
+
+class RecordMultipleStreams(object):
+    """
+        with uu.RecordStreams(sys.stderr, sys.stdout) as record:
+            print("This is a test!")
+            print("This is one too.", file=sys.stderr)
+
+        print(record().readlines())
+    """
+    def __init__(self, *files) -> None:
+        self.files = [*files]
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.saved_writes = []
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        for f in self.files:
+            self.saved_writes.append(f.write)
+            f.write = self.destination.write
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        for f in self.files:
+            f.write = self.saved_writes.pop()
+        self.destination.seek(0)