Reduce the doctest lease duration...
[python_utils.git] / decorator_utils.py
index daae64e75348e973dc8a27cf387faf7f404ef2b2..4615fec6f8960e0083ce48546ba9421c25243d42 100644 (file)
@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-"""Decorators."""
+# © Copyright 2021-2022, Scott Gasch
+# Portions (marked) below retain the original author's copyright.
+
+"""Useful(?) decorators."""
 
 import enum
 import functools
 
 import enum
 import functools
@@ -14,14 +17,13 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
 import warnings
 import warnings
+from typing import Any, Callable, List, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
-
 logger = logging.getLogger(__name__)
 
 
 logger = logging.getLogger(__name__)
 
 
@@ -31,7 +33,7 @@ def timed(func: Callable) -> Callable:
     >>> @timed
     ... def foo():
     ...     import time
     >>> @timed
     ... def foo():
     ...     import time
-    ...     time.sleep(0.1)
+    ...     time.sleep(0.01)
 
     >>> foo()  # doctest: +ELLIPSIS
     Finished foo in ...
 
     >>> foo()  # doctest: +ELLIPSIS
     Finished foo in ...
@@ -53,7 +55,7 @@ def timed(func: Callable) -> Callable:
 
 
 def invocation_logged(func: Callable) -> Callable:
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function.
+    """Log the call of a function on stdout and the info log.
 
     >>> @invocation_logged
     ... def foo():
 
     >>> @invocation_logged
     ... def foo():
@@ -80,12 +82,10 @@ def invocation_logged(func: Callable) -> Callable:
     return wrapper_invocation_logged
 
 
     return wrapper_invocation_logged
 
 
-def rate_limited(
-    n_calls: int, *, per_period_in_seconds: float = 1.0
-) -> Callable:
-    """Limit invocation of a wrapped function to n calls per period.
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+    """Limit invocation of a wrapped function to n calls per time period.
     Thread safe.  In testing this was relatively fair with multiple
     Thread safe.  In testing this was relatively fair with multiple
-    threads using it though that hasn't been measured.
+    threads using it though that hasn't been measured in detail.
 
     >>> import time
     >>> import decorator_utils
 
     >>> import time
     >>> import decorator_utils
@@ -136,7 +136,7 @@ def rate_limited(
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
-            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
@@ -148,11 +148,11 @@ def rate_limited(
                     ):
                         break
             with cv:
                     ):
                         break
             with cv:
-                logger.debug(f'@{time.time()}> calling it...')
+                logger.debug('@%.4f> calling it...', time.time())
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
-                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                    '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
                 )
                 cv.notify()
             return ret
                 )
                 cv.notify()
             return ret
@@ -220,18 +220,21 @@ def debug_count_calls(func: Callable) -> Callable:
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
-        msg = (
-            f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
-        )
+        msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
 
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
 
-    wrapper_debug_count_calls.num_calls = 0
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
 class DelayWhen(enum.IntEnum):
     return wrapper_debug_count_calls
 
 
 class DelayWhen(enum.IntEnum):
+    """When should we delay: before or after calling the function (or
+    both)?
+
+    """
+
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -243,9 +246,7 @@ def delay(
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
-    """Delay the execution of a function by sleeping before and/or after.
-
-    Slow down a function by inserting a delay before and/or after its
+    """Slow down a function by inserting a delay before and/or after its
     invocation.
 
     >>> import time
     invocation.
 
     >>> import time
@@ -266,15 +267,11 @@ def delay(
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             return retval
 
                 time.sleep(seconds)
             return retval
 
@@ -299,9 +296,7 @@ class _SingletonWrapper:
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
-        logger.debug(
-            f"@singleton returning global instance of {self.__wrapped__.__name__}"
-        )
+        logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
@@ -333,8 +328,9 @@ def memoized(func: Callable) -> Callable:
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
-    call.  Consider also: functools.lru_cache for a more advanced
-    implementation.
+    call.  Consider also: functools.cache for a more advanced
+    implementation.  See:
+    https://docs.python.org/3/library/functools.html#functools.cache
 
     >>> import time
 
 
     >>> import time
 
@@ -368,15 +364,13 @@ def memoized(func: Callable) -> Callable:
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
             wrapper_memoized.cache[cache_key] = value
         else:
             wrapper_memoized.cache[cache_key] = value
         else:
-            logger.debug(f"Returning memoized value for {func.__name__}")
+            logger.debug('Returning memoized value for %s', {func.__name__})
         return wrapper_memoized.cache[cache_key]
 
         return wrapper_memoized.cache[cache_key]
 
-    wrapper_memoized.cache = dict()
+    wrapper_memoized.cache = {}  # type: ignore
     return wrapper_memoized
 
 
     return wrapper_memoized
 
 
@@ -387,17 +381,20 @@ def retry_predicate(
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
-    """Retries a function or method up to a certain number of times
-    with a prescribed initial delay period and backoff rate.
-
-    tries is the maximum number of attempts to run the function.
-    delay_sec sets the initial delay period in seconds.
-    backoff is a multiplied (must be >1) used to modify the delay.
-    predicate is a function that will be passed the retval of the
-    decorated function and must return True to stop or False to
-    retry.
-
+    """Retries a function or method up to a certain number of times with a
+    prescribed initial delay period and backoff rate (multiplier).
+
+    Args:
+        tries: the maximum number of attempts to run the function
+        delay_sec: sets the initial delay period in seconds
+        backoff: a multiplier (must be >=1.0) used to modify the
+            delay at each subsequent invocation
+        predicate: a Callable that will be passed the retval of
+            the decorated function and must return True to indicate
+            that we should stop calling or False to indicate a retry
+            is necessary
     """
     """
+
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
@@ -418,7 +415,7 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
-            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
+            logger.debug('deco_retry: will make up to %d attempts...', mtries)
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
@@ -443,7 +440,6 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     >>> import time
 
     >>> counter = 0
     >>> import time
 
     >>> counter = 0
-
     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
     ... def foo():
     ...     global counter
     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
     ... def foo():
     ...     global counter
@@ -475,8 +471,8 @@ def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
     """Another helper for @retry_predicate above.  Retries up to N
     times so long as the wrapped function returns None with a delay
     between each retry and a backoff that can increase the delay.
     """Another helper for @retry_predicate above.  Retries up to N
     times so long as the wrapped function returns None with a delay
     between each retry and a backoff that can increase the delay.
-
     """
     """
+
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -489,7 +485,6 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
-
     """
 
     @functools.wraps(func)
     """
 
     @functools.wraps(func)
@@ -515,7 +510,7 @@ def thunkify(func):
         wait_event = threading.Event()
 
         result = [None]
         wait_event = threading.Event()
 
         result = [None]
-        exc = [False, None]
+        exc: List[Any] = [False, None]
 
         def worker_func():
             try:
 
         def worker_func():
             try:
@@ -532,6 +527,7 @@ def thunkify(func):
         def thunk():
             wait_event.wait()
             if exc[0]:
         def thunk():
             wait_event.wait()
             if exc[0]:
+                assert exc[1]
                 raise exc[1][0](exc[1][1])
             return result[0]
 
                 raise exc[1][0](exc[1][1])
             return result[0]
 
@@ -552,7 +548,7 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise Exception()
+        raise Exception(exception)
     else:
         raise Exception(error_message)
 
     else:
         raise Exception(error_message)
 
@@ -605,9 +601,7 @@ class _Timeout(object):
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
-        self.__process = multiprocessing.Process(
-            target=_target, args=args, kwargs=kwargs
-        )
+        self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
@@ -637,6 +631,7 @@ class _Timeout(object):
             if flag:
                 return load
             raise load
             if flag:
                 return load
             raise load
+        return None
 
 
 def timeout(
 
 
 def timeout(
@@ -652,6 +647,13 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
+    Beware that an @timeout on a function inside a module will be
+    evaluated at module load time and not when the wrapped function is
+    invoked.  This can lead to problems when relying on the automatic
+    main thread detection code (use_signals=None, the default) since
+    the import probably happens on the main thread and the invocation
+    can happen on a different thread (which can't use signals).
+
     Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
@@ -679,7 +681,7 @@ def timeout(
     def decorate(function):
         if use_signals:
 
     def decorate(function):
         if use_signals:
 
-            def handler(signum, frame):
+            def handler(unused_signum, unused_frame):
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
@@ -704,9 +706,7 @@ def timeout(
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
-                timeout_wrapper = _Timeout(
-                    function, timeout_exception, error_message, seconds
-                )
+                timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
@@ -714,41 +714,31 @@ def timeout(
     return decorate
 
 
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
+def synchronized(lock):
+    """Emulates java's synchronized keyword: given a lock, require that
+    threads take that lock (or wait) before invoking the wrapped
+    function and automatically releases the lock afterwards.
+    """
 
 
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
+    """Calls the wrapped function probabilistically given a rate between
+    0.0 and 1.0 inclusive (0% probability and 100% probability).
+    """
+
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
@@ -760,9 +750,8 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
+                return None
 
         return _call_with_sample_rate
 
 
         return _call_with_sample_rate
 
@@ -770,9 +759,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
 
 
 def decorate_matching_methods_with(decorator, acl=None):
 
 
 def decorate_matching_methods_with(decorator, acl=None):
-    """Apply decorator to all methods in a class whose names begin with
-    prefix.  If prefix is None (default), decorate all methods in the
-    class.
+    """Apply the given decorator to all methods in a class whose names
+    begin with prefix.  If prefix is None (default), decorate all
+    methods in the class.
     """
 
     def decorate_the_class(cls):
     """
 
     def decorate_the_class(cls):