X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=cd69639448425ce3a47073c7e423ea98d6704b2b;hb=a4bf4d05230474ad14243d67ac7f8c938f670e58;hp=2817239c88c2396b0e5dcc56e7c535b8afdd99d9;hpb=09e6d10face80d98a4578ff54192b5c8bec007d7;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 2817239..cd69639 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -2,7 +2,6 @@ """Decorators.""" -import datetime import enum import functools import inspect @@ -15,7 +14,7 @@ import sys import threading import time import traceback -from typing import Callable, Optional +from typing import Any, Callable, Optional import warnings # This module is commonly used by others in here and should avoid @@ -27,7 +26,17 @@ logger = logging.getLogger(__name__) def timed(func: Callable) -> Callable: - """Print the runtime of the decorated function.""" + """Print the runtime of the decorated function. + + >>> @timed + ... def foo(): + ... import time + ... time.sleep(0.1) + + >>> foo() # doctest: +ELLIPSIS + Finished foo in ... + + """ @functools.wraps(func) def wrapper_timer(*args, **kwargs): @@ -35,53 +44,176 @@ def timed(func: Callable) -> Callable: value = func(*args, **kwargs) end_time = time.perf_counter() run_time = end_time - start_time - msg = f"Finished {func.__name__!r} in {run_time:.4f}s" + msg = f"Finished {func.__qualname__} in {run_time:.4f}s" print(msg) logger.info(msg) return value + return wrapper_timer def invocation_logged(func: Callable) -> Callable: - """Log the call of a function.""" + """Log the call of a function. + + >>> @invocation_logged + ... def foo(): + ... print('Hello, world.') + + >>> foo() + Entered foo + Hello, world. + Exited foo + + """ @functools.wraps(func) def wrapper_invocation_logged(*args, **kwargs): - now = datetime.datetime.now() - ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z") - msg = f"[{ts}]: Entered {func.__name__}" + msg = f"Entered {func.__qualname__}" print(msg) logger.info(msg) ret = func(*args, **kwargs) - now = datetime.datetime.now() - ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z") - msg = f"[{ts}]: Exited {func.__name__}" + msg = f"Exited {func.__qualname__}" print(msg) logger.info(msg) return ret + return wrapper_invocation_logged +def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable: + """Limit invocation of a wrapped function to n calls per period. + Thread safe. In testing this was relatively fair with multiple + threads using it though that hasn't been measured. + + >>> import time + >>> import decorator_utils + >>> import thread_utils + + >>> calls = 0 + + >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0) + ... def limited(x: int): + ... global calls + ... calls += 1 + + >>> @thread_utils.background_thread + ... def a(stop): + ... for _ in range(3): + ... limited(_) + + >>> @thread_utils.background_thread + ... def b(stop): + ... for _ in range(3): + ... limited(_) + + >>> start = time.time() + >>> (t1, e1) = a() + >>> (t2, e2) = b() + >>> t1.join() + >>> t2.join() + >>> end = time.time() + >>> dur = end - start + >>> dur > 0.5 + True + + >>> calls + 6 + + """ + min_interval_seconds = per_period_in_seconds / float(n_calls) + + def wrapper_rate_limited(func: Callable) -> Callable: + cv = threading.Condition() + last_invocation_timestamp = [0.0] + + def may_proceed() -> float: + now = time.time() + last_invocation = last_invocation_timestamp[0] + if last_invocation != 0.0: + elapsed_since_last = now - last_invocation + wait_time = min_interval_seconds - elapsed_since_last + else: + wait_time = 0.0 + logger.debug(f'@{time.time()}> wait_time = {wait_time}') + return wait_time + + def wrapper_wrapper_rate_limited(*args, **kargs) -> Any: + with cv: + while True: + if cv.wait_for( + lambda: may_proceed() <= 0.0, + timeout=may_proceed(), + ): + break + with cv: + logger.debug(f'@{time.time()}> calling it...') + ret = func(*args, **kargs) + last_invocation_timestamp[0] = time.time() + logger.debug( + f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}' + ) + cv.notify() + return ret + + return wrapper_wrapper_rate_limited + + return wrapper_rate_limited + + def debug_args(func: Callable) -> Callable: - """Print the function signature and return value at each call.""" + """Print the function signature and return value at each call. + + >>> @debug_args + ... def foo(a, b, c): + ... print(a) + ... print(b) + ... print(c) + ... return (a + b, c) + + >>> foo(1, 2.0, "test") + Calling foo(1:, 2.0:, 'test':) + 1 + 2.0 + test + foo returned (3.0, 'test'): + (3.0, 'test') + """ @functools.wraps(func) def wrapper_debug_args(*args, **kwargs): args_repr = [f"{repr(a)}:{type(a)}" for a in args] kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()] signature = ", ".join(args_repr + kwargs_repr) - msg = f"Calling {func.__name__}({signature})" + msg = f"Calling {func.__qualname__}({signature})" print(msg) logger.info(msg) value = func(*args, **kwargs) - msg = f"{func.__name__!r} returned {value!r}:{type(value)}" + msg = f"{func.__qualname__} returned {value!r}:{type(value)}" + print(msg) logger.info(msg) return value + return wrapper_debug_args def debug_count_calls(func: Callable) -> Callable: - """Count function invocations and print a message befor every call.""" + """Count function invocations and print a message befor every call. + + >>> @debug_count_calls + ... def factoral(x): + ... if x == 1: + ... return 1 + ... return x * factoral(x - 1) + + >>> factoral(5) + Call #1 of 'factoral' + Call #2 of 'factoral' + Call #3 of 'factoral' + Call #4 of 'factoral' + Call #5 of 'factoral' + 120 + + """ @functools.wraps(func) def wrapper_debug_count_calls(*args, **kwargs): @@ -90,11 +222,12 @@ def debug_count_calls(func: Callable) -> Callable: print(msg) logger.info(msg) return func(*args, **kwargs) - wrapper_debug_count_calls.num_calls = 0 + + wrapper_debug_count_calls.num_calls = 0 # type: ignore return wrapper_debug_count_calls -class DelayWhen(enum.Enum): +class DelayWhen(enum.IntEnum): BEFORE_CALL = 1 AFTER_CALL = 2 BEFORE_AND_AFTER = 3 @@ -110,23 +243,33 @@ def delay( Slow down a function by inserting a delay before and/or after its invocation. + + >>> import time + + >>> @delay(seconds=1.0) + ... def foo(): + ... pass + + >>> start = time.time() + >>> foo() + >>> dur = time.time() - start + >>> dur >= 1.0 + True + """ def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): if when & DelayWhen.BEFORE_CALL: - logger.debug( - f"@delay for {seconds}s BEFORE_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}") time.sleep(seconds) retval = func(*args, **kwargs) if when & DelayWhen.AFTER_CALL: - logger.debug( - f"@delay for {seconds}s AFTER_CALL to {func.__name__}" - ) + logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}") time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -139,6 +282,7 @@ class _SingletonWrapper: """ A singleton wrapper class. Its instances would be created for each decorated class. + """ def __init__(self, cls): @@ -160,6 +304,19 @@ def singleton(cls): A singleton decorator. Returns a wrapper objects. A call on that object returns a single instance object of decorated class. Use the __wrapped__ attribute to access decorated class directly in unit tests + + >>> @singleton + ... class foo(object): + ... pass + + >>> a = foo() + >>> b = foo() + >>> a is b + True + + >>> id(a) == id(b) + True + """ return _SingletonWrapper(cls) @@ -170,6 +327,32 @@ def memoized(func: Callable) -> Callable: The cache here is a dict with a key based on the arguments to the call. Consider also: functools.lru_cache for a more advanced implementation. + + >>> import time + + >>> @memoized + ... def expensive(arg) -> int: + ... # Simulate something slow to compute or lookup + ... time.sleep(1.0) + ... return arg * arg + + >>> start = time.time() + >>> expensive(5) # Takes about 1 sec + 25 + + >>> expensive(3) # Also takes about 1 sec + 9 + + >>> expensive(5) # Pulls from cache, fast + 25 + + >>> expensive(3) # Pulls from cache again, fast + 9 + + >>> dur = time.time() - start + >>> dur < 3.0 + True + """ @functools.wraps(func) @@ -177,14 +360,13 @@ def memoized(func: Callable) -> Callable: cache_key = args + tuple(kwargs.items()) if cache_key not in wrapper_memoized.cache: value = func(*args, **kwargs) - logger.debug( - f"Memoizing {cache_key} => {value} for {func.__name__}" - ) + logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}") wrapper_memoized.cache[cache_key] = value else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] - wrapper_memoized.cache = dict() + + wrapper_memoized.cache = dict() # type: ignore return wrapper_memoized @@ -192,7 +374,7 @@ def retry_predicate( tries: int, *, predicate: Callable[..., bool], - delay_sec: float = 3, + delay_sec: float = 3.0, backoff: float = 2.0, ): """Retries a function or method up to a certain number of times @@ -202,10 +384,11 @@ def retry_predicate( delay_sec sets the initial delay period in seconds. backoff is a multiplied (must be >1) used to modify the delay. predicate is a function that will be passed the retval of the - decorated function and must return True to stop or False to - retry. + decorated function and must return True to stop or False to + retry. + """ - if backoff < 1: + if backoff < 1.0: msg = f"backoff must be greater than or equal to 1, got {backoff}" logger.critical(msg) raise ValueError(msg) @@ -225,9 +408,11 @@ def retry_predicate( @functools.wraps(f) def f_retry(*args, **kwargs): mtries, mdelay = tries, delay_sec # make mutable + logger.debug(f'deco_retry: will make up to {mtries} attempts...') retval = f(*args, **kwargs) while mtries > 0: if predicate(retval) is True: + logger.debug('Predicate succeeded, deco_retry is done.') return retval logger.debug("Predicate failed, sleeping and retrying.") mtries -= 1 @@ -235,11 +420,39 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): + """A helper for @retry_predicate that retries a decorated + function as long as it keeps returning False. + + >>> import time + + >>> counter = 0 + + >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1) + ... def foo(): + ... global counter + ... counter += 1 + ... return counter >= 3 + + >>> start = time.time() + >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed + True + + >>> dur = time.time() - start + >>> counter + 3 + >>> dur > 2.0 + True + >>> dur < 2.3 + True + + """ return retry_predicate( tries, predicate=lambda x: x is True, @@ -249,6 +462,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0): + """Another helper for @retry_predicate above. Retries up to N + times so long as the wrapped function returns None with a delay + between each retry and a backoff that can increase the delay. + + """ return retry_predicate( tries, predicate=lambda x: x is not None, @@ -261,13 +479,15 @@ def deprecated(func): """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. + """ @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): - msg = f"Call to deprecated function {func.__name__}" + msg = f"Call to deprecated function {func.__qualname__}" logger.warning(msg) - warnings.warn(msg, category=DeprecationWarning) + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) + print(msg, file=sys.stderr) return func(*args, **kwargs) return wrapper_deprecated @@ -296,7 +516,6 @@ def thunkify(func): exc[1] = sys.exc_info() # (type, value, traceback) msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}" logger.warning(msg) - print(msg) finally: wait_event.set() @@ -323,9 +542,9 @@ def thunkify(func): def _raise_exception(exception, error_message: Optional[str]): if error_message is None: - raise exception() + raise Exception() else: - raise exception(error_message) + raise Exception(error_message) def _target(queue, function, *args, **kwargs): @@ -338,15 +557,15 @@ def _target(queue, function, *args, **kwargs): """ try: queue.put((True, function(*args, **kwargs))) - except: + except Exception: queue.put((False, sys.exc_info()[1])) class _Timeout(object): - """Wrap a function and add a timeout (limit) attribute to it. + """Wrap a function and add a timeout to it. Instances of this class are automatically generated by the add_timeout - function defined below. + function defined below. Do not use directly. """ def __init__( @@ -423,17 +642,31 @@ def timeout( main thread). When not using signals, timeout granularity will be rounded to the nearest 0.1s. - Raises an exception when the timeout is reached. + Raises an exception when/if the timeout is reached. It is illegal to pass anything other than a function as the first parameter. The function is wrapped and returned to the caller. + + >>> @timeout(0.2) + ... def foo(delay: float): + ... time.sleep(delay) + ... return "ok" + + >>> foo(0) + 'ok' + + >>> foo(1.0) + Traceback (most recent call last): + ... + Exception: Function call timed out + """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): - if use_signals: def handler(signum, frame): @@ -471,37 +704,19 @@ def timeout( return decorate -class non_reentrant_code(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False +def synchronized(lock): + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() return _gatekeeper - -class rlocked(object): - def __init__(self): - self._lock = threading.RLock - self._entered = False - - def __call__(self, f): - def _gatekeeper(*args, **kwargs): - with self._lock: - if self._entered: - return - self._entered = True - f(*args, **kwargs) - self._entered = False - return _gatekeeper + return wrap def call_with_sample_rate(sample_rate: float) -> Callable: @@ -516,10 +731,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable: if random.uniform(0, 1) < sample_rate: return f(*args, **kwargs) else: - logger.debug( - f"@call_with_sample_rate skipping a call to {f.__name__}" - ) + logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}") + return _call_with_sample_rate + return decorator @@ -528,6 +743,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -536,4 +752,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class + + +if __name__ == '__main__': + import doctest + + doctest.testmod()