More type annotations.
[python_utils.git] / decorator_utils.py
index 4d882bed7ac4486db741b76b587b402a6dac147e..cd69639448425ce3a47073c7e423ea98d6704b2b 100644 (file)
@@ -2,9 +2,9 @@
 
 """Decorators."""
 
-import datetime
 import enum
 import functools
+import inspect
 import logging
 import math
 import multiprocessing
@@ -14,16 +14,29 @@ import sys
 import threading
 import time
 import traceback
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
 import warnings
 
-import thread_utils
+# This module is commonly used by others in here and should avoid
+# taking any unnecessary dependencies back on them.
+import exceptions
+
 
 logger = logging.getLogger(__name__)
 
 
 def timed(func: Callable) -> Callable:
-    """Print the runtime of the decorated function."""
+    """Print the runtime of the decorated function.
+
+    >>> @timed
+    ... def foo():
+    ...     import time
+    ...     time.sleep(0.1)
+
+    >>> foo()  # doctest: +ELLIPSIS
+    Finished foo in ...
+
+    """
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
@@ -31,53 +44,176 @@ def timed(func: Callable) -> Callable:
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
-        msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
+        msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function."""
+    """Log the call of a function.
+
+    >>> @invocation_logged
+    ... def foo():
+    ...     print('Hello, world.')
+
+    >>> foo()
+    Entered foo
+    Hello, world.
+    Exited foo
+
+    """
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
-        now = datetime.datetime.now()
-        ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
-        msg = f"[{ts}]: Entered {func.__name__}"
+        msg = f"Entered {func.__qualname__}"
         print(msg)
         logger.info(msg)
         ret = func(*args, **kwargs)
-        now = datetime.datetime.now()
-        ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
-        msg = f"[{ts}]: Exited {func.__name__}"
+        msg = f"Exited {func.__qualname__}"
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+    """Limit invocation of a wrapped function to n calls per period.
+    Thread safe.  In testing this was relatively fair with multiple
+    threads using it though that hasn't been measured.
+
+    >>> import time
+    >>> import decorator_utils
+    >>> import thread_utils
+
+    >>> calls = 0
+
+    >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
+    ... def limited(x: int):
+    ...     global calls
+    ...     calls += 1
+
+    >>> @thread_utils.background_thread
+    ... def a(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> @thread_utils.background_thread
+    ... def b(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> start = time.time()
+    >>> (t1, e1) = a()
+    >>> (t2, e2) = b()
+    >>> t1.join()
+    >>> t2.join()
+    >>> end = time.time()
+    >>> dur = end - start
+    >>> dur > 0.5
+    True
+
+    >>> calls
+    6
+
+    """
+    min_interval_seconds = per_period_in_seconds / float(n_calls)
+
+    def wrapper_rate_limited(func: Callable) -> Callable:
+        cv = threading.Condition()
+        last_invocation_timestamp = [0.0]
+
+        def may_proceed() -> float:
+            now = time.time()
+            last_invocation = last_invocation_timestamp[0]
+            if last_invocation != 0.0:
+                elapsed_since_last = now - last_invocation
+                wait_time = min_interval_seconds - elapsed_since_last
+            else:
+                wait_time = 0.0
+            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            return wait_time
+
+        def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
+            with cv:
+                while True:
+                    if cv.wait_for(
+                        lambda: may_proceed() <= 0.0,
+                        timeout=may_proceed(),
+                    ):
+                        break
+            with cv:
+                logger.debug(f'@{time.time()}> calling it...')
+                ret = func(*args, **kargs)
+                last_invocation_timestamp[0] = time.time()
+                logger.debug(
+                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                )
+                cv.notify()
+            return ret
+
+        return wrapper_wrapper_rate_limited
+
+    return wrapper_rate_limited
+
+
 def debug_args(func: Callable) -> Callable:
-    """Print the function signature and return value at each call."""
+    """Print the function signature and return value at each call.
+
+    >>> @debug_args
+    ... def foo(a, b, c):
+    ...     print(a)
+    ...     print(b)
+    ...     print(c)
+    ...     return (a + b, c)
+
+    >>> foo(1, 2.0, "test")
+    Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
+    1
+    2.0
+    test
+    foo returned (3.0, 'test'):<class 'tuple'>
+    (3.0, 'test')
+    """
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
-        msg = f"Calling {func.__name__}({signature})"
+        msg = f"Calling {func.__qualname__}({signature})"
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
-        msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
+        msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
+        print(msg)
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
-    """Count function invocations and print a message befor every call."""
+    """Count function invocations and print a message befor every call.
+
+    >>> @debug_count_calls
+    ... def factoral(x):
+    ...     if x == 1:
+    ...         return 1
+    ...     return x * factoral(x - 1)
+
+    >>> factoral(5)
+    Call #1 of 'factoral'
+    Call #2 of 'factoral'
+    Call #3 of 'factoral'
+    Call #4 of 'factoral'
+    Call #5 of 'factoral'
+    120
+
+    """
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
@@ -86,11 +222,12 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
-class DelayWhen(enum.Enum):
+class DelayWhen(enum.IntEnum):
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -106,23 +243,33 @@ def delay(
 
     Slow down a function by inserting a delay before and/or after its
     invocation.
+
+    >>> import time
+
+    >>> @delay(seconds=1.0)
+    ... def foo():
+    ...     pass
+
+    >>> start = time.time()
+    >>> foo()
+    >>> dur = time.time() - start
+    >>> dur >= 1.0
+    True
+
     """
 
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
@@ -135,6 +282,7 @@ class _SingletonWrapper:
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
+
     """
 
     def __init__(self, cls):
@@ -156,6 +304,19 @@ def singleton(cls):
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
+
+    >>> @singleton
+    ... class foo(object):
+    ...     pass
+
+    >>> a = foo()
+    >>> b = foo()
+    >>> a is b
+    True
+
+    >>> id(a) == id(b)
+    True
+
     """
     return _SingletonWrapper(cls)
 
@@ -166,6 +327,32 @@ def memoized(func: Callable) -> Callable:
     The cache here is a dict with a key based on the arguments to the
     call.  Consider also: functools.lru_cache for a more advanced
     implementation.
+
+    >>> import time
+
+    >>> @memoized
+    ... def expensive(arg) -> int:
+    ...     # Simulate something slow to compute or lookup
+    ...     time.sleep(1.0)
+    ...     return arg * arg
+
+    >>> start = time.time()
+    >>> expensive(5)           # Takes about 1 sec
+    25
+
+    >>> expensive(3)           # Also takes about 1 sec
+    9
+
+    >>> expensive(5)           # Pulls from cache, fast
+    25
+
+    >>> expensive(3)           # Pulls from cache again, fast
+    9
+
+    >>> dur = time.time() - start
+    >>> dur < 3.0
+    True
+
     """
 
     @functools.wraps(func)
@@ -173,14 +360,13 @@ def memoized(func: Callable) -> Callable:
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = dict()  # type: ignore
     return wrapper_memoized
 
 
@@ -188,7 +374,7 @@ def retry_predicate(
     tries: int,
     *,
     predicate: Callable[..., bool],
-    delay_sec: float = 3,
+    delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
     """Retries a function or method up to a certain number of times
@@ -198,10 +384,11 @@ def retry_predicate(
     delay_sec sets the initial delay period in seconds.
     backoff is a multiplied (must be >1) used to modify the delay.
     predicate is a function that will be passed the retval of the
-      decorated function and must return True to stop or False to
-      retry.
+    decorated function and must return True to stop or False to
+    retry.
+
     """
-    if backoff < 1:
+    if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
         raise ValueError(msg)
@@ -221,9 +408,11 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
+            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
+                    logger.debug('Predicate succeeded, deco_retry is done.')
                     return retval
                 logger.debug("Predicate failed, sleeping and retrying.")
                 mtries -= 1
@@ -231,11 +420,39 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
+
     return deco_retry
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """A helper for @retry_predicate that retries a decorated
+    function as long as it keeps returning False.
+
+    >>> import time
+
+    >>> counter = 0
+
+    >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
+    ... def foo():
+    ...     global counter
+    ...     counter += 1
+    ...     return counter >= 3
+
+    >>> start = time.time()
+    >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
+    True
+
+    >>> dur = time.time() - start
+    >>> counter
+    3
+    >>> dur > 2.0
+    True
+    >>> dur < 2.3
+    True
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
@@ -245,6 +462,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """Another helper for @retry_predicate above.  Retries up to N
+    times so long as the wrapped function returns None with a delay
+    between each retry and a backoff that can increase the delay.
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -257,13 +479,15 @@ def deprecated(func):
     """This is a decorator which can be used to mark functions
     as deprecated. It will result in a warning being emitted
     when the function is used.
+
     """
 
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
-        msg = f"Call to deprecated function {func.__name__}"
+        msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
+        print(msg, file=sys.stderr)
         return func(*args, **kwargs)
 
     return wrapper_deprecated
@@ -292,7 +516,6 @@ def thunkify(func):
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 logger.warning(msg)
-                print(msg)
             finally:
                 wait_event.set()
 
@@ -317,19 +540,11 @@ def thunkify(func):
 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
 
 
-class TimeoutError(AssertionError):
-    def __init__(self, value: str = "Timed Out"):
-        self.value = value
-
-    def __str__(self):
-        return repr(self.value)
-
-
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception()
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
@@ -342,15 +557,15 @@ def _target(queue, function, *args, **kwargs):
     """
     try:
         queue.put((True, function(*args, **kwargs)))
-    except:
+    except Exception:
         queue.put((False, sys.exc_info()[1]))
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
@@ -417,7 +632,7 @@ class _Timeout(object):
 def timeout(
     seconds: float = 1.0,
     use_signals: Optional[bool] = None,
-    timeout_exception=TimeoutError,
+    timeout_exception=exceptions.TimeoutError,
     error_message="Function call timed out",
 ):
     """Add a timeout parameter to a function and return the function.
@@ -427,16 +642,31 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
+        import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
-
         if use_signals:
 
             def handler(signum, frame):
@@ -474,37 +704,19 @@ def timeout(
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+def synchronized(lock):
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
@@ -519,8 +731,32 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
+
         return _call_with_sample_rate
+
     return decorator
+
+
+def decorate_matching_methods_with(decorator, acl=None):
+    """Apply decorator to all methods in a class whose names begin with
+    prefix.  If prefix is None (default), decorate all methods in the
+    class.
+    """
+
+    def decorate_the_class(cls):
+        for name, m in inspect.getmembers(cls, inspect.isfunction):
+            if acl is None:
+                setattr(cls, name, decorator(m))
+            else:
+                if acl(name):
+                    setattr(cls, name, decorator(m))
+        return cls
+
+    return decorate_the_class
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()