X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=4615fec6f8960e0083ce48546ba9421c25243d42;hb=e46158e49121b8a955bb07b73f5bcf9928b79c90;hp=752fb919df35dc8ff5f93aba672b9d10bc1efb19;hpb=55a3172e37855f388b9ba0dfc91641a6c9ad1376;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 752fb91..4615fec 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 -"""Decorators.""" +# © Copyright 2021-2022, Scott Gasch +# Portions (marked) below retain the original author's copyright. + +"""Useful(?) decorators.""" import enum import functools @@ -14,14 +17,13 @@ import sys import threading import time import traceback -from typing import Any, Callable, Optional import warnings +from typing import Any, Callable, List, Optional # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. import exceptions - logger = logging.getLogger(__name__) @@ -31,7 +33,7 @@ def timed(func: Callable) -> Callable: >>> @timed ... def foo(): ... import time - ... time.sleep(0.1) + ... time.sleep(0.01) >>> foo() # doctest: +ELLIPSIS Finished foo in ... @@ -48,11 +50,12 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer def invocation_logged(func: Callable) -> Callable: - """Log the call of a function. + """Log the call of a function on stdout and the info log. >>> @invocation_logged ... def foo(): @@ -75,13 +78,14 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable: - """Limit invocation of a wrapped function to n calls per period. + """Limit invocation of a wrapped function to n calls per time period. Thread safe. In testing this was relatively fair with multiple - threads using it though that hasn't been measured. + threads using it though that hasn't been measured in detail. >>> import time >>> import decorator_utils @@ -132,7 +136,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl wait_time = min_interval_seconds - elapsed_since_last else: wait_time = 0.0 - logger.debug(f'@{time.time()}> wait_time = {wait_time}') + logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time) return wait_time def wrapper_wrapper_rate_limited(*args, **kargs) -> Any: @@ -144,15 +148,17 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ): break with cv: - logger.debug(f'@{time.time()}> calling it...') + logger.debug('@%.4f> calling it...', time.time()) ret = func(*args, **kargs) last_invocation_timestamp[0] = time.time() logger.debug( - f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}' + '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0] ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -217,11 +224,17 @@ def debug_count_calls(func: Callable) -> Callable: print(msg) logger.info(msg) return func(*args, **kwargs) - wrapper_debug_count_calls.num_calls = 0 + + wrapper_debug_count_calls.num_calls = 0 # type: ignore return wrapper_debug_count_calls class DelayWhen(enum.IntEnum): + """When should we delay: before or after calling the function (or + both)? + + """ + BEFORE_CALL = 1 AFTER_CALL = 2 BEFORE_AND_AFTER = 3 @@ -233,9 +246,7 @@ def delay( seconds: float = 1.0, when: DelayWhen = DelayWhen.BEFORE_CALL, ) -> Callable: - """Delay the execution of a function by sleeping before and/or after. - - Slow down a function by inserting a delay before and/or after its + """Slow down a function by inserting a delay before and/or after its invocation. >>> import time @@ -251,21 +262,19 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): if when & DelayWhen.BEFORE_CALL: - logger.debug( - f"@delay for {seconds}s BEFORE_CALL to {func.__name__}" - ) + logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__) time.sleep(seconds) retval = func(*args, **kwargs) if when & DelayWhen.AFTER_CALL: - logger.debug( - f"@delay for {seconds}s AFTER_CALL to {func.__name__}" - ) + logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__) time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -287,9 +296,7 @@ class _SingletonWrapper: def __call__(self, *args, **kwargs): """Returns a single instance of decorated class""" - logger.debug( - f"@singleton returning global instance of {self.__wrapped__.__name__}" - ) + logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__) if self._instance is None: self._instance = self.__wrapped__(*args, **kwargs) return self._instance @@ -321,8 +328,9 @@ def memoized(func: Callable) -> Callable: """Keep a cache of previous function call results. The cache here is a dict with a key based on the arguments to the - call. Consider also: functools.lru_cache for a more advanced - implementation. + call. Consider also: functools.cache for a more advanced + implementation. See: + https://docs.python.org/3/library/functools.html#functools.cache >>> import time @@ -350,19 +358,19 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) if cache_key not in wrapper_memoized.cache: value = func(*args, **kwargs) - logger.debug( - f"Memoizing {cache_key} => {value} for {func.__name__}" - ) + logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__) wrapper_memoized.cache[cache_key] = value else: - logger.debug(f"Returning memoized value for {func.__name__}") + logger.debug('Returning memoized value for %s', {func.__name__}) return wrapper_memoized.cache[cache_key] - wrapper_memoized.cache = dict() + + wrapper_memoized.cache = {} # type: ignore return wrapper_memoized @@ -373,17 +381,20 @@ def retry_predicate( delay_sec: float = 3.0, backoff: float = 2.0, ): - """Retries a function or method up to a certain number of times - with a prescribed initial delay period and backoff rate. - - tries is the maximum number of attempts to run the function. - delay_sec sets the initial delay period in seconds. - backoff is a multiplied (must be >1) used to modify the delay. - predicate is a function that will be passed the retval of the - decorated function and must return True to stop or False to - retry. - + """Retries a function or method up to a certain number of times with a + prescribed initial delay period and backoff rate (multiplier). + + Args: + tries: the maximum number of attempts to run the function + delay_sec: sets the initial delay period in seconds + backoff: a multiplier (must be >=1.0) used to modify the + delay at each subsequent invocation + predicate: a Callable that will be passed the retval of + the decorated function and must return True to indicate + that we should stop calling or False to indicate a retry + is necessary """ + if backoff < 1.0: msg = f"backoff must be greater than or equal to 1, got {backoff}" logger.critical(msg) @@ -404,7 +415,7 @@ def retry_predicate( @functools.wraps(f) def f_retry(*args, **kwargs): mtries, mdelay = tries, delay_sec # make mutable - logger.debug(f'deco_retry: will make up to {mtries} attempts...') + logger.debug('deco_retry: will make up to %d attempts...', mtries) retval = f(*args, **kwargs) while mtries > 0: if predicate(retval) is True: @@ -416,7 +427,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -427,7 +440,6 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): >>> import time >>> counter = 0 - >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1) ... def foo(): ... global counter @@ -443,7 +455,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): 3 >>> dur > 2.0 True - >>> dur < 2.2 + >>> dur < 2.3 True """ @@ -459,8 +471,8 @@ def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0): """Another helper for @retry_predicate above. Retries up to N times so long as the wrapped function returns None with a delay between each retry and a backoff that can increase the delay. - """ + return retry_predicate( tries, predicate=lambda x: x is not None, @@ -473,8 +485,8 @@ def deprecated(func): """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. - """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" @@ -482,6 +494,7 @@ def deprecated(func): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -497,7 +510,7 @@ def thunkify(func): wait_event = threading.Event() result = [None] - exc = [False, None] + exc: List[Any] = [False, None] def worker_func(): try: @@ -514,6 +527,7 @@ def thunkify(func): def thunk(): wait_event.wait() if exc[0]: + assert exc[1] raise exc[1][0](exc[1][1]) return result[0] @@ -534,7 +548,7 @@ def thunkify(func): def _raise_exception(exception, error_message: Optional[str]): if error_message is None: - raise Exception() + raise Exception(exception) else: raise Exception(error_message) @@ -587,9 +601,7 @@ class _Timeout(object): self.__limit = kwargs.pop("timeout", self.__limit) self.__queue = multiprocessing.Queue(1) args = (self.__queue, self.__function) + args - self.__process = multiprocessing.Process( - target=_target, args=args, kwargs=kwargs - ) + self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs) self.__process.daemon = True self.__process.start() if self.__limit is not None: @@ -619,6 +631,7 @@ class _Timeout(object): if flag: return load raise load + return None def timeout( @@ -634,6 +647,13 @@ def timeout( main thread). When not using signals, timeout granularity will be rounded to the nearest 0.1s. + Beware that an @timeout on a function inside a module will be + evaluated at module load time and not when the wrapped function is + invoked. This can lead to problems when relying on the automatic + main thread detection code (use_signals=None, the default) since + the import probably happens on the main thread and the invocation + can happen on a different thread (which can't use signals). + Raises an exception when/if the timeout is reached. It is illegal to pass anything other than a function as the first @@ -655,12 +675,13 @@ def timeout( """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): if use_signals: - def handler(signum, frame): + def handler(unused_signum, unused_frame): _raise_exception(timeout_exception, error_message) @functools.wraps(function) @@ -685,9 +706,7 @@ def timeout( @functools.wraps(function) def new_function(*args, **kwargs): - timeout_wrapper = _Timeout( - function, timeout_exception, error_message, seconds - ) + timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds) return timeout_wrapper(*args, **kwargs) return new_function @@ -695,39 +714,31 @@ def timeout( return decorate -class non_reentrant_code(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False - return _gatekeeper - +def synchronized(lock): + """Emulates java's synchronized keyword: given a lock, require that + threads take that lock (or wait) before invoking the wrapped + function and automatically releases the lock afterwards. + """ -class rlocked(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False return _gatekeeper + return wrap + def call_with_sample_rate(sample_rate: float) -> Callable: + """Calls the wrapped function probabilistically given a rate between + 0.0 and 1.0 inclusive (0% probability and 100% probability). + """ + if not 0.0 <= sample_rate <= 1.0: msg = f"sample_rate must be between [0, 1]. Got {sample_rate}." logger.critical(msg) @@ -739,18 +750,20 @@ def call_with_sample_rate(sample_rate: float) -> Callable: if random.uniform(0, 1) < sample_rate: return f(*args, **kwargs) else: - logger.debug( - f"@call_with_sample_rate skipping a call to {f.__name__}" - ) + logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__) + return None + return _call_with_sample_rate + return decorator def decorate_matching_methods_with(decorator, acl=None): - """Apply decorator to all methods in a class whose names begin with - prefix. If prefix is None (default), decorate all methods in the - class. + """Apply the given decorator to all methods in a class whose names + begin with prefix. If prefix is None (default), decorate all + methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -759,10 +772,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod()