Ran black code formatter on everything.
[python_utils.git] / decorator_utils.py
index d5349cc31aed2e74352822c3175ace022ab20e74..daae64e75348e973dc8a27cf387faf7f404ef2b2 100644 (file)
@@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
@@ -75,10 +76,13 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
-def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+def rate_limited(
+    n_calls: int, *, per_period_in_seconds: float = 1.0
+) -> Callable:
     """Limit invocation of a wrapped function to n calls per period.
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
@@ -152,7 +156,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
@@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
@@ -213,10 +220,13 @@ def debug_count_calls(func: Callable) -> Callable:
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
-        msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        msg = (
+            f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
+        )
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
+
     wrapper_debug_count_calls.num_calls = 0
     return wrapper_debug_count_calls
 
@@ -251,6 +261,7 @@ def delay(
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
@@ -266,6 +277,7 @@ def delay(
                 )
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
@@ -350,6 +362,7 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
@@ -362,6 +375,7 @@ def memoized(func: Callable) -> Callable:
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
+
     wrapper_memoized.cache = dict()
     return wrapper_memoized
 
@@ -416,7 +430,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
+
     return deco_retry
 
 
@@ -443,7 +459,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     3
     >>> dur > 2.0
     True
-    >>> dur < 2.2
+    >>> dur < 2.3
     True
 
     """
@@ -475,13 +491,15 @@ def deprecated(func):
     when the function is used.
 
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
@@ -507,7 +525,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -656,6 +673,7 @@ def timeout(
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
@@ -709,6 +727,7 @@ class non_reentrant_code(object):
                 self._entered = True
                 f(*args, **kwargs)
                 self._entered = False
+
         return _gatekeeper
 
 
@@ -725,6 +744,7 @@ class rlocked(object):
                 self._entered = True
                 f(*args, **kwargs)
                 self._entered = False
+
         return _gatekeeper
 
 
@@ -743,7 +763,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
                 logger.debug(
                     f"@call_with_sample_rate skipping a call to {f.__name__}"
                 )
+
         return _call_with_sample_rate
+
     return decorator
 
 
@@ -752,6 +774,7 @@ def decorate_matching_methods_with(decorator, acl=None):
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -760,10 +783,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
+    doctest.testmod()