More type annotations.
[python_utils.git] / decorator_utils.py
index daae64e75348e973dc8a27cf387faf7f404ef2b2..cd69639448425ce3a47073c7e423ea98d6704b2b 100644 (file)
@@ -80,9 +80,7 @@ def invocation_logged(func: Callable) -> Callable:
     return wrapper_invocation_logged
 
 
-def rate_limited(
-    n_calls: int, *, per_period_in_seconds: float = 1.0
-) -> Callable:
+def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
     """Limit invocation of a wrapped function to n calls per period.
     Thread safe.  In testing this was relatively fair with multiple
     threads using it though that hasn't been measured.
@@ -220,14 +218,12 @@ def debug_count_calls(func: Callable) -> Callable:
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
         wrapper_debug_count_calls.num_calls += 1
-        msg = (
-            f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
-        )
+        msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
 
-    wrapper_debug_count_calls.num_calls = 0
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
@@ -266,15 +262,11 @@ def delay(
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
                 time.sleep(seconds)
             return retval
 
@@ -368,15 +360,13 @@ def memoized(func: Callable) -> Callable:
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
             wrapper_memoized.cache[cache_key] = value
         else:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
 
-    wrapper_memoized.cache = dict()
+    wrapper_memoized.cache = dict()  # type: ignore
     return wrapper_memoized
 
 
@@ -714,38 +704,19 @@ def timeout(
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+def synchronized(lock):
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
@@ -760,9 +731,7 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
 
         return _call_with_sample_rate