X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=daae64e75348e973dc8a27cf387faf7f404ef2b2;hb=36fea7f15ed17150691b5b3ead75450e575229ef;hp=07ad881f63a613de38d82d9a54babce92127b1b5;hpb=b454ad295eb3024a238d32bf2aef1ebc3c496b44;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 07ad881..daae64e 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer @@ -75,10 +76,13 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged -def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable: +def rate_limited( + n_calls: int, *, per_period_in_seconds: float = 1.0 +) -> Callable: """Limit invocation of a wrapped function to n calls per period. Thread safe. In testing this was relatively fair with multiple threads using it though that hasn't been measured. @@ -152,7 +156,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -213,10 +220,13 @@ def debug_count_calls(func: Callable) -> Callable: @functools.wraps(func) def wrapper_debug_count_calls(*args, **kwargs): wrapper_debug_count_calls.num_calls += 1 - msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}" + msg = ( + f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}" + ) print(msg) logger.info(msg) return func(*args, **kwargs) + wrapper_debug_count_calls.num_calls = 0 return wrapper_debug_count_calls @@ -251,6 +261,7 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): @@ -266,6 +277,7 @@ def delay( ) time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -350,6 +362,7 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) @@ -362,6 +375,7 @@ def memoized(func: Callable) -> Callable: else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] + wrapper_memoized.cache = dict() return wrapper_memoized @@ -416,7 +430,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -443,7 +459,7 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): 3 >>> dur > 2.0 True - >>> dur < 2.2 + >>> dur < 2.3 True """ @@ -475,6 +491,7 @@ def deprecated(func): when the function is used. """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" @@ -482,6 +499,7 @@ def deprecated(func): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -508,7 +526,6 @@ def thunkify(func): exc[1] = sys.exc_info() # (type, value, traceback) msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}" logger.warning(msg) - warnings.warn(msg) finally: wait_event.set() @@ -656,6 +673,7 @@ def timeout( """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -709,6 +727,7 @@ class non_reentrant_code(object): self._entered = True f(*args, **kwargs) self._entered = False + return _gatekeeper @@ -725,6 +744,7 @@ class rlocked(object): self._entered = True f(*args, **kwargs) self._entered = False + return _gatekeeper @@ -743,7 +763,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable: logger.debug( f"@call_with_sample_rate skipping a call to {f.__name__}" ) + return _call_with_sample_rate + return decorator @@ -752,6 +774,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -760,10 +783,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod()