Various changes.
[python_utils.git] / unittest_utils.py
index e7090bcc6237e50d9d1f982c5db52f0e6f77bd0f..2dc8cfe231e5aad03d877fac8116472e0c4b6c3d 100644 (file)
@@ -1,22 +1,22 @@
 #!/usr/bin/env python3
 
 """Helpers for unittests.  Note that when you import this we
-automatically wrap unittest.main() with a call to bootstrap.initialize
-so that we getLogger config, commandline args, logging control,
-etc... this works fine but it's a little hacky so caveat emptor.
+   automatically wrap unittest.main() with a call to
+   bootstrap.initialize so that we getLogger config, commandline args,
+   logging control, etc... this works fine but it's a little hacky so
+   caveat emptor.
 """
 
+import contextlib
 import functools
 import inspect
-import io
 import logging
 import pickle
 import random
 import statistics
-import sys
 import time
 import tempfile
-from typing import Callable, Iterable
+from typing import Callable
 import unittest
 
 import bootstrap
@@ -152,3 +152,70 @@ def breakpoint():
     pdb.set_trace()
 
 
+class RecordStdout(object):
+    """
+        with uu.RecordStdout() as record:
+            print("This is a test!")
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stdout(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordStderr(object):
+    """
+        with uu.RecordStderr() as record:
+            print("This is a test!", file=sys.stderr)
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stderr(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordMultipleStreams(object):
+    """
+        with uu.RecordStreams(sys.stderr, sys.stdout) as record:
+            print("This is a test!")
+            print("This is one too.", file=sys.stderr)
+
+        print(record().readlines())
+    """
+    def __init__(self, *files) -> None:
+        self.files = [*files]
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.saved_writes = []
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        for f in self.files:
+            self.saved_writes.append(f.write)
+            f.write = self.destination.write
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        for f in self.files:
+            f.write = self.saved_writes.pop()
+        self.destination.seek(0)