Cleanup unittest_utils trying to get clean lint.
authorScott Gasch <[email protected]>
Wed, 9 Feb 2022 03:59:23 +0000 (19:59 -0800)
committerScott Gasch <[email protected]>
Wed, 9 Feb 2022 03:59:23 +0000 (19:59 -0800)
unittest_utils.py

index 70e588e2fa8025b2a70941b9837c78fb3f65421c..88e41954811de26629b405cb4ee7255d8bbebc62 100644 (file)
@@ -20,7 +20,7 @@ import time
 import unittest
 import warnings
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Literal, Optional
 
 import sqlalchemy as sa
 
@@ -265,14 +265,7 @@ def check_all_methods_for_perf_regressions(prefix='test_'):
     return decorate_the_testcase
 
 
-def breakpoint():
-    """Hard code a breakpoint somewhere; drop into pdb."""
-    import pdb
-
-    pdb.set_trace()
-
-
-class RecordStdout(object):
+class RecordStdout(contextlib.AbstractContextManager):
     """
     Record what is emitted to stdout.
 
@@ -284,6 +277,7 @@ class RecordStdout(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout] = None
 
@@ -293,14 +287,14 @@ class RecordStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordStderr(object):
+class RecordStderr(contextlib.AbstractContextManager):
     """
     Record what is emitted to stderr.
 
@@ -313,6 +307,7 @@ class RecordStderr(object):
     """
 
     def __init__(self) -> None:
+        super().__init__()
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
 
@@ -322,19 +317,20 @@ class RecordStderr(object):
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         assert self.recorder is not None
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None
+        return False
 
 
-class RecordMultipleStreams(object):
+class RecordMultipleStreams(contextlib.AbstractContextManager):
     """
     Record the output to more than one stream.
     """
 
     def __init__(self, *files) -> None:
+        super().__init__()
         self.files = [*files]
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
         self.saved_writes: List[Callable[..., Any]] = []
@@ -345,11 +341,11 @@ class RecordMultipleStreams(object):
             f.write = self.destination.write
         return lambda: self.destination
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         for f in self.files:
             f.write = self.saved_writes.pop()
         self.destination.seek(0)
-        return None
+        return False
 
 
 if __name__ == '__main__':