Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / decorator_utils.py
index 07ad881f63a613de38d82d9a54babce92127b1b5..4615fec6f8960e0083ce48546ba9421c25243d42 100644 (file)
@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 
-"""Decorators."""
+# © Copyright 2021-2022, Scott Gasch
+# Portions (marked) below retain the original author's copyright.
+
+"""Useful(?) decorators."""
 
 import enum
 import functools
@@ -14,14 +17,13 @@ import sys
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
 import warnings
+from typing import Any, Callable, List, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -31,7 +33,7 @@ def timed(func: Callable) -> Callable:
     >>> @timed
     ... def foo():
     ...     import time
-    ...     time.sleep(0.1)
+    ...     time.sleep(0.01)
 
     >>> foo()  # doctest: +ELLIPSIS
     Finished foo in ...
@@ -48,11 +50,12 @@ def timed(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function.
+    """Log the call of a function on stdout and the info log.
 
     >>> @invocation_logged
     ... def foo():
@@ -75,13 +78,14 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
-    """Limit invocation of a wrapped function to n calls per period.
+    """Limit invocation of a wrapped function to n calls per time period.
     Thread safe.  In testing this was relatively fair with multiple
-    threads using it though that hasn't been measured.
+    threads using it though that hasn't been measured in detail.
 
     >>> import time
     >>> import decorator_utils
@@ -132,7 +136,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
-            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
@@ -144,15 +148,17 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                     ):
                         break
             with cv:
-                logger.debug(f'@{time.time()}> calling it...')
+                logger.debug('@%.4f> calling it...', time.time())
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
-                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                    '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
@@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
@@ -217,11 +224,17 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
 class DelayWhen(enum.IntEnum):
+    """When should we delay: before or after calling the function (or
+    both)?
+
+    """
+
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -233,9 +246,7 @@ def delay(
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
-    """Delay the execution of a function by sleeping before and/or after.
-
-    Slow down a function by inserting a delay before and/or after its
+    """Slow down a function by inserting a delay before and/or after its
     invocation.
 
     >>> import time
@@ -251,21 +262,19 @@ def delay(
     True
 
     """
+
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
@@ -287,9 +296,7 @@ class _SingletonWrapper:
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
-        logger.debug(
-            f"@singleton returning global instance of {self.__wrapped__.__name__}"
-        )
+        logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
@@ -321,8 +328,9 @@ def memoized(func: Callable) -> Callable:
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
-    call.  Consider also: functools.lru_cache for a more advanced
-    implementation.
+    call.  Consider also: functools.cache for a more advanced
+    implementation.  See:
+    https://docs.python.org/3/library/functools.html#functools.cache
 
     >>> import time
 
@@ -350,19 +358,19 @@ def memoized(func: Callable) -> Callable:
     True
 
     """
+
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
             wrapper_memoized.cache[cache_key] = value
         else:
-            logger.debug(f"Returning memoized value for {func.__name__}")
+            logger.debug('Returning memoized value for %s', {func.__name__})
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = {}  # type: ignore
     return wrapper_memoized
 
 
@@ -373,17 +381,20 @@ def retry_predicate(
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
-    """Retries a function or method up to a certain number of times
-    with a prescribed initial delay period and backoff rate.
-
-    tries is the maximum number of attempts to run the function.
-    delay_sec sets the initial delay period in seconds.
-    backoff is a multiplied (must be >1) used to modify the delay.
-    predicate is a function that will be passed the retval of the
-    decorated function and must return True to stop or False to
-    retry.
-
+    """Retries a function or method up to a certain number of times with a
+    prescribed initial delay period and backoff rate (multiplier).
+
+    Args:
+        tries: the maximum number of attempts to run the function
+        delay_sec: sets the initial delay period in seconds
+        backoff: a multiplier (must be >=1.0) used to modify the
+            delay at each subsequent invocation
+        predicate: a Callable that will be passed the retval of
+            the decorated function and must return True to indicate
+            that we should stop calling or False to indicate a retry
+            is necessary
     """
+
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
@@ -404,7 +415,7 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
-            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
+            logger.debug('deco_retry: will make up to %d attempts...', mtries)
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
@@ -416,7 +427,9 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
+
     return deco_retry
 
 
@@ -427,7 +440,6 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     >>> import time
 
     >>> counter = 0
-
     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
     ... def foo():
     ...     global counter
@@ -443,7 +455,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     3
     >>> dur > 2.0
     True
-    >>> dur < 2.2
+    >>> dur < 2.3
     True
 
     """
@@ -459,8 +471,8 @@ def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
     """Another helper for @retry_predicate above.  Retries up to N
     times so long as the wrapped function returns None with a delay
     between each retry and a backoff that can increase the delay.
-
     """
+
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -473,8 +485,8 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
-
     """
+
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
         msg = f"Call to deprecated function {func.__qualname__}"
@@ -482,6 +494,7 @@ def deprecated(func):
         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
         print(msg, file=sys.stderr)
         return func(*args, **kwargs)
+
     return wrapper_deprecated
 
 
@@ -497,7 +510,7 @@ def thunkify(func):
         wait_event = threading.Event()
 
         result = [None]
-        exc = [False, None]
+        exc: List[Any] = [False, None]
 
         def worker_func():
             try:
@@ -508,13 +521,13 @@ def thunkify(func):
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 logger.warning(msg)
-                warnings.warn(msg)
             finally:
                 wait_event.set()
 
         def thunk():
             wait_event.wait()
             if exc[0]:
+                assert exc[1]
                 raise exc[1][0](exc[1][1])
             return result[0]
 
@@ -535,7 +548,7 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise Exception()
+        raise Exception(exception)
     else:
         raise Exception(error_message)
 
@@ -588,9 +601,7 @@ class _Timeout(object):
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
-        self.__process = multiprocessing.Process(
-            target=_target, args=args, kwargs=kwargs
-        )
+        self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
@@ -620,6 +631,7 @@ class _Timeout(object):
             if flag:
                 return load
             raise load
+        return None
 
 
 def timeout(
@@ -635,6 +647,13 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
+    Beware that an @timeout on a function inside a module will be
+    evaluated at module load time and not when the wrapped function is
+    invoked.  This can lead to problems when relying on the automatic
+    main thread detection code (use_signals=None, the default) since
+    the import probably happens on the main thread and the invocation
+    can happen on a different thread (which can't use signals).
+
     Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
@@ -656,12 +675,13 @@ def timeout(
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         if use_signals:
 
-            def handler(signum, frame):
+            def handler(unused_signum, unused_frame):
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
@@ -686,9 +706,7 @@ def timeout(
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
-                timeout_wrapper = _Timeout(
-                    function, timeout_exception, error_message, seconds
-                )
+                timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
@@ -696,39 +714,31 @@ def timeout(
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
-
+def synchronized(lock):
+    """Emulates java's synchronized keyword: given a lock, require that
+    threads take that lock (or wait) before invoking the wrapped
+    function and automatically releases the lock afterwards.
+    """
 
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
         return _gatekeeper
 
+    return wrap
+
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
+    """Calls the wrapped function probabilistically given a rate between
+    0.0 and 1.0 inclusive (0% probability and 100% probability).
+    """
+
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
@@ -740,18 +750,20 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
+                return None
+
         return _call_with_sample_rate
+
     return decorator
 
 
 def decorate_matching_methods_with(decorator, acl=None):
-    """Apply decorator to all methods in a class whose names begin with
-    prefix.  If prefix is None (default), decorate all methods in the
-    class.
+    """Apply the given decorator to all methods in a class whose names
+    begin with prefix.  If prefix is None (default), decorate all
+    methods in the class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -760,10 +772,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
 
+    doctest.testmod()