#!/usr/bin/env python3
-"""Decorators."""
+# © Copyright 2021-2022, Scott Gasch
+# Portions (marked) below retain the original author's copyright.
+
+"""Useful(?) decorators."""
-import datetime
import enum
import functools
import inspect
import threading
import time
import traceback
-from typing import Callable, Optional
import warnings
+from typing import Any, Callable, List, Optional
# This module is commonly used by others in here and should avoid
# taking any unnecessary dependencies back on them.
import exceptions
-
logger = logging.getLogger(__name__)
def timed(func: Callable) -> Callable:
- """Print the runtime of the decorated function."""
+ """Print the runtime of the decorated function.
+
+ >>> @timed
+ ... def foo():
+ ... import time
+ ... time.sleep(0.01)
+
+ >>> foo() # doctest: +ELLIPSIS
+ Finished foo in ...
+
+ """
@functools.wraps(func)
def wrapper_timer(*args, **kwargs):
value = func(*args, **kwargs)
end_time = time.perf_counter()
run_time = end_time - start_time
- msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
+ msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
print(msg)
logger.info(msg)
return value
+
return wrapper_timer
def invocation_logged(func: Callable) -> Callable:
- """Log the call of a function."""
+ """Log the call of a function on stdout and the info log.
+
+ >>> @invocation_logged
+ ... def foo():
+ ... print('Hello, world.')
+
+ >>> foo()
+ Entered foo
+ Hello, world.
+ Exited foo
+
+ """
@functools.wraps(func)
def wrapper_invocation_logged(*args, **kwargs):
print(msg)
logger.info(msg)
return ret
+
return wrapper_invocation_logged
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+ """Limit invocation of a wrapped function to n calls per time period.
+ Thread safe. In testing this was relatively fair with multiple
+ threads using it though that hasn't been measured in detail.
+
+ >>> import time
+ >>> import decorator_utils
+ >>> import thread_utils
+
+ >>> calls = 0
+
+ >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
+ ... def limited(x: int):
+ ... global calls
+ ... calls += 1
+
+ >>> @thread_utils.background_thread
+ ... def a(stop):
+ ... for _ in range(3):
+ ... limited(_)
+
+ >>> @thread_utils.background_thread
+ ... def b(stop):
+ ... for _ in range(3):
+ ... limited(_)
+
+ >>> start = time.time()
+ >>> (t1, e1) = a()
+ >>> (t2, e2) = b()
+ >>> t1.join()
+ >>> t2.join()
+ >>> end = time.time()
+ >>> dur = end - start
+ >>> dur > 0.5
+ True
+
+ >>> calls
+ 6
+
+ """
+ min_interval_seconds = per_period_in_seconds / float(n_calls)
+
+ def wrapper_rate_limited(func: Callable) -> Callable:
+ cv = threading.Condition()
+ last_invocation_timestamp = [0.0]
+
+ def may_proceed() -> float:
+ now = time.time()
+ last_invocation = last_invocation_timestamp[0]
+ if last_invocation != 0.0:
+ elapsed_since_last = now - last_invocation
+ wait_time = min_interval_seconds - elapsed_since_last
+ else:
+ wait_time = 0.0
+ logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
+ return wait_time
+
+ def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
+ with cv:
+ while True:
+ if cv.wait_for(
+ lambda: may_proceed() <= 0.0,
+ timeout=may_proceed(),
+ ):
+ break
+ with cv:
+ logger.debug('@%.4f> calling it...', time.time())
+ ret = func(*args, **kargs)
+ last_invocation_timestamp[0] = time.time()
+ logger.debug(
+ '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
+ )
+ cv.notify()
+ return ret
+
+ return wrapper_wrapper_rate_limited
+
+ return wrapper_rate_limited
+
+
def debug_args(func: Callable) -> Callable:
- """Print the function signature and return value at each call."""
+ """Print the function signature and return value at each call.
+
+ >>> @debug_args
+ ... def foo(a, b, c):
+ ... print(a)
+ ... print(b)
+ ... print(c)
+ ... return (a + b, c)
+
+ >>> foo(1, 2.0, "test")
+ Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
+ 1
+ 2.0
+ test
+ foo returned (3.0, 'test'):<class 'tuple'>
+ (3.0, 'test')
+ """
@functools.wraps(func)
def wrapper_debug_args(*args, **kwargs):
args_repr = [f"{repr(a)}:{type(a)}" for a in args]
kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
- msg = f"Calling {func.__name__}({signature})"
+ msg = f"Calling {func.__qualname__}({signature})"
print(msg)
logger.info(msg)
value = func(*args, **kwargs)
- msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
+ msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
+ print(msg)
logger.info(msg)
return value
+
return wrapper_debug_args
def debug_count_calls(func: Callable) -> Callable:
- """Count function invocations and print a message befor every call."""
+ """Count function invocations and print a message befor every call.
+
+ >>> @debug_count_calls
+ ... def factoral(x):
+ ... if x == 1:
+ ... return 1
+ ... return x * factoral(x - 1)
+
+ >>> factoral(5)
+ Call #1 of 'factoral'
+ Call #2 of 'factoral'
+ Call #3 of 'factoral'
+ Call #4 of 'factoral'
+ Call #5 of 'factoral'
+ 120
+
+ """
@functools.wraps(func)
def wrapper_debug_count_calls(*args, **kwargs):
print(msg)
logger.info(msg)
return func(*args, **kwargs)
- wrapper_debug_count_calls.num_calls = 0
+
+ wrapper_debug_count_calls.num_calls = 0 # type: ignore
return wrapper_debug_count_calls
-class DelayWhen(enum.Enum):
+class DelayWhen(enum.IntEnum):
+ """When should we delay: before or after calling the function (or
+ both)?
+
+ """
+
BEFORE_CALL = 1
AFTER_CALL = 2
BEFORE_AND_AFTER = 3
seconds: float = 1.0,
when: DelayWhen = DelayWhen.BEFORE_CALL,
) -> Callable:
- """Delay the execution of a function by sleeping before and/or after.
-
- Slow down a function by inserting a delay before and/or after its
+ """Slow down a function by inserting a delay before and/or after its
invocation.
+
+ >>> import time
+
+ >>> @delay(seconds=1.0)
+ ... def foo():
+ ... pass
+
+ >>> start = time.time()
+ >>> foo()
+ >>> dur = time.time() - start
+ >>> dur >= 1.0
+ True
+
"""
def decorator_delay(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper_delay(*args, **kwargs):
if when & DelayWhen.BEFORE_CALL:
- logger.debug(
- f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
- )
+ logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
time.sleep(seconds)
retval = func(*args, **kwargs)
if when & DelayWhen.AFTER_CALL:
- logger.debug(
- f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
- )
+ logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
time.sleep(seconds)
return retval
+
return wrapper_delay
if _func is None:
"""
A singleton wrapper class. Its instances would be created
for each decorated class.
+
"""
def __init__(self, cls):
def __call__(self, *args, **kwargs):
"""Returns a single instance of decorated class"""
- logger.debug(
- f"@singleton returning global instance of {self.__wrapped__.__name__}"
- )
+ logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
if self._instance is None:
self._instance = self.__wrapped__(*args, **kwargs)
return self._instance
A singleton decorator. Returns a wrapper objects. A call on that object
returns a single instance object of decorated class. Use the __wrapped__
attribute to access decorated class directly in unit tests
+
+ >>> @singleton
+ ... class foo(object):
+ ... pass
+
+ >>> a = foo()
+ >>> b = foo()
+ >>> a is b
+ True
+
+ >>> id(a) == id(b)
+ True
+
"""
return _SingletonWrapper(cls)
"""Keep a cache of previous function call results.
The cache here is a dict with a key based on the arguments to the
- call. Consider also: functools.lru_cache for a more advanced
- implementation.
+ call. Consider also: functools.cache for a more advanced
+ implementation. See:
+ https://docs.python.org/3/library/functools.html#functools.cache
+
+ >>> import time
+
+ >>> @memoized
+ ... def expensive(arg) -> int:
+ ... # Simulate something slow to compute or lookup
+ ... time.sleep(1.0)
+ ... return arg * arg
+
+ >>> start = time.time()
+ >>> expensive(5) # Takes about 1 sec
+ 25
+
+ >>> expensive(3) # Also takes about 1 sec
+ 9
+
+ >>> expensive(5) # Pulls from cache, fast
+ 25
+
+ >>> expensive(3) # Pulls from cache again, fast
+ 9
+
+ >>> dur = time.time() - start
+ >>> dur < 3.0
+ True
+
"""
@functools.wraps(func)
cache_key = args + tuple(kwargs.items())
if cache_key not in wrapper_memoized.cache:
value = func(*args, **kwargs)
- logger.debug(
- f"Memoizing {cache_key} => {value} for {func.__name__}"
- )
+ logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
wrapper_memoized.cache[cache_key] = value
else:
- logger.debug(f"Returning memoized value for {func.__name__}")
+ logger.debug('Returning memoized value for %s', {func.__name__})
return wrapper_memoized.cache[cache_key]
- wrapper_memoized.cache = dict()
+
+ wrapper_memoized.cache = {} # type: ignore
return wrapper_memoized
delay_sec: float = 3.0,
backoff: float = 2.0,
):
- """Retries a function or method up to a certain number of times
- with a prescribed initial delay period and backoff rate.
-
- tries is the maximum number of attempts to run the function.
- delay_sec sets the initial delay period in seconds.
- backoff is a multiplied (must be >1) used to modify the delay.
- predicate is a function that will be passed the retval of the
- decorated function and must return True to stop or False to
- retry.
+ """Retries a function or method up to a certain number of times with a
+ prescribed initial delay period and backoff rate (multiplier).
+
+ Args:
+ tries: the maximum number of attempts to run the function
+ delay_sec: sets the initial delay period in seconds
+ backoff: a multiplier (must be >=1.0) used to modify the
+ delay at each subsequent invocation
+ predicate: a Callable that will be passed the retval of
+ the decorated function and must return True to indicate
+ that we should stop calling or False to indicate a retry
+ is necessary
"""
+
if backoff < 1.0:
msg = f"backoff must be greater than or equal to 1, got {backoff}"
logger.critical(msg)
@functools.wraps(f)
def f_retry(*args, **kwargs):
mtries, mdelay = tries, delay_sec # make mutable
- logger.debug(f'deco_retry: will make up to {mtries} attempts...')
+ logger.debug('deco_retry: will make up to %d attempts...', mtries)
retval = f(*args, **kwargs)
while mtries > 0:
if predicate(retval) is True:
mdelay *= backoff
retval = f(*args, **kwargs)
return retval
+
return f_retry
+
return deco_retry
def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
+ """A helper for @retry_predicate that retries a decorated
+ function as long as it keeps returning False.
+
+ >>> import time
+
+ >>> counter = 0
+ >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
+ ... def foo():
+ ... global counter
+ ... counter += 1
+ ... return counter >= 3
+
+ >>> start = time.time()
+ >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
+ True
+
+ >>> dur = time.time() - start
+ >>> counter
+ 3
+ >>> dur > 2.0
+ True
+ >>> dur < 2.3
+ True
+
+ """
return retry_predicate(
tries,
predicate=lambda x: x is True,
def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
+ """Another helper for @retry_predicate above. Retries up to N
+ times so long as the wrapped function returns None with a delay
+ between each retry and a backoff that can increase the delay.
+ """
+
return retry_predicate(
tries,
predicate=lambda x: x is not None,
@functools.wraps(func)
def wrapper_deprecated(*args, **kwargs):
- msg = f"Call to deprecated function {func.__name__}"
+ msg = f"Call to deprecated function {func.__qualname__}"
logger.warning(msg)
- warnings.warn(msg, category=DeprecationWarning)
+ warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
+ print(msg, file=sys.stderr)
return func(*args, **kwargs)
return wrapper_deprecated
wait_event = threading.Event()
result = [None]
- exc = [False, None]
+ exc: List[Any] = [False, None]
def worker_func():
try:
exc[0] = True
exc[1] = sys.exc_info() # (type, value, traceback)
msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
- print(msg)
logger.warning(msg)
finally:
wait_event.set()
def thunk():
wait_event.wait()
if exc[0]:
+ assert exc[1]
raise exc[1][0](exc[1][1])
return result[0]
def _raise_exception(exception, error_message: Optional[str]):
if error_message is None:
- raise exception()
+ raise Exception(exception)
else:
- raise exception(error_message)
+ raise Exception(error_message)
def _target(queue, function, *args, **kwargs):
class _Timeout(object):
- """Wrap a function and add a timeout (limit) attribute to it.
+ """Wrap a function and add a timeout to it.
Instances of this class are automatically generated by the add_timeout
- function defined below.
+ function defined below. Do not use directly.
"""
def __init__(
self.__limit = kwargs.pop("timeout", self.__limit)
self.__queue = multiprocessing.Queue(1)
args = (self.__queue, self.__function) + args
- self.__process = multiprocessing.Process(
- target=_target, args=args, kwargs=kwargs
- )
+ self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
self.__process.daemon = True
self.__process.start()
if self.__limit is not None:
if flag:
return load
raise load
+ return None
def timeout(
main thread). When not using signals, timeout granularity will be
rounded to the nearest 0.1s.
- Raises an exception when the timeout is reached.
+ Beware that an @timeout on a function inside a module will be
+ evaluated at module load time and not when the wrapped function is
+ invoked. This can lead to problems when relying on the automatic
+ main thread detection code (use_signals=None, the default) since
+ the import probably happens on the main thread and the invocation
+ can happen on a different thread (which can't use signals).
+
+ Raises an exception when/if the timeout is reached.
It is illegal to pass anything other than a function as the first
parameter. The function is wrapped and returned to the caller.
+
+ >>> @timeout(0.2)
+ ... def foo(delay: float):
+ ... time.sleep(delay)
+ ... return "ok"
+
+ >>> foo(0)
+ 'ok'
+
+ >>> foo(1.0)
+ Traceback (most recent call last):
+ ...
+ Exception: Function call timed out
+
"""
if use_signals is None:
import thread_utils
+
use_signals = thread_utils.is_current_thread_main_thread()
def decorate(function):
if use_signals:
- def handler(signum, frame):
+ def handler(unused_signum, unused_frame):
_raise_exception(timeout_exception, error_message)
@functools.wraps(function)
@functools.wraps(function)
def new_function(*args, **kwargs):
- timeout_wrapper = _Timeout(
- function, timeout_exception, error_message, seconds
- )
+ timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
return timeout_wrapper(*args, **kwargs)
return new_function
return decorate
-class non_reentrant_code(object):
- def __init__(self):
- self._lock = threading.RLock
- self._entered = False
+def synchronized(lock):
+ """Emulates java's synchronized keyword: given a lock, require that
+ threads take that lock (or wait) before invoking the wrapped
+ function and automatically releases the lock afterwards.
+ """
- def __call__(self, f):
- def _gatekeeper(*args, **kwargs):
- with self._lock:
- if self._entered:
- return
- self._entered = True
- f(*args, **kwargs)
- self._entered = False
+ def wrap(f):
+ @functools.wraps(f)
+ def _gatekeeper(*args, **kw):
+ lock.acquire()
+ try:
+ return f(*args, **kw)
+ finally:
+ lock.release()
return _gatekeeper
-
-class rlocked(object):
- def __init__(self):
- self._lock = threading.RLock
- self._entered = False
-
- def __call__(self, f):
- def _gatekeeper(*args, **kwargs):
- with self._lock:
- if self._entered:
- return
- self._entered = True
- f(*args, **kwargs)
- self._entered = False
- return _gatekeeper
+ return wrap
def call_with_sample_rate(sample_rate: float) -> Callable:
+ """Calls the wrapped function probabilistically given a rate between
+ 0.0 and 1.0 inclusive (0% probability and 100% probability).
+ """
+
if not 0.0 <= sample_rate <= 1.0:
msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
logger.critical(msg)
if random.uniform(0, 1) < sample_rate:
return f(*args, **kwargs)
else:
- logger.debug(
- f"@call_with_sample_rate skipping a call to {f.__name__}"
- )
+ logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
+ return None
+
return _call_with_sample_rate
+
return decorator
def decorate_matching_methods_with(decorator, acl=None):
- """Apply decorator to all methods in a class whose names begin with
- prefix. If prefix is None (default), decorate all methods in the
- class.
+ """Apply the given decorator to all methods in a class whose names
+ begin with prefix. If prefix is None (default), decorate all
+ methods in the class.
"""
+
def decorate_the_class(cls):
for name, m in inspect.getmembers(cls, inspect.isfunction):
if acl is None:
if acl(name):
setattr(cls, name, decorator(m))
return cls
+
return decorate_the_class
+
+
+if __name__ == '__main__':
+ import doctest
+
+ doctest.testmod()