X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=a5c5afecb34c005d16a351cc703f44dc561b567d;hb=31c81f6539969a5eba864d3305f9fb7bf716a367;hp=9b848ed792144919b863b20c82e846bcd509bbe8;hpb=d742c4a0f3a177e3ab55a9eb2d30e0e37af2f044;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 9b848ed..a5c5afe 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -14,14 +14,13 @@ import sys import threading import time import traceback -from typing import Any, Callable, Optional import warnings +from typing import Any, Callable, Optional # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. import exceptions - logger = logging.getLogger(__name__) @@ -48,6 +47,7 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer @@ -75,6 +75,7 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged @@ -152,7 +153,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +191,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -217,7 +221,8 @@ def debug_count_calls(func: Callable) -> Callable: print(msg) logger.info(msg) return func(*args, **kwargs) - wrapper_debug_count_calls.num_calls = 0 + + wrapper_debug_count_calls.num_calls = 0 # type: ignore return wrapper_debug_count_calls @@ -251,21 +256,19 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): if when & DelayWhen.BEFORE_CALL: - logger.debug( - f"@delay for {seconds}s BEFORE_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}") time.sleep(seconds) retval = func(*args, **kwargs) if when & DelayWhen.AFTER_CALL: - logger.debug( - f"@delay for {seconds}s AFTER_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}") time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -350,19 +353,19 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) if cache_key not in wrapper_memoized.cache: value = func(*args, **kwargs) - logger.debug( - f"Memoizing {cache_key} => {value} for {func.__name__}" - ) + logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}") wrapper_memoized.cache[cache_key] = value else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] - wrapper_memoized.cache = dict() + + wrapper_memoized.cache = dict() # type: ignore return wrapper_memoized @@ -416,7 +419,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -443,7 +448,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): 3 >>> dur > 2.0 True - >>> dur < 2.2 + >>> dur < 2.3 True """ @@ -475,13 +480,15 @@ def deprecated(func): when the function is used. """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" logger.warning(msg) - warnings.warn(msg, category=DeprecationWarning) + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -507,7 +514,6 @@ def thunkify(func): exc[0] = True exc[1] = sys.exc_info() # (type, value, traceback) msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}" - print(msg) logger.warning(msg) finally: wait_event.set() @@ -535,9 +541,9 @@ def thunkify(func): def _raise_exception(exception, error_message: Optional[str]): if error_message is None: - raise exception() + raise Exception() else: - raise exception(error_message) + raise Exception(error_message) def _target(queue, function, *args, **kwargs): @@ -555,10 +561,10 @@ def _target(queue, function, *args, **kwargs): class _Timeout(object): - """Wrap a function and add a timeout (limit) attribute to it. + """Wrap a function and add a timeout to it. Instances of this class are automatically generated by the add_timeout - function defined below. + function defined below. Do not use directly. """ def __init__( @@ -635,13 +641,28 @@ def timeout( main thread). When not using signals, timeout granularity will be rounded to the nearest 0.1s. - Raises an exception when the timeout is reached. + Raises an exception when/if the timeout is reached. It is illegal to pass anything other than a function as the first parameter. The function is wrapped and returned to the caller. + + >>> @timeout(0.2) + ... def foo(delay: float): + ... time.sleep(delay) + ... return "ok" + + >>> foo(0) + 'ok' + + >>> foo(1.0) + Traceback (most recent call last): + ... + Exception: Function call timed out + """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -682,37 +703,19 @@ def timeout( return decorate -class non_reentrant_code(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False +def synchronized(lock): + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() return _gatekeeper - -class rlocked(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False - return _gatekeeper + return wrap def call_with_sample_rate(sample_rate: float) -> Callable: @@ -727,10 +730,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable: if random.uniform(0, 1) < sample_rate: return f(*args, **kwargs) else: - logger.debug( - f"@call_with_sample_rate skipping a call to {f.__name__}" - ) + logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}") + return _call_with_sample_rate + return decorator @@ -739,6 +742,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -747,10 +751,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod()