Make rate_limited work and add doctest.
[python_utils.git] / decorator_utils.py
index 70a88d37ad0dbad37edff45aba0130dcc5a26271..1e0fe18c4063b285e84d561b30b39e5b4245d001 100644 (file)
@@ -14,7 +14,7 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional, Tuple
+from typing import Any, Callable, Optional
 import warnings
 
 # This module is commonly used by others in here and should avoid
 import warnings
 
 # This module is commonly used by others in here and should avoid
@@ -62,6 +62,40 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
 
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
 
+    >>> import time
+    >>> import decorator_utils
+    >>> import thread_utils
+
+    >>> calls = 0
+
+    >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
+    ... def limited(x: int):
+    ...     global calls
+    ...     calls += 1
+
+    >>> @thread_utils.background_thread
+    ... def a(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> @thread_utils.background_thread
+    ... def b(stop):
+    ...     for _ in range(3):
+    ...         limited(_)
+
+    >>> start = time.time()
+    >>> (t1, e1) = a()
+    >>> (t2, e2) = b()
+    >>> t1.join()
+    >>> t2.join()
+    >>> end = time.time()
+    >>> dur = end - start
+    >>> dur > 5.0
+    True
+
+    >>> calls
+    6
+
     """
     min_interval_seconds = per_period_in_seconds / float(n_calls)
 
     """
     min_interval_seconds = per_period_in_seconds / float(n_calls)
 
@@ -77,19 +111,24 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
+            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             with cv:
                 while True:
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             with cv:
                 while True:
-                    cv.wait_for(
+                    if cv.wait_for(
                         lambda: may_proceed() <= 0.0,
                         timeout=may_proceed(),
                         lambda: may_proceed() <= 0.0,
                         timeout=may_proceed(),
-                    )
-                    break
-            ret = func(*args, **kargs)
+                    ):
+                        break
             with cv:
             with cv:
+                logger.debug(f'@{time.time()}> calling it...')
+                ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 last_invocation_timestamp[0] = time.time()
+                logger.debug(
+                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                )
                 cv.notify()
             return ret
         return wrapper_wrapper_rate_limited
                 cv.notify()
             return ret
         return wrapper_wrapper_rate_limited
@@ -572,3 +611,10 @@ def decorate_matching_methods_with(decorator, acl=None):
                     setattr(cls, name, decorator(m))
         return cls
     return decorate_the_class
                     setattr(cls, name, decorator(m))
         return cls
     return decorate_the_class
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
+
+