From d742c4a0f3a177e3ab55a9eb2d30e0e37af2f044 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 10 Jan 2022 15:31:11 -0800 Subject: [PATCH] Adds some doctests to decorators. --- decorator_utils.py | 168 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 152 insertions(+), 16 deletions(-) diff --git a/decorator_utils.py b/decorator_utils.py index 1e0fe18..9b848ed 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -26,7 +26,17 @@ logger = logging.getLogger(__name__) def timed(func: Callable) -> Callable: - """Print the runtime of the decorated function.""" + """Print the runtime of the decorated function. + + >>> @timed + ... def foo(): + ... import time + ... time.sleep(0.1) + + >>> foo() # doctest: +ELLIPSIS + Finished foo in ... + + """ @functools.wraps(func) def wrapper_timer(*args, **kwargs): @@ -34,7 +44,7 @@ def timed(func: Callable) -> Callable: value = func(*args, **kwargs) end_time = time.perf_counter() run_time = end_time - start_time - msg = f"Finished {func.__name__!r} in {run_time:.4f}s" + msg = f"Finished {func.__qualname__} in {run_time:.4f}s" print(msg) logger.info(msg) return value @@ -42,7 +52,18 @@ def timed(func: Callable) -> Callable: def invocation_logged(func: Callable) -> Callable: - """Log the call of a function.""" + """Log the call of a function. + + >>> @invocation_logged + ... def foo(): + ... print('Hello, world.') + + >>> foo() + Entered foo + Hello, world. + Exited foo + + """ @functools.wraps(func) def wrapper_invocation_logged(*args, **kwargs): @@ -68,7 +89,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl >>> calls = 0 - >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0) + >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0) ... def limited(x: int): ... global calls ... calls += 1 @@ -90,7 +111,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl >>> t2.join() >>> end = time.time() >>> dur = end - start - >>> dur > 5.0 + >>> dur > 0.5 True >>> calls @@ -136,25 +157,58 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl def debug_args(func: Callable) -> Callable: - """Print the function signature and return value at each call.""" + """Print the function signature and return value at each call. + + >>> @debug_args + ... def foo(a, b, c): + ... print(a) + ... print(b) + ... print(c) + ... return (a + b, c) + + >>> foo(1, 2.0, "test") + Calling foo(1:, 2.0:, 'test':) + 1 + 2.0 + test + foo returned (3.0, 'test'): + (3.0, 'test') + """ @functools.wraps(func) def wrapper_debug_args(*args, **kwargs): args_repr = [f"{repr(a)}:{type(a)}" for a in args] kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()] signature = ", ".join(args_repr + kwargs_repr) - msg = f"Calling {func.__name__}({signature})" + msg = f"Calling {func.__qualname__}({signature})" print(msg) logger.info(msg) value = func(*args, **kwargs) - msg = f"{func.__name__!r} returned {value!r}:{type(value)}" + msg = f"{func.__qualname__} returned {value!r}:{type(value)}" + print(msg) logger.info(msg) return value return wrapper_debug_args def debug_count_calls(func: Callable) -> Callable: - """Count function invocations and print a message befor every call.""" + """Count function invocations and print a message befor every call. + + >>> @debug_count_calls + ... def factoral(x): + ... if x == 1: + ... return 1 + ... return x * factoral(x - 1) + + >>> factoral(5) + Call #1 of 'factoral' + Call #2 of 'factoral' + Call #3 of 'factoral' + Call #4 of 'factoral' + Call #5 of 'factoral' + 120 + + """ @functools.wraps(func) def wrapper_debug_count_calls(*args, **kwargs): @@ -167,7 +221,7 @@ def debug_count_calls(func: Callable) -> Callable: return wrapper_debug_count_calls -class DelayWhen(enum.Enum): +class DelayWhen(enum.IntEnum): BEFORE_CALL = 1 AFTER_CALL = 2 BEFORE_AND_AFTER = 3 @@ -183,8 +237,20 @@ def delay( Slow down a function by inserting a delay before and/or after its invocation. - """ + >>> import time + + >>> @delay(seconds=1.0) + ... def foo(): + ... pass + + >>> start = time.time() + >>> foo() + >>> dur = time.time() - start + >>> dur >= 1.0 + True + + """ def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): @@ -212,6 +278,7 @@ class _SingletonWrapper: """ A singleton wrapper class. Its instances would be created for each decorated class. + """ def __init__(self, cls): @@ -233,6 +300,19 @@ def singleton(cls): A singleton decorator. Returns a wrapper objects. A call on that object returns a single instance object of decorated class. Use the __wrapped__ attribute to access decorated class directly in unit tests + + >>> @singleton + ... class foo(object): + ... pass + + >>> a = foo() + >>> b = foo() + >>> a is b + True + + >>> id(a) == id(b) + True + """ return _SingletonWrapper(cls) @@ -243,8 +323,33 @@ def memoized(func: Callable) -> Callable: The cache here is a dict with a key based on the arguments to the call. Consider also: functools.lru_cache for a more advanced implementation. - """ + >>> import time + + >>> @memoized + ... def expensive(arg) -> int: + ... # Simulate something slow to compute or lookup + ... time.sleep(1.0) + ... return arg * arg + + >>> start = time.time() + >>> expensive(5) # Takes about 1 sec + 25 + + >>> expensive(3) # Also takes about 1 sec + 9 + + >>> expensive(5) # Pulls from cache, fast + 25 + + >>> expensive(3) # Pulls from cache again, fast + 9 + + >>> dur = time.time() - start + >>> dur < 3.0 + True + + """ @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) @@ -277,6 +382,7 @@ def retry_predicate( predicate is a function that will be passed the retval of the decorated function and must return True to stop or False to retry. + """ if backoff < 1.0: msg = f"backoff must be greater than or equal to 1, got {backoff}" @@ -315,6 +421,32 @@ def retry_predicate( def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): + """A helper for @retry_predicate that retries a decorated + function as long as it keeps returning False. + + >>> import time + + >>> counter = 0 + + >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1) + ... def foo(): + ... global counter + ... counter += 1 + ... return counter >= 3 + + >>> start = time.time() + >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed + True + + >>> dur = time.time() - start + >>> counter + 3 + >>> dur > 2.0 + True + >>> dur < 2.2 + True + + """ return retry_predicate( tries, predicate=lambda x: x is True, @@ -324,6 +456,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0): + """Another helper for @retry_predicate above. Retries up to N + times so long as the wrapped function returns None with a delay + between each retry and a backoff that can increase the delay. + + """ return retry_predicate( tries, predicate=lambda x: x is not None, @@ -336,15 +473,15 @@ def deprecated(func): """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. - """ + """ @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): - msg = f"Call to deprecated function {func.__name__}" + msg = f"Call to deprecated function {func.__qualname__}" logger.warning(msg) warnings.warn(msg, category=DeprecationWarning) + print(msg, file=sys.stderr) return func(*args, **kwargs) - return wrapper_deprecated @@ -617,4 +754,3 @@ if __name__ == '__main__': import doctest doctest.testmod() - -- 2.47.1