Cleanup config in preparation for zookeeper-based dynamic configs.
[python_utils.git] / unittest_utils.py
index 70e588e2fa8025b2a70941b9837c78fb3f65421c..a41aeb5d02108b7cd836cd0c9972262b114d0b76 100644 (file)
@@ -1,10 +1,15 @@
 #!/usr/bin/env python3
 
-"""Helpers for unittests.  Note that when you import this we
-   automatically wrap unittest.main() with a call to
-   bootstrap.initialize so that we getLogger config, commandline args,
-   logging control, etc... this works fine but it's a little hacky so
-   caveat emptor.
+# © Copyright 2021-2022, Scott Gasch
+
+"""Helpers for unittests.
+
+.. note::
+
+    When you import this we automatically wrap unittest.main()
+    with a call to bootstrap.initialize so that we getLogger
+    config, commandline args, logging control, etc... this works
+    fine but it's a little hacky so caveat emptor.
 """
 
 import contextlib
@@ -20,7 +25,7 @@ import time
 import unittest
 import warnings
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Literal, Optional
 
 import sqlalchemy as sa
 
@@ -265,14 +270,7 @@ def check_all_methods_for_perf_regressions(prefix='test_'):
     return decorate_the_testcase
 
 
-def breakpoint():
-    """Hard code a breakpoint somewhere; drop into pdb."""
-    import pdb
-
-    pdb.set_trace()
-
-
-class RecordStdout(object):
+class RecordStdout(contextlib.AbstractContextManager):
     """
     Record what is emitted to stdout.
 
@@ -284,6 +282,7 @@ class RecordStdout(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout] = None
 
@@ -293,14 +292,14 @@ class RecordStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordStderr(object):
+class RecordStderr(contextlib.AbstractContextManager):
     """
     Record what is emitted to stderr.
 
@@ -313,6 +312,7 @@ class RecordStderr(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
 
@@ -322,19 +322,20 @@ class RecordStderr(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordMultipleStreams(object):
+class RecordMultipleStreams(contextlib.AbstractContextManager):
     """
     Record the output to more than one stream.
     """
 
     def __init__(self, *files) -> None:
+        super().__init__()
         self.files = [*files]
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.saved_writes: List[Callable[..., Any]] = []
@@ -345,11 +346,11 @@ class RecordMultipleStreams(object):
             f.write = self.destination.write
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         for f in self.files:
             f.write = self.saved_writes.pop()
         self.destination.seek(0)
-        return None
+        return False
 
 
 if __name__ == '__main__':