Make smart futures avoid polling.
[python_utils.git] / unittest_utils.py
index 99ac81d32b3284fc8257d750b193fa57564cebb4..bb1a9b432f49d7886b99f39418339e9f62e0cee1 100644 (file)
@@ -1,11 +1,13 @@
 #!/usr/bin/env python3
 
 """Helpers for unittests.  Note that when you import this we
-automatically wrap unittest.main() with a call to bootstrap.initialize
-so that we getLogger config, commandline args, logging control,
-etc... this works fine but it's a little hacky so caveat emptor.
+   automatically wrap unittest.main() with a call to
+   bootstrap.initialize so that we getLogger config, commandline args,
+   logging control, etc... this works fine but it's a little hacky so
+   caveat emptor.
 """
 
+import contextlib
 import functools
 import inspect
 import logging
@@ -13,6 +15,7 @@ import pickle
 import random
 import statistics
 import time
+import tempfile
 from typing import Callable
 import unittest
 
@@ -53,13 +56,14 @@ _db = '/home/scott/.python_unittest_performance_db'
 
 
 def check_method_for_perf_regressions(func: Callable) -> Callable:
-    """This is meant to be used on a method in a class that subclasses
+    """
+    This is meant to be used on a method in a class that subclasses
     unittest.TestCase.  When thus decorated it will time the execution
     of the code in the method, compare it with a database of
     historical perfmance, and fail the test with a perf-related
     message if it has become too slow.
-    """
 
+    """
     def load_known_test_performance_characteristics():
         with open(_db, 'rb') as f:
             return pickle.load(f)
@@ -104,7 +108,7 @@ def check_method_for_perf_regressions(func: Callable) -> Callable:
             )
         else:
             stdev = statistics.stdev(hist)
-            limit = hist[-1] + stdev * 3
+            limit = hist[-1] + stdev * 5
             logger.debug(
                 f'Max acceptable performace for {func.__name__} is {limit:f}s'
             )
@@ -114,12 +118,14 @@ def check_method_for_perf_regressions(func: Callable) -> Callable:
             ):
                 msg = f'''{func_id} performance has regressed unacceptably.
 {hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
-It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample.
+It just ran in {run_time:f}s which is >5 stdevs slower than the slowest sample.
 Here is the current, full db perf timing distribution:
 
-{hist}'''
-                slf = args[0]
+'''
+                for x in hist:
+                    msg += f'{x:f}\n'
                 logger.error(msg)
+                slf = args[0]
                 slf.fail(msg)
             else:
                 hist.append(run_time)
@@ -142,3 +148,87 @@ def check_all_methods_for_perf_regressions(prefix='test_'):
                     logger.debug(f'Wrapping {cls.__name__}:{name}.')
         return cls
     return decorate_the_testcase
+
+
+def breakpoint():
+    """Hard code a breakpoint somewhere; drop into pdb."""
+    import pdb
+    pdb.set_trace()
+
+
+class RecordStdout(object):
+    """
+    Record what is emitted to stdout.
+
+    >>> with RecordStdout() as record:
+    ...     print("This is a test!")
+    >>> print({record().readline()})
+    {'This is a test!\\n'}
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stdout(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordStderr(object):
+    """
+    Record what is emitted to stderr.
+
+    >>> import sys
+    >>> with RecordStderr() as record:
+    ...     print("This is a test!", file=sys.stderr)
+    >>> print({record().readline()})
+    {'This is a test!\\n'}
+    """
+
+    def __init__(self) -> None:
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        self.recorder = contextlib.redirect_stderr(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None
+
+
+class RecordMultipleStreams(object):
+    """
+    Record the output to more than one stream.
+    """
+
+    def __init__(self, *files) -> None:
+        self.files = [*files]
+        self.destination = tempfile.SpooledTemporaryFile(mode='r+')
+        self.saved_writes = []
+
+    def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
+        for f in self.files:
+            self.saved_writes.append(f.write)
+            f.write = self.destination.write
+        return lambda: self.destination
+
+    def __exit__(self, *args) -> bool:
+        for f in self.files:
+            f.write = self.saved_writes.pop()
+        self.destination.seek(0)
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()