Ran black code formatter on everything.
[python_utils.git] / decorator_utils.py
index 9b848ed792144919b863b20c82e846bcd509bbe8..daae64e75348e973dc8a27cf387faf7f404ef2b2 100644 (file)
@@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
     return wrapper_timer
 
 
@@ -75,10 +76,13 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
     return wrapper_invocation_logged
 
 
-def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+def rate_limited(
+    n_calls: int, *, per_period_in_seconds: float = 1.0
+) -> Callable:
     """Limit invocation of a wrapped function to n calls per period.
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
     """Limit invocation of a wrapped function to n calls per period.
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
@@ -152,7 +156,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 )
                 cv.notify()
             return ret
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
     return wrapper_rate_limited
 
 
@@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
     return wrapper_debug_args
 
 
@@ -213,10 +220,13 @@ def debug_count_calls(func: Callable) -> Callable:
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
-        msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        msg = (
+            f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        )
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
+
     wrapper_debug_count_calls.num_calls = 0
     return wrapper_debug_count_calls
 
     wrapper_debug_count_calls.num_calls = 0
     return wrapper_debug_count_calls
 
@@ -251,6 +261,7 @@ def delay(
     True
 
     """
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
@@ -266,6 +277,7 @@ def delay(
                 )
                 time.sleep(seconds)
             return retval
                 )
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
         return wrapper_delay
 
     if _func is None:
@@ -350,6 +362,7 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
@@ -362,6 +375,7 @@ def memoized(func: Callable) -> Callable:
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
+
     wrapper_memoized.cache = dict()
     return wrapper_memoized
 
     wrapper_memoized.cache = dict()
     return wrapper_memoized
 
@@ -416,7 +430,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
         return f_retry
+
     return deco_retry
 
 
     return deco_retry
 
 
@@ -443,7 +459,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     3
     >>> dur > 2.0
     True
     3
     >>> dur > 2.0
     True
-    >>> dur < 2.2
+    >>> dur < 2.3
     True
 
     """
     True
 
     """
@@ -475,13 +491,15 @@ def deprecated(func):
     when the function is used.
 
     """
     when the function is used.
 
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
     return wrapper_deprecated
 
 
@@ -507,7 +525,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -535,9 +552,9 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception()
     else:
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
 
 
 def _target(queue, function, *args, **kwargs):
@@ -555,10 +572,10 @@ def _target(queue, function, *args, **kwargs):
 
 
 class _Timeout(object):
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
     """
 
     def __init__(
@@ -635,13 +652,28 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
         import thread_utils
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
@@ -712,6 +744,7 @@ class rlocked(object):
                 self._entered = True
                 f(*args, **kwargs)
                 self._entered = False
                 self._entered = True
                 f(*args, **kwargs)
                 self._entered = False
+
         return _gatekeeper
 
 
         return _gatekeeper
 
 
@@ -730,7 +763,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
                 logger.debug(
                     f"@call_with_sample_rate skipping a call to {f.__name__}"
                 )
                 logger.debug(
                     f"@call_with_sample_rate skipping a call to {f.__name__}"
                 )
+
         return _call_with_sample_rate
         return _call_with_sample_rate
+
     return decorator
 
 
     return decorator
 
 
@@ -739,6 +774,7 @@ def decorate_matching_methods_with(decorator, acl=None):
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -747,10 +783,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
 
+    doctest.testmod()