projects
/
python_utils.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Cleanup unittest_utils trying to get clean lint.
[python_utils.git]
/
unittest_utils.py
diff --git
a/unittest_utils.py
b/unittest_utils.py
index 70e588e2fa8025b2a70941b9837c78fb3f65421c..88e41954811de26629b405cb4ee7255d8bbebc62 100644
(file)
--- a/
unittest_utils.py
+++ b/
unittest_utils.py
@@
-20,7
+20,7
@@
import time
import unittest
import warnings
from abc import ABC, abstractmethod
import unittest
import warnings
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List,
Literal,
Optional
import sqlalchemy as sa
import sqlalchemy as sa
@@
-265,14
+265,7
@@
def check_all_methods_for_perf_regressions(prefix='test_'):
return decorate_the_testcase
return decorate_the_testcase
-def breakpoint():
- """Hard code a breakpoint somewhere; drop into pdb."""
- import pdb
-
- pdb.set_trace()
-
-
-class RecordStdout(object):
+class RecordStdout(contextlib.AbstractContextManager):
"""
Record what is emitted to stdout.
"""
Record what is emitted to stdout.
@@
-284,6
+277,7
@@
class RecordStdout(object):
"""
def __init__(self) -> None:
"""
def __init__(self) -> None:
+ super().__init__()
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout] = None
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout] = None
@@
-293,14
+287,14
@@
class RecordStdout(object):
self.recorder.__enter__()
return lambda: self.destination
self.recorder.__enter__()
return lambda: self.destination
- def __exit__(self, *args) ->
Optional[bool
]:
+ def __exit__(self, *args) ->
Literal[False
]:
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
- return
Non
e
+ return
Fals
e
-class RecordStderr(
object
):
+class RecordStderr(
contextlib.AbstractContextManager
):
"""
Record what is emitted to stderr.
"""
Record what is emitted to stderr.
@@
-313,6
+307,7
@@
class RecordStderr(object):
"""
def __init__(self) -> None:
"""
def __init__(self) -> None:
+ super().__init__()
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
@@
-322,19
+317,20
@@
class RecordStderr(object):
self.recorder.__enter__()
return lambda: self.destination
self.recorder.__enter__()
return lambda: self.destination
- def __exit__(self, *args) ->
Optional[bool
]:
+ def __exit__(self, *args) ->
Literal[False
]:
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
assert self.recorder is not None
self.recorder.__exit__(*args)
self.destination.seek(0)
- return
Non
e
+ return
Fals
e
-class RecordMultipleStreams(
object
):
+class RecordMultipleStreams(
contextlib.AbstractContextManager
):
"""
Record the output to more than one stream.
"""
def __init__(self, *files) -> None:
"""
Record the output to more than one stream.
"""
def __init__(self, *files) -> None:
+ super().__init__()
self.files = [*files]
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.saved_writes: List[Callable[..., Any]] = []
self.files = [*files]
self.destination = tempfile.SpooledTemporaryFile(mode='r+')
self.saved_writes: List[Callable[..., Any]] = []
@@
-345,11
+341,11
@@
class RecordMultipleStreams(object):
f.write = self.destination.write
return lambda: self.destination
f.write = self.destination.write
return lambda: self.destination
- def __exit__(self, *args) ->
Optional[bool
]:
+ def __exit__(self, *args) ->
Literal[False
]:
for f in self.files:
f.write = self.saved_writes.pop()
self.destination.seek(0)
for f in self.files:
f.write = self.saved_writes.pop()
self.destination.seek(0)
- return
Non
e
+ return
Fals
e
if __name__ == '__main__':
if __name__ == '__main__':