Make rate_limited work and add doctest.
[python_utils.git] / decorator_utils.py
index 375cbad436feca20d57750b8276537c94f66f060..1e0fe18c4063b285e84d561b30b39e5b4245d001 100644 (file)
@@ -2,7 +2,6 @@
 
 """Decorators."""
 
-import datetime
 import enum
 import functools
 import inspect
@@ -15,9 +14,11 @@ import sys
 import threading
 import time
 import traceback
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
 import warnings
 
+# This module is commonly used by others in here and should avoid
+# taking any unnecessary dependencies back on them.
 import exceptions
 
 
@@ -45,21 +46,95 @@ def invocation_logged(func: Callable) -> Callable:
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
-        now = datetime.datetime.now()
-        ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
-        msg = f"[{ts}]: Entered {func.__name__}"
+        msg = f"Entered {func.__qualname__}"
         print(msg)
         logger.info(msg)
         ret = func(*args, **kwargs)
-        now = datetime.datetime.now()
-        ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
-        msg = f"[{ts}]: Exited {func.__name__}"
+        msg = f"Exited {func.__qualname__}"
         print(msg)
         logger.info(msg)
         return ret
     return wrapper_invocation_logged
 
 
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+    """Limit invocation of a wrapped function to n calls per period.
+    Thread safe.  In testing this was relatively fair with multiple
+    threads using it though that hasn't been measured.
+
+    >>> import time
+    >>> import decorator_utils
+    >>> import thread_utils
+
+    >>> calls = 0
+
+    >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
+    ... def limited(x: int):
+    ...     global calls
+    ...     calls += 1
+
+    >>> @thread_utils.background_thread
+    ... def a(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> @thread_utils.background_thread
+    ... def b(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> start = time.time()
+    >>> (t1, e1) = a()
+    >>> (t2, e2) = b()
+    >>> t1.join()
+    >>> t2.join()
+    >>> end = time.time()
+    >>> dur = end - start
+    >>> dur > 5.0
+    True
+
+    >>> calls
+    6
+
+    """
+    min_interval_seconds = per_period_in_seconds / float(n_calls)
+
+    def wrapper_rate_limited(func: Callable) -> Callable:
+        cv = threading.Condition()
+        last_invocation_timestamp = [0.0]
+
+        def may_proceed() -> float:
+            now = time.time()
+            last_invocation = last_invocation_timestamp[0]
+            if last_invocation != 0.0:
+                elapsed_since_last = now - last_invocation
+                wait_time = min_interval_seconds - elapsed_since_last
+            else:
+                wait_time = 0.0
+            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            return wait_time
+
+        def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
+            with cv:
+                while True:
+                    if cv.wait_for(
+                        lambda: may_proceed() <= 0.0,
+                        timeout=may_proceed(),
+                    ):
+                        break
+            with cv:
+                logger.debug(f'@{time.time()}> calling it...')
+                ret = func(*args, **kargs)
+                last_invocation_timestamp[0] = time.time()
+                logger.debug(
+                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                )
+                cv.notify()
+            return ret
+        return wrapper_wrapper_rate_limited
+    return wrapper_rate_limited
+
+
 def debug_args(func: Callable) -> Callable:
     """Print the function signature and return value at each call."""
 
@@ -190,7 +265,7 @@ def retry_predicate(
     tries: int,
     *,
     predicate: Callable[..., bool],
-    delay_sec: float = 3,
+    delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
     """Retries a function or method up to a certain number of times
@@ -200,10 +275,10 @@ def retry_predicate(
     delay_sec sets the initial delay period in seconds.
     backoff is a multiplied (must be >1) used to modify the delay.
     predicate is a function that will be passed the retval of the
-      decorated function and must return True to stop or False to
-      retry.
+    decorated function and must return True to stop or False to
+    retry.
     """
-    if backoff < 1:
+    if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
         raise ValueError(msg)
@@ -223,9 +298,11 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
+            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
+                    logger.debug('Predicate succeeded, deco_retry is done.')
                     return retval
                 logger.debug("Predicate failed, sleeping and retrying.")
                 mtries -= 1
@@ -293,8 +370,8 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                logger.warning(msg)
                 print(msg)
+                logger.warning(msg)
             finally:
                 wait_event.set()
 
@@ -336,7 +413,7 @@ def _target(queue, function, *args, **kwargs):
     """
     try:
         queue.put((True, function(*args, **kwargs)))
-    except:
+    except Exception:
         queue.put((False, sys.exc_info()[1]))
 
 
@@ -431,7 +508,6 @@ def timeout(
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
-
         if use_signals:
 
             def handler(signum, frame):
@@ -535,3 +611,10 @@ def decorate_matching_methods_with(decorator, acl=None):
                     setattr(cls, name, decorator(m))
         return cls
     return decorate_the_class
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
+
+