import unittest
import warnings
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Literal, Optional
import sqlalchemy as sa
return decorate_the_testcase
-def breakpoint():
- """Hard code a breakpoint somewhere; drop into pdb."""
- import pdb
-
- pdb.set_trace()
-
-
-class RecordStdout(object):
+class RecordStdout(contextlib.AbstractContextManager):
"""
Record what is emitted to stdout.
"""
def __init__(self) -> None:
+ super().__init__()
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout] = None
self.recorder.__enter__()
return lambda: self.destination
- def __exit__(self, *args) -> Optional[bool]:
+ def __exit__(self, *args) -> Literal[False]:
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
- return None
+ return False
-class RecordStderr(object):
+class RecordStderr(contextlib.AbstractContextManager):
"""
Record what is emitted to stderr.
"""
def __init__(self) -> None:
+ super().__init__()
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
self.recorder.__enter__()
return lambda: self.destination
- def __exit__(self, *args) -> Optional[bool]:
+ def __exit__(self, *args) -> Literal[False]:
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
- return None
+ return False
-class RecordMultipleStreams(object):
+class RecordMultipleStreams(contextlib.AbstractContextManager):
"""
Record the output to more than one stream.
"""
def __init__(self, *files) -> None:
+ super().__init__()
self.files = [*files]
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.saved_writes: List[Callable[..., Any]] = []
f.write = self.destination.write
return lambda: self.destination
- def __exit__(self, *args) -> Optional[bool]:
+ def __exit__(self, *args) -> Literal[False]:
for f in self.files:
f.write = self.saved_writes.pop()
self.destination.seek(0)
- return None
+ return False
if __name__ == '__main__':