X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=a5c5afecb34c005d16a351cc703f44dc561b567d;hb=31c81f6539969a5eba864d3305f9fb7bf716a367;hp=752fb919df35dc8ff5f93aba672b9d10bc1efb19;hpb=55a3172e37855f388b9ba0dfc91641a6c9ad1376;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 752fb91..a5c5afe 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -14,14 +14,13 @@ import sys import threading import time import traceback -from typing import Any, Callable, Optional import warnings +from typing import Any, Callable, Optional # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. import exceptions - logger = logging.getLogger(__name__) @@ -48,6 +47,7 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer @@ -75,6 +75,7 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged @@ -152,7 +153,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +191,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -217,7 +221,8 @@ def debug_count_calls(func: Callable) -> Callable: print(msg) logger.info(msg) return func(*args, **kwargs) - wrapper_debug_count_calls.num_calls = 0 + + wrapper_debug_count_calls.num_calls = 0 # type: ignore return wrapper_debug_count_calls @@ -251,21 +256,19 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): if when & DelayWhen.BEFORE_CALL: - logger.debug( - f"@delay for {seconds}s BEFORE_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}") time.sleep(seconds) retval = func(*args, **kwargs) if when & DelayWhen.AFTER_CALL: - logger.debug( - f"@delay for {seconds}s AFTER_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}") time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -350,19 +353,19 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) if cache_key not in wrapper_memoized.cache: value = func(*args, **kwargs) - logger.debug( - f"Memoizing {cache_key} => {value} for {func.__name__}" - ) + logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}") wrapper_memoized.cache[cache_key] = value else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] - wrapper_memoized.cache = dict() + + wrapper_memoized.cache = dict() # type: ignore return wrapper_memoized @@ -416,7 +419,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -443,7 +448,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): 3 >>> dur > 2.0 True - >>> dur < 2.2 + >>> dur < 2.3 True """ @@ -475,6 +480,7 @@ def deprecated(func): when the function is used. """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" @@ -482,6 +488,7 @@ def deprecated(func): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -655,6 +662,7 @@ def timeout( """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -695,36 +703,19 @@ def timeout( return decorate -class non_reentrant_code(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False +def synchronized(lock): + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False return _gatekeeper - -class rlocked(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False - return _gatekeeper + return wrap def call_with_sample_rate(sample_rate: float) -> Callable: @@ -739,10 +730,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable: if random.uniform(0, 1) < sample_rate: return f(*args, **kwargs) else: - logger.debug( - f"@call_with_sample_rate skipping a call to {f.__name__}" - ) + logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}") + return _call_with_sample_rate + return decorator @@ -751,6 +742,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -759,10 +751,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod()