Ran black code formatter on everything.
[python_utils.git] / decorator_utils.py
index 9b848ed792144919b863b20c82e846bcd509bbe8..daae64e75348e973dc8a27cf387faf7f404ef2b2 100644 (file)
@@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
@@ -75,10 +76,13 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
-def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+def rate_limited(
+    n_calls: int, *, per_period_in_seconds: float = 1.0
+) -> Callable:
     """Limit invocation of a wrapped function to n calls per period.
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
@@ -152,7 +156,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
@@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
@@ -213,10 +220,13 @@ def debug_count_calls(func: Callable) -> Callable:
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
-        msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        msg = (
+            f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        )
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
+
     wrapper_debug_count_calls.num_calls = 0
     return wrapper_debug_count_calls
 
@@ -251,6 +261,7 @@ def delay(
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
@@ -266,6 +277,7 @@ def delay(
                 )
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
@@ -350,6 +362,7 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
@@ -362,6 +375,7 @@ def memoized(func: Callable) -> Callable:
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
+
     wrapper_memoized.cache = dict()
     return wrapper_memoized
 
@@ -416,7 +430,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
+
     return deco_retry
 
 
@@ -443,7 +459,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     3
     >>> dur > 2.0
     True
-    >>> dur < 2.2
+    >>> dur < 2.3
     True
 
     """
@@ -475,13 +491,15 @@ def deprecated(func):
     when the function is used.
 
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
@@ -507,7 +525,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -535,9 +552,9 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception()
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
@@ -555,10 +572,10 @@ def _target(queue, function, *args, **kwargs):
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
@@ -635,13 +652,28 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
@@ -712,6 +744,7 @@ class rlocked(object):
                 self._entered = True
                 f(*args, **kwargs)
                 self._entered = False
+
         return _gatekeeper
 
 
@@ -730,7 +763,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
                 logger.debug(
                     f"@call_with_sample_rate skipping a call to {f.__name__}"
                 )
+
         return _call_with_sample_rate
+
     return decorator
 
 
@@ -739,6 +774,7 @@ def decorate_matching_methods_with(decorator, acl=None):
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -747,10 +783,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
+    doctest.testmod()