Make rate_limited use cvs.
[python_utils.git] / decorator_utils.py
index 480543ae97e26498b9d4314c927c5ff85b076757..70a88d37ad0dbad37edff45aba0130dcc5a26271 100644 (file)
@@ -14,7 +14,7 @@ import sys
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Optional, Tuple
 import warnings
 
 # This module is commonly used by others in here and should avoid
@@ -57,26 +57,40 @@ def invocation_logged(func: Callable) -> Callable:
     return wrapper_invocation_logged
 
 
-def rate_limited(n_per_second: int) -> Callable:
-    """Limit invocation of a wrapped function to n calls per second.
-    Thread safe.
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
+    """Limit invocation of a wrapped function to n calls per period.
+    Thread safe.  In testing this was relatively fair with multiple
+    threads using it though that hasn't been measured.
 
     """
-    min_interval = 1.0 / float(n_per_second)
+    min_interval_seconds = per_period_in_seconds / float(n_calls)
 
     def wrapper_rate_limited(func: Callable) -> Callable:
-        last_invocation_time = [0.0]
+        cv = threading.Condition()
+        last_invocation_timestamp = [0.0]
+
+        def may_proceed() -> float:
+            now = time.time()
+            last_invocation = last_invocation_timestamp[0]
+            if last_invocation != 0.0:
+                elapsed_since_last = now - last_invocation
+                wait_time = min_interval_seconds - elapsed_since_last
+            else:
+                wait_time = 0.0
+            return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
-            while True:
-                elapsed = time.clock_gettime(0) - last_invocation_time[0]
-                wait_time = min_interval - elapsed
-                if wait_time > 0.0:
-                    time.sleep(wait_time)
-                else:
+            with cv:
+                while True:
+                    cv.wait_for(
+                        lambda: may_proceed() <= 0.0,
+                        timeout=may_proceed(),
+                    )
                     break
             ret = func(*args, **kargs)
-            last_invocation_time[0] = time.clock_gettime(0)
+            with cv:
+                last_invocation_timestamp[0] = time.time()
+                cv.notify()
             return ret
         return wrapper_wrapper_rate_limited
     return wrapper_rate_limited