Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / unittest_utils.py
index 70e588e2fa8025b2a70941b9837c78fb3f65421c..28b577e2086af4ff20647d05cd9be24761839d6d 100644 (file)
@@ -1,10 +1,12 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
 """Helpers for unittests.  Note that when you import this we
-   automatically wrap unittest.main() with a call to
-   bootstrap.initialize so that we getLogger config, commandline args,
-   logging control, etc... this works fine but it's a little hacky so
-   caveat emptor.
+automatically wrap unittest.main() with a call to bootstrap.initialize
+so that we getLogger config, commandline args, logging control,
+etc... this works fine but it's a little hacky so caveat emptor.
+
 """
 
 import contextlib
@@ -20,7 +22,7 @@ import time
 import unittest
 import warnings
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Literal, Optional
 
 import sqlalchemy as sa
 
@@ -265,14 +267,7 @@ def check_all_methods_for_perf_regressions(prefix='test_'):
     return decorate_the_testcase
 
 
-def breakpoint():
-    """Hard code a breakpoint somewhere; drop into pdb."""
-    import pdb
-
-    pdb.set_trace()
-
-
-class RecordStdout(object):
+class RecordStdout(contextlib.AbstractContextManager):
     """
     Record what is emitted to stdout.
 
@@ -284,6 +279,7 @@ class RecordStdout(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout] = None
 
@@ -293,14 +289,14 @@ class RecordStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordStderr(object):
+class RecordStderr(contextlib.AbstractContextManager):
     """
     Record what is emitted to stderr.
 
@@ -313,6 +309,7 @@ class RecordStderr(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
 
@@ -322,19 +319,20 @@ class RecordStderr(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordMultipleStreams(object):
+class RecordMultipleStreams(contextlib.AbstractContextManager):
     """
     Record the output to more than one stream.
     """
 
     def __init__(self, *files) -> None:
+        super().__init__()
         self.files = [*files]
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.saved_writes: List[Callable[..., Any]] = []
@@ -345,11 +343,11 @@ class RecordMultipleStreams(object):
             f.write = self.destination.write
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         for f in self.files:
             f.write = self.saved_writes.pop()
         self.destination.seek(0)
-        return None
+        return False
 
 
 if __name__ == '__main__':