X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=cd69639448425ce3a47073c7e423ea98d6704b2b;hb=a4bf4d05230474ad14243d67ac7f8c938f670e58;hp=07ad881f63a613de38d82d9a54babce92127b1b5;hpb=b454ad295eb3024a238d32bf2aef1ebc3c496b44;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 07ad881..cd69639 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer @@ -75,6 +76,7 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged @@ -152,7 +154,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +192,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -217,7 +222,8 @@ def debug_count_calls(func: Callable) -> Callable: print(msg) logger.info(msg) return func(*args, **kwargs) - wrapper_debug_count_calls.num_calls = 0 + + wrapper_debug_count_calls.num_calls = 0 # type: ignore return wrapper_debug_count_calls @@ -251,21 +257,19 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): if when & DelayWhen.BEFORE_CALL: - logger.debug( - f"@delay for {seconds}s BEFORE_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}") time.sleep(seconds) retval = func(*args, **kwargs) if when & DelayWhen.AFTER_CALL: - logger.debug( - f"@delay for {seconds}s AFTER_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}") time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -350,19 +354,19 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) if cache_key not in wrapper_memoized.cache: value = func(*args, **kwargs) - logger.debug( - f"Memoizing {cache_key} => {value} for {func.__name__}" - ) + logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}") wrapper_memoized.cache[cache_key] = value else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] - wrapper_memoized.cache = dict() + + wrapper_memoized.cache = dict() # type: ignore return wrapper_memoized @@ -416,7 +420,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -443,7 +449,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): 3 >>> dur > 2.0 True - >>> dur < 2.2 + >>> dur < 2.3 True """ @@ -475,6 +481,7 @@ def deprecated(func): when the function is used. """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" @@ -482,6 +489,7 @@ def deprecated(func): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -508,7 +516,6 @@ def thunkify(func): exc[1] = sys.exc_info() # (type, value, traceback) msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}" logger.warning(msg) - warnings.warn(msg) finally: wait_event.set() @@ -656,6 +663,7 @@ def timeout( """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -696,36 +704,19 @@ def timeout( return decorate -class non_reentrant_code(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False +def synchronized(lock): + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False return _gatekeeper - -class rlocked(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False - return _gatekeeper + return wrap def call_with_sample_rate(sample_rate: float) -> Callable: @@ -740,10 +731,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable: if random.uniform(0, 1) < sample_rate: return f(*args, **kwargs) else: - logger.debug( - f"@call_with_sample_rate skipping a call to {f.__name__}" - ) + logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}") + return _call_with_sample_rate + return decorator @@ -752,6 +743,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -760,10 +752,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod()