Reduce the doctest lease duration...
[python_utils.git] / decorator_utils.py
index eb5a0c9b9ff581fc3c4398e1e1ab186ec3d00978..4615fec6f8960e0083ce48546ba9421c25243d42 100644 (file)
@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-"""Decorators."""
+# © Copyright 2021-2022, Scott Gasch
+# Portions (marked) below retain the original author's copyright.
+
+"""Useful(?) decorators."""
 
 import enum
 import functools
 
 import enum
 import functools
@@ -14,14 +17,13 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
 import warnings
 import warnings
+from typing import Any, Callable, List, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
-
 logger = logging.getLogger(__name__)
 
 
 logger = logging.getLogger(__name__)
 
 
@@ -31,7 +33,7 @@ def timed(func: Callable) -> Callable:
     >>> @timed
     ... def foo():
     ...     import time
     >>> @timed
     ... def foo():
     ...     import time
-    ...     time.sleep(0.1)
+    ...     time.sleep(0.01)
 
     >>> foo()  # doctest: +ELLIPSIS
     Finished foo in ...
 
     >>> foo()  # doctest: +ELLIPSIS
     Finished foo in ...
@@ -48,11 +50,12 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function.
+    """Log the call of a function on stdout and the info log.
 
     >>> @invocation_logged
     ... def foo():
 
     >>> @invocation_logged
     ... def foo():
@@ -75,13 +78,14 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
     return wrapper_invocation_logged
 
 
 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
-    """Limit invocation of a wrapped function to n calls per period.
+    """Limit invocation of a wrapped function to n calls per time period.
     Thread safe.  In testing this was relatively fair with multiple
     Thread safe.  In testing this was relatively fair with multiple
-    threads using it though that hasn't been measured.
+    threads using it though that hasn't been measured in detail.
 
     >>> import time
     >>> import decorator_utils
 
     >>> import time
     >>> import decorator_utils
@@ -132,7 +136,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
-            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
@@ -144,15 +148,17 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                     ):
                         break
             with cv:
                     ):
                         break
             with cv:
-                logger.debug(f'@{time.time()}> calling it...')
+                logger.debug('@%.4f> calling it...', time.time())
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
-                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                    '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
                 )
                 cv.notify()
             return ret
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
     return wrapper_rate_limited
 
 
@@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
     return wrapper_debug_args
 
 
@@ -217,11 +224,17 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
 class DelayWhen(enum.IntEnum):
     return wrapper_debug_count_calls
 
 
 class DelayWhen(enum.IntEnum):
+    """When should we delay: before or after calling the function (or
+    both)?
+
+    """
+
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -233,9 +246,7 @@ def delay(
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
-    """Delay the execution of a function by sleeping before and/or after.
-
-    Slow down a function by inserting a delay before and/or after its
+    """Slow down a function by inserting a delay before and/or after its
     invocation.
 
     >>> import time
     invocation.
 
     >>> import time
@@ -251,21 +262,19 @@ def delay(
     True
 
     """
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             return retval
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
         return wrapper_delay
 
     if _func is None:
@@ -287,9 +296,7 @@ class _SingletonWrapper:
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
-        logger.debug(
-            f"@singleton returning global instance of {self.__wrapped__.__name__}"
-        )
+        logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
@@ -321,8 +328,9 @@ def memoized(func: Callable) -> Callable:
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
-    call.  Consider also: functools.lru_cache for a more advanced
-    implementation.
+    call.  Consider also: functools.cache for a more advanced
+    implementation.  See:
+    https://docs.python.org/3/library/functools.html#functools.cache
 
     >>> import time
 
 
     >>> import time
 
@@ -350,19 +358,19 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
             wrapper_memoized.cache[cache_key] = value
         else:
             wrapper_memoized.cache[cache_key] = value
         else:
-            logger.debug(f"Returning memoized value for {func.__name__}")
+            logger.debug('Returning memoized value for %s', {func.__name__})
         return wrapper_memoized.cache[cache_key]
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = {}  # type: ignore
     return wrapper_memoized
 
 
     return wrapper_memoized
 
 
@@ -373,17 +381,20 @@ def retry_predicate(
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
-    """Retries a function or method up to a certain number of times
-    with a prescribed initial delay period and backoff rate.
-
-    tries is the maximum number of attempts to run the function.
-    delay_sec sets the initial delay period in seconds.
-    backoff is a multiplied (must be >1) used to modify the delay.
-    predicate is a function that will be passed the retval of the
-    decorated function and must return True to stop or False to
-    retry.
-
+    """Retries a function or method up to a certain number of times with a
+    prescribed initial delay period and backoff rate (multiplier).
+
+    Args:
+        tries: the maximum number of attempts to run the function
+        delay_sec: sets the initial delay period in seconds
+        backoff: a multiplier (must be >=1.0) used to modify the
+            delay at each subsequent invocation
+        predicate: a Callable that will be passed the retval of
+            the decorated function and must return True to indicate
+            that we should stop calling or False to indicate a retry
+            is necessary
     """
     """
+
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
@@ -404,7 +415,7 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
-            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
+            logger.debug('deco_retry: will make up to %d attempts...', mtries)
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
@@ -416,7 +427,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
         return f_retry
+
     return deco_retry
 
 
     return deco_retry
 
 
@@ -427,7 +440,6 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     >>> import time
 
     >>> counter = 0
     >>> import time
 
     >>> counter = 0
-
     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
     ... def foo():
     ...     global counter
     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
     ... def foo():
     ...     global counter
@@ -459,8 +471,8 @@ def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
     """Another helper for @retry_predicate above.  Retries up to N
     times so long as the wrapped function returns None with a delay
     between each retry and a backoff that can increase the delay.
     """Another helper for @retry_predicate above.  Retries up to N
     times so long as the wrapped function returns None with a delay
     between each retry and a backoff that can increase the delay.
-
     """
     """
+
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -473,8 +485,8 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
-
     """
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
@@ -482,6 +494,7 @@ def deprecated(func):
         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
     return wrapper_deprecated
 
 
@@ -497,7 +510,7 @@ def thunkify(func):
         wait_event = threading.Event()
 
         result = [None]
         wait_event = threading.Event()
 
         result = [None]
-        exc = [False, None]
+        exc: List[Any] = [False, None]
 
         def worker_func():
             try:
 
         def worker_func():
             try:
@@ -514,6 +527,7 @@ def thunkify(func):
         def thunk():
             wait_event.wait()
             if exc[0]:
         def thunk():
             wait_event.wait()
             if exc[0]:
+                assert exc[1]
                 raise exc[1][0](exc[1][1])
             return result[0]
 
                 raise exc[1][0](exc[1][1])
             return result[0]
 
@@ -534,7 +548,7 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise Exception()
+        raise Exception(exception)
     else:
         raise Exception(error_message)
 
     else:
         raise Exception(error_message)
 
@@ -587,9 +601,7 @@ class _Timeout(object):
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
-        self.__process = multiprocessing.Process(
-            target=_target, args=args, kwargs=kwargs
-        )
+        self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
@@ -619,6 +631,7 @@ class _Timeout(object):
             if flag:
                 return load
             raise load
             if flag:
                 return load
             raise load
+        return None
 
 
 def timeout(
 
 
 def timeout(
@@ -634,6 +647,13 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
+    Beware that an @timeout on a function inside a module will be
+    evaluated at module load time and not when the wrapped function is
+    invoked.  This can lead to problems when relying on the automatic
+    main thread detection code (use_signals=None, the default) since
+    the import probably happens on the main thread and the invocation
+    can happen on a different thread (which can't use signals).
+
     Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
@@ -655,12 +675,13 @@ def timeout(
     """
     if use_signals is None:
         import thread_utils
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         if use_signals:
 
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         if use_signals:
 
-            def handler(signum, frame):
+            def handler(unused_signum, unused_frame):
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
@@ -685,9 +706,7 @@ def timeout(
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
-                timeout_wrapper = _Timeout(
-                    function, timeout_exception, error_message, seconds
-                )
+                timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
@@ -695,39 +714,31 @@ def timeout(
     return decorate
 
 
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
-
+def synchronized(lock):
+    """Emulates java's synchronized keyword: given a lock, require that
+    threads take that lock (or wait) before invoking the wrapped
+    function and automatically releases the lock afterwards.
+    """
 
 
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
 
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
         return _gatekeeper
 
         return _gatekeeper
 
+    return wrap
+
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
+    """Calls the wrapped function probabilistically given a rate between
+    0.0 and 1.0 inclusive (0% probability and 100% probability).
+    """
+
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
@@ -739,18 +750,20 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
+                return None
+
         return _call_with_sample_rate
         return _call_with_sample_rate
+
     return decorator
 
 
 def decorate_matching_methods_with(decorator, acl=None):
     return decorator
 
 
 def decorate_matching_methods_with(decorator, acl=None):
-    """Apply decorator to all methods in a class whose names begin with
-    prefix.  If prefix is None (default), decorate all methods in the
-    class.
+    """Apply the given decorator to all methods in a class whose names
+    begin with prefix.  If prefix is None (default), decorate all
+    methods in the class.
     """
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -759,10 +772,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
 
+    doctest.testmod()