Adds some doctests to decorators.
authorScott <[email protected]>
Mon, 10 Jan 2022 23:31:11 +0000 (15:31 -0800)
committerScott <[email protected]>
Mon, 10 Jan 2022 23:31:11 +0000 (15:31 -0800)
decorator_utils.py

index 1e0fe18c4063b285e84d561b30b39e5b4245d001..9b848ed792144919b863b20c82e846bcd509bbe8 100644 (file)
@@ -26,7 +26,17 @@ logger = logging.getLogger(__name__)
 
 
 def timed(func: Callable) -> Callable:
-    """Print the runtime of the decorated function."""
+    """Print the runtime of the decorated function.
+
+    >>> @timed
+    ... def foo():
+    ...     import time
+    ...     time.sleep(0.1)
+
+    >>> foo()  # doctest: +ELLIPSIS
+    Finished foo in ...
+
+    """
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
@@ -34,7 +44,7 @@ def timed(func: Callable) -> Callable:
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
-        msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
+        msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
         print(msg)
         logger.info(msg)
         return value
@@ -42,7 +52,18 @@ def timed(func: Callable) -> Callable:
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function."""
+    """Log the call of a function.
+
+    >>> @invocation_logged
+    ... def foo():
+    ...     print('Hello, world.')
+
+    >>> foo()
+    Entered foo
+    Hello, world.
+    Exited foo
+
+    """
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
@@ -68,7 +89,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
 
     >>> calls = 0
 
-    >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
+    >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
     ... def limited(x: int):
     ...     global calls
     ...     calls += 1
@@ -90,7 +111,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
     >>> t2.join()
     >>> end = time.time()
     >>> dur = end - start
-    >>> dur > 5.0
+    >>> dur > 0.5
     True
 
     >>> calls
@@ -136,25 +157,58 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
 
 
 def debug_args(func: Callable) -> Callable:
-    """Print the function signature and return value at each call."""
+    """Print the function signature and return value at each call.
+
+    >>> @debug_args
+    ... def foo(a, b, c):
+    ...     print(a)
+    ...     print(b)
+    ...     print(c)
+    ...     return (a + b, c)
+
+    >>> foo(1, 2.0, "test")
+    Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
+    1
+    2.0
+    test
+    foo returned (3.0, 'test'):<class 'tuple'>
+    (3.0, 'test')
+    """
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
-        msg = f"Calling {func.__name__}({signature})"
+        msg = f"Calling {func.__qualname__}({signature})"
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
-        msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
+        msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
+        print(msg)
         logger.info(msg)
         return value
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
-    """Count function invocations and print a message befor every call."""
+    """Count function invocations and print a message befor every call.
+
+    >>> @debug_count_calls
+    ... def factoral(x):
+    ...     if x == 1:
+    ...         return 1
+    ...     return x * factoral(x - 1)
+
+    >>> factoral(5)
+    Call #1 of 'factoral'
+    Call #2 of 'factoral'
+    Call #3 of 'factoral'
+    Call #4 of 'factoral'
+    Call #5 of 'factoral'
+    120
+
+    """
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
@@ -167,7 +221,7 @@ def debug_count_calls(func: Callable) -> Callable:
     return wrapper_debug_count_calls
 
 
-class DelayWhen(enum.Enum):
+class DelayWhen(enum.IntEnum):
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -183,8 +237,20 @@ def delay(
 
     Slow down a function by inserting a delay before and/or after its
     invocation.
-    """
 
+    >>> import time
+
+    >>> @delay(seconds=1.0)
+    ... def foo():
+    ...     pass
+
+    >>> start = time.time()
+    >>> foo()
+    >>> dur = time.time() - start
+    >>> dur >= 1.0
+    True
+
+    """
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
@@ -212,6 +278,7 @@ class _SingletonWrapper:
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
+
     """
 
     def __init__(self, cls):
@@ -233,6 +300,19 @@ def singleton(cls):
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
+
+    >>> @singleton
+    ... class foo(object):
+    ...     pass
+
+    >>> a = foo()
+    >>> b = foo()
+    >>> a is b
+    True
+
+    >>> id(a) == id(b)
+    True
+
     """
     return _SingletonWrapper(cls)
 
@@ -243,8 +323,33 @@ def memoized(func: Callable) -> Callable:
     The cache here is a dict with a key based on the arguments to the
     call.  Consider also: functools.lru_cache for a more advanced
     implementation.
-    """
 
+    >>> import time
+
+    >>> @memoized
+    ... def expensive(arg) -> int:
+    ...     # Simulate something slow to compute or lookup
+    ...     time.sleep(1.0)
+    ...     return arg * arg
+
+    >>> start = time.time()
+    >>> expensive(5)           # Takes about 1 sec
+    25
+
+    >>> expensive(3)           # Also takes about 1 sec
+    9
+
+    >>> expensive(5)           # Pulls from cache, fast
+    25
+
+    >>> expensive(3)           # Pulls from cache again, fast
+    9
+
+    >>> dur = time.time() - start
+    >>> dur < 3.0
+    True
+
+    """
     @functools.wraps(func)
     def wrapper_memoized(*args, **kwargs):
         cache_key = args + tuple(kwargs.items())
@@ -277,6 +382,7 @@ def retry_predicate(
     predicate is a function that will be passed the retval of the
     decorated function and must return True to stop or False to
     retry.
+
     """
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
@@ -315,6 +421,32 @@ def retry_predicate(
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """A helper for @retry_predicate that retries a decorated
+    function as long as it keeps returning False.
+
+    >>> import time
+
+    >>> counter = 0
+
+    >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
+    ... def foo():
+    ...     global counter
+    ...     counter += 1
+    ...     return counter >= 3
+
+    >>> start = time.time()
+    >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
+    True
+
+    >>> dur = time.time() - start
+    >>> counter
+    3
+    >>> dur > 2.0
+    True
+    >>> dur < 2.2
+    True
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
@@ -324,6 +456,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """Another helper for @retry_predicate above.  Retries up to N
+    times so long as the wrapped function returns None with a delay
+    between each retry and a backoff that can increase the delay.
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -336,15 +473,15 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
-    """
 
+    """
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
-        msg = f"Call to deprecated function {func.__name__}"
+        msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
         warnings.warn(msg, category=DeprecationWarning)
+        print(msg, file=sys.stderr)
         return func(*args, **kwargs)
-
     return wrapper_deprecated
 
 
@@ -617,4 +754,3 @@ if __name__ == '__main__':
     import doctest
     doctest.testmod()
 
-