More type annotations.
[python_utils.git] / decorator_utils.py
index 70a88d37ad0dbad37edff45aba0130dcc5a26271..cd69639448425ce3a47073c7e423ea98d6704b2b 100644 (file)
@@ -14,7 +14,7 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional, Tuple
+from typing import Any, Callable, Optional
 import warnings
 
 # This module is commonly used by others in here and should avoid
 import warnings
 
 # This module is commonly used by others in here and should avoid
@@ -26,7 +26,17 @@ logger = logging.getLogger(__name__)
 
 
 def timed(func: Callable) -> Callable:
 
 
 def timed(func: Callable) -> Callable:
-    """Print the runtime of the decorated function."""
+    """Print the runtime of the decorated function.
+
+    >>> @timed
+    ... def foo():
+    ...     import time
+    ...     time.sleep(0.1)
+
+    >>> foo()  # doctest: +ELLIPSIS
+    Finished foo in ...
+
+    """
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
@@ -34,15 +44,27 @@ def timed(func: Callable) -> Callable:
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
-        msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
+        msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function."""
+    """Log the call of a function.
+
+    >>> @invocation_logged
+    ... def foo():
+    ...     print('Hello, world.')
+
+    >>> foo()
+    Entered foo
+    Hello, world.
+    Exited foo
+
+    """
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
@@ -54,6 +76,7 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
     return wrapper_invocation_logged
 
 
@@ -62,6 +85,40 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
 
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
 
+    >>> import time
+    >>> import decorator_utils
+    >>> import thread_utils
+
+    >>> calls = 0
+
+    >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
+    ... def limited(x: int):
+    ...     global calls
+    ...     calls += 1
+
+    >>> @thread_utils.background_thread
+    ... def a(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> @thread_utils.background_thread
+    ... def b(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> start = time.time()
+    >>> (t1, e1) = a()
+    >>> (t2, e2) = b()
+    >>> t1.join()
+    >>> t2.join()
+    >>> end = time.time()
+    >>> dur = end - start
+    >>> dur > 0.5
+    True
+
+    >>> calls
+    6
+
     """
     min_interval_seconds = per_period_in_seconds / float(n_calls)
 
     """
     min_interval_seconds = per_period_in_seconds / float(n_calls)
 
@@ -77,45 +134,86 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
+            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             with cv:
                 while True:
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             with cv:
                 while True:
-                    cv.wait_for(
+                    if cv.wait_for(
                         lambda: may_proceed() <= 0.0,
                         timeout=may_proceed(),
                         lambda: may_proceed() <= 0.0,
                         timeout=may_proceed(),
-                    )
-                    break
-            ret = func(*args, **kargs)
+                    ):
+                        break
             with cv:
             with cv:
+                logger.debug(f'@{time.time()}> calling it...')
+                ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 last_invocation_timestamp[0] = time.time()
+                logger.debug(
+                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                )
                 cv.notify()
             return ret
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
 def debug_args(func: Callable) -> Callable:
     return wrapper_rate_limited
 
 
 def debug_args(func: Callable) -> Callable:
-    """Print the function signature and return value at each call."""
+    """Print the function signature and return value at each call.
+
+    >>> @debug_args
+    ... def foo(a, b, c):
+    ...     print(a)
+    ...     print(b)
+    ...     print(c)
+    ...     return (a + b, c)
+
+    >>> foo(1, 2.0, "test")
+    Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
+    1
+    2.0
+    test
+    foo returned (3.0, 'test'):<class 'tuple'>
+    (3.0, 'test')
+    """
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
-        msg = f"Calling {func.__name__}({signature})"
+        msg = f"Calling {func.__qualname__}({signature})"
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
-        msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
+        msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
+        print(msg)
         logger.info(msg)
         return value
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
-    """Count function invocations and print a message befor every call."""
+    """Count function invocations and print a message befor every call.
+
+    >>> @debug_count_calls
+    ... def factoral(x):
+    ...     if x == 1:
+    ...         return 1
+    ...     return x * factoral(x - 1)
+
+    >>> factoral(5)
+    Call #1 of 'factoral'
+    Call #2 of 'factoral'
+    Call #3 of 'factoral'
+    Call #4 of 'factoral'
+    Call #5 of 'factoral'
+    120
+
+    """
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
@@ -124,11 +222,12 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
     return wrapper_debug_count_calls
 
 
-class DelayWhen(enum.Enum):
+class DelayWhen(enum.IntEnum):
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -144,23 +243,33 @@ def delay(
 
     Slow down a function by inserting a delay before and/or after its
     invocation.
 
     Slow down a function by inserting a delay before and/or after its
     invocation.
+
+    >>> import time
+
+    >>> @delay(seconds=1.0)
+    ... def foo():
+    ...     pass
+
+    >>> start = time.time()
+    >>> foo()
+    >>> dur = time.time() - start
+    >>> dur >= 1.0
+    True
+
     """
 
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
     """
 
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
                 time.sleep(seconds)
             return retval
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
         return wrapper_delay
 
     if _func is None:
@@ -173,6 +282,7 @@ class _SingletonWrapper:
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
+
     """
 
     def __init__(self, cls):
     """
 
     def __init__(self, cls):
@@ -194,6 +304,19 @@ def singleton(cls):
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
+
+    >>> @singleton
+    ... class foo(object):
+    ...     pass
+
+    >>> a = foo()
+    >>> b = foo()
+    >>> a is b
+    True
+
+    >>> id(a) == id(b)
+    True
+
     """
     return _SingletonWrapper(cls)
 
     """
     return _SingletonWrapper(cls)
 
@@ -204,6 +327,32 @@ def memoized(func: Callable) -> Callable:
     The cache here is a dict with a key based on the arguments to the
     call.  Consider also: functools.lru_cache for a more advanced
     implementation.
     The cache here is a dict with a key based on the arguments to the
     call.  Consider also: functools.lru_cache for a more advanced
     implementation.
+
+    >>> import time
+
+    >>> @memoized
+    ... def expensive(arg) -> int:
+    ...     # Simulate something slow to compute or lookup
+    ...     time.sleep(1.0)
+    ...     return arg * arg
+
+    >>> start = time.time()
+    >>> expensive(5)           # Takes about 1 sec
+    25
+
+    >>> expensive(3)           # Also takes about 1 sec
+    9
+
+    >>> expensive(5)           # Pulls from cache, fast
+    25
+
+    >>> expensive(3)           # Pulls from cache again, fast
+    9
+
+    >>> dur = time.time() - start
+    >>> dur < 3.0
+    True
+
     """
 
     @functools.wraps(func)
     """
 
     @functools.wraps(func)
@@ -211,14 +360,13 @@ def memoized(func: Callable) -> Callable:
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = dict()  # type: ignore
     return wrapper_memoized
 
 
     return wrapper_memoized
 
 
@@ -238,6 +386,7 @@ def retry_predicate(
     predicate is a function that will be passed the retval of the
     decorated function and must return True to stop or False to
     retry.
     predicate is a function that will be passed the retval of the
     decorated function and must return True to stop or False to
     retry.
+
     """
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
     """
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
@@ -271,11 +420,39 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
         return f_retry
+
     return deco_retry
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     return deco_retry
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """A helper for @retry_predicate that retries a decorated
+    function as long as it keeps returning False.
+
+    >>> import time
+
+    >>> counter = 0
+
+    >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
+    ... def foo():
+    ...     global counter
+    ...     counter += 1
+    ...     return counter >= 3
+
+    >>> start = time.time()
+    >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
+    True
+
+    >>> dur = time.time() - start
+    >>> counter
+    3
+    >>> dur > 2.0
+    True
+    >>> dur < 2.3
+    True
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
@@ -285,6 +462,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """Another helper for @retry_predicate above.  Retries up to N
+    times so long as the wrapped function returns None with a delay
+    between each retry and a backoff that can increase the delay.
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -297,13 +479,15 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
+
     """
 
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
     """
 
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
-        msg = f"Call to deprecated function {func.__name__}"
+        msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
+        print(msg, file=sys.stderr)
         return func(*args, **kwargs)
 
     return wrapper_deprecated
         return func(*args, **kwargs)
 
     return wrapper_deprecated
@@ -331,7 +515,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -359,9 +542,9 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception()
     else:
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
 
 
 def _target(queue, function, *args, **kwargs):
@@ -379,10 +562,10 @@ def _target(queue, function, *args, **kwargs):
 
 
 class _Timeout(object):
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
     """
 
     def __init__(
@@ -459,13 +642,28 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
         import thread_utils
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
@@ -506,37 +704,19 @@ def timeout(
     return decorate
 
 
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+def synchronized(lock):
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
@@ -551,10 +731,10 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
+
         return _call_with_sample_rate
         return _call_with_sample_rate
+
     return decorator
 
 
     return decorator
 
 
@@ -563,6 +743,7 @@ def decorate_matching_methods_with(decorator, acl=None):
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
     prefix.  If prefix is None (default), decorate all methods in the
     class.
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -571,4 +752,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
     return decorate_the_class
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()