Update tests / test harness.
[python_utils.git] / unittest_utils.py
index 99ac81d32b3284fc8257d750b193fa57564cebb4..2dc8cfe231e5aad03d877fac8116472e0c4b6c3d 100644 (file)
@@ -1,11 +1,13 @@
 #!/usr/bin/env python3
 
 """Helpers for unittests.  Note that when you import this we
-automatically wrap unittest.main() with a call to bootstrap.initialize
-so that we getLogger config, commandline args, logging control,
-etc... this works fine but it's a little hacky so caveat emptor.
+   automatically wrap unittest.main() with a call to
+   bootstrap.initialize so that we getLogger config, commandline args,
+   logging control, etc... this works fine but it's a little hacky so
+   caveat emptor.
 """
 
+import contextlib
 import functools
 import inspect
 import logging
@@ -13,6 +15,7 @@ import pickle
 import random
 import statistics
 import time
+import tempfile
 from typing import Callable
 import unittest
 
@@ -142,3 +145,77 @@ def check_all_methods_for_perf_regressions(prefix='test_'):
                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
         return cls
     return decorate_the_testcase
+
+
+def breakpoint():
+    import pdb
+    pdb.set_trace()
+
+
+class RecordStdout(object):
+    """
+        with uu.RecordStdout() as record:
+            print("This is a test!")
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stdout(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordStderr(object):
+    """
+        with uu.RecordStderr() as record:
+            print("This is a test!", file=sys.stderr)
+        print({record().readline()})
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stderr(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordMultipleStreams(object):
+    """
+        with uu.RecordStreams(sys.stderr, sys.stdout) as record:
+            print("This is a test!")
+            print("This is one too.", file=sys.stderr)
+
+        print(record().readlines())
+    """
+    def __init__(self, *files) -> None:
+        self.files = [*files]
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.saved_writes = []
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        for f in self.files:
+            self.saved_writes.append(f.write)
+            f.write = self.destination.write
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        for f in self.files:
+            f.write = self.saved_writes.pop()
+        self.destination.seek(0)