Make rate_limited work and add doctest.
authorScott <[email protected]>
Mon, 10 Jan 2022 22:57:17 +0000 (14:57 -0800)
committerScott <[email protected]>
Mon, 10 Jan 2022 22:57:17 +0000 (14:57 -0800)
decorator_utils.py

index 70a88d37ad0dbad37edff45aba0130dcc5a26271..1e0fe18c4063b285e84d561b30b39e5b4245d001 100644 (file)
@@ -14,7 +14,7 @@ import sys
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional, Tuple
+from typing import Any, Callable, Optional
 import warnings
 
 # This module is commonly used by others in here and should avoid
@@ -62,6 +62,40 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
 
+    >>> import time
+    >>> import decorator_utils
+    >>> import thread_utils
+
+    >>> calls = 0
+
+    >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
+    ... def limited(x: int):
+    ...     global calls
+    ...     calls += 1
+
+    >>> @thread_utils.background_thread
+    ... def a(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> @thread_utils.background_thread
+    ... def b(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> start = time.time()
+    >>> (t1, e1) = a()
+    >>> (t2, e2) = b()
+    >>> t1.join()
+    >>> t2.join()
+    >>> end = time.time()
+    >>> dur = end - start
+    >>> dur > 5.0
+    True
+
+    >>> calls
+    6
+
     """
     min_interval_seconds = per_period_in_seconds / float(n_calls)
 
@@ -77,19 +111,24 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
+            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             with cv:
                 while True:
-                    cv.wait_for(
+                    if cv.wait_for(
                         lambda: may_proceed() <= 0.0,
                         timeout=may_proceed(),
-                    )
-                    break
-            ret = func(*args, **kargs)
+                    ):
+                        break
             with cv:
+                logger.debug(f'@{time.time()}> calling it...')
+                ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
+                logger.debug(
+                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                )
                 cv.notify()
             return ret
         return wrapper_wrapper_rate_limited
@@ -572,3 +611,10 @@ def decorate_matching_methods_with(decorator, acl=None):
                     setattr(cls, name, decorator(m))
         return cls
     return decorate_the_class
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
+
+