Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / decorator_utils.py
index 9b848ed792144919b863b20c82e846bcd509bbe8..a5c5afecb34c005d16a351cc703f44dc561b567d 100644 (file)
@@ -14,14 +14,13 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
 import warnings
 import warnings
+from typing import Any, Callable, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
-
 logger = logging.getLogger(__name__)
 
 
 logger = logging.getLogger(__name__)
 
 
@@ -48,6 +47,7 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
     return wrapper_timer
 
 
@@ -75,6 +75,7 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
     return wrapper_invocation_logged
 
 
@@ -152,7 +153,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 )
                 cv.notify()
             return ret
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
     return wrapper_rate_limited
 
 
@@ -188,6 +191,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
     return wrapper_debug_args
 
 
@@ -217,7 +221,8 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
     return wrapper_debug_count_calls
 
 
@@ -251,21 +256,19 @@ def delay(
     True
 
     """
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
                 time.sleep(seconds)
             return retval
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
         return wrapper_delay
 
     if _func is None:
@@ -350,19 +353,19 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = dict()  # type: ignore
     return wrapper_memoized
 
 
     return wrapper_memoized
 
 
@@ -416,7 +419,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
         return f_retry
+
     return deco_retry
 
 
     return deco_retry
 
 
@@ -443,7 +448,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     3
     >>> dur > 2.0
     True
     3
     >>> dur > 2.0
     True
-    >>> dur < 2.2
+    >>> dur < 2.3
     True
 
     """
     True
 
     """
@@ -475,13 +480,15 @@ def deprecated(func):
     when the function is used.
 
     """
     when the function is used.
 
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
     return wrapper_deprecated
 
 
@@ -507,7 +514,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -535,9 +541,9 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception()
     else:
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
 
 
 def _target(queue, function, *args, **kwargs):
@@ -555,10 +561,10 @@ def _target(queue, function, *args, **kwargs):
 
 
 class _Timeout(object):
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
     """
 
     def __init__(
@@ -635,13 +641,28 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
         import thread_utils
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
@@ -682,37 +703,19 @@ def timeout(
     return decorate
 
 
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+def synchronized(lock):
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
@@ -727,10 +730,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
+
         return _call_with_sample_rate
         return _call_with_sample_rate
+
     return decorator
 
 
     return decorator
 
 
@@ -739,6 +742,7 @@ def decorate_matching_methods_with(decorator, acl=None):
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -747,10 +751,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
 
+    doctest.testmod()