Reduce the doctest lease duration...
[python_utils.git] / decorator_utils.py
index 1e0fe18c4063b285e84d561b30b39e5b4245d001..4615fec6f8960e0083ce48546ba9421c25243d42 100644 (file)
@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
-"""Decorators."""
+# © Copyright 2021-2022, Scott Gasch
+# Portions (marked) below retain the original author's copyright.
+
+"""Useful(?) decorators."""
 
 import enum
 import functools
 
 import enum
 import functools
@@ -14,19 +17,28 @@ import sys
 import threading
 import time
 import traceback
 import threading
 import time
 import traceback
-from typing import Any, Callable, Optional
 import warnings
 import warnings
+from typing import Any, Callable, List, Optional
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 import exceptions
 
-
 logger = logging.getLogger(__name__)
 
 
 def timed(func: Callable) -> Callable:
 logger = logging.getLogger(__name__)
 
 
 def timed(func: Callable) -> Callable:
-    """Print the runtime of the decorated function."""
+    """Print the runtime of the decorated function.
+
+    >>> @timed
+    ... def foo():
+    ...     import time
+    ...     time.sleep(0.01)
+
+    >>> foo()  # doctest: +ELLIPSIS
+    Finished foo in ...
+
+    """
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_timer(*args, **kwargs):
@@ -34,15 +46,27 @@ def timed(func: Callable) -> Callable:
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
         value = func(*args, **kwargs)
         end_time = time.perf_counter()
         run_time = end_time - start_time
-        msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
+        msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
         print(msg)
         logger.info(msg)
         return value
         print(msg)
         logger.info(msg)
         return value
+
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
     return wrapper_timer
 
 
 def invocation_logged(func: Callable) -> Callable:
-    """Log the call of a function."""
+    """Log the call of a function on stdout and the info log.
+
+    >>> @invocation_logged
+    ... def foo():
+    ...     print('Hello, world.')
+
+    >>> foo()
+    Entered foo
+    Hello, world.
+    Exited foo
+
+    """
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_invocation_logged(*args, **kwargs):
@@ -54,13 +78,14 @@ def invocation_logged(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return ret
         print(msg)
         logger.info(msg)
         return ret
+
     return wrapper_invocation_logged
 
 
 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
     return wrapper_invocation_logged
 
 
 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
-    """Limit invocation of a wrapped function to n calls per period.
+    """Limit invocation of a wrapped function to n calls per time period.
     Thread safe.  In testing this was relatively fair with multiple
     Thread safe.  In testing this was relatively fair with multiple
-    threads using it though that hasn't been measured.
+    threads using it though that hasn't been measured in detail.
 
     >>> import time
     >>> import decorator_utils
 
     >>> import time
     >>> import decorator_utils
@@ -68,7 +93,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
 
     >>> calls = 0
 
 
     >>> calls = 0
 
-    >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
+    >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
     ... def limited(x: int):
     ...     global calls
     ...     calls += 1
     ... def limited(x: int):
     ...     global calls
     ...     calls += 1
@@ -90,7 +115,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
     >>> t2.join()
     >>> end = time.time()
     >>> dur = end - start
     >>> t2.join()
     >>> end = time.time()
     >>> dur = end - start
-    >>> dur > 5.0
+    >>> dur > 0.5
     True
 
     >>> calls
     True
 
     >>> calls
@@ -111,7 +136,7 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
                 wait_time = min_interval_seconds - elapsed_since_last
             else:
                 wait_time = 0.0
-            logger.debug(f'@{time.time()}> wait_time = {wait_time}')
+            logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
             return wait_time
 
         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
@@ -123,38 +148,74 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl
                     ):
                         break
             with cv:
                     ):
                         break
             with cv:
-                logger.debug(f'@{time.time()}> calling it...')
+                logger.debug('@%.4f> calling it...', time.time())
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
                 ret = func(*args, **kargs)
                 last_invocation_timestamp[0] = time.time()
                 logger.debug(
-                    f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
+                    '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
                 )
                 cv.notify()
             return ret
                 )
                 cv.notify()
             return ret
+
         return wrapper_wrapper_rate_limited
         return wrapper_wrapper_rate_limited
+
     return wrapper_rate_limited
 
 
 def debug_args(func: Callable) -> Callable:
     return wrapper_rate_limited
 
 
 def debug_args(func: Callable) -> Callable:
-    """Print the function signature and return value at each call."""
+    """Print the function signature and return value at each call.
+
+    >>> @debug_args
+    ... def foo(a, b, c):
+    ...     print(a)
+    ...     print(b)
+    ...     print(c)
+    ...     return (a + b, c)
+
+    >>> foo(1, 2.0, "test")
+    Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
+    1
+    2.0
+    test
+    foo returned (3.0, 'test'):<class 'tuple'>
+    (3.0, 'test')
+    """
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
 
     @functools.wraps(func)
     def wrapper_debug_args(*args, **kwargs):
         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
         signature = ", ".join(args_repr + kwargs_repr)
-        msg = f"Calling {func.__name__}({signature})"
+        msg = f"Calling {func.__qualname__}({signature})"
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         value = func(*args, **kwargs)
-        msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
+        msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
+        print(msg)
         logger.info(msg)
         return value
         logger.info(msg)
         return value
+
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
     return wrapper_debug_args
 
 
 def debug_count_calls(func: Callable) -> Callable:
-    """Count function invocations and print a message befor every call."""
+    """Count function invocations and print a message befor every call.
+
+    >>> @debug_count_calls
+    ... def factoral(x):
+    ...     if x == 1:
+    ...         return 1
+    ...     return x * factoral(x - 1)
+
+    >>> factoral(5)
+    Call #1 of 'factoral'
+    Call #2 of 'factoral'
+    Call #3 of 'factoral'
+    Call #4 of 'factoral'
+    Call #5 of 'factoral'
+    120
+
+    """
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_debug_count_calls(*args, **kwargs):
@@ -163,11 +224,17 @@ def debug_count_calls(func: Callable) -> Callable:
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
         print(msg)
         logger.info(msg)
         return func(*args, **kwargs)
-    wrapper_debug_count_calls.num_calls = 0
+
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
     return wrapper_debug_count_calls
 
 
-class DelayWhen(enum.Enum):
+class DelayWhen(enum.IntEnum):
+    """When should we delay: before or after calling the function (or
+    both)?
+
+    """
+
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
     BEFORE_CALL = 1
     AFTER_CALL = 2
     BEFORE_AND_AFTER = 3
@@ -179,27 +246,35 @@ def delay(
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
     seconds: float = 1.0,
     when: DelayWhen = DelayWhen.BEFORE_CALL,
 ) -> Callable:
-    """Delay the execution of a function by sleeping before and/or after.
-
-    Slow down a function by inserting a delay before and/or after its
+    """Slow down a function by inserting a delay before and/or after its
     invocation.
     invocation.
+
+    >>> import time
+
+    >>> @delay(seconds=1.0)
+    ... def foo():
+    ...     pass
+
+    >>> start = time.time()
+    >>> foo()
+    >>> dur = time.time() - start
+    >>> dur >= 1.0
+    True
+
     """
 
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
     """
 
     def decorator_delay(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper_delay(*args, **kwargs):
             if when & DelayWhen.BEFORE_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
                 time.sleep(seconds)
             retval = func(*args, **kwargs)
             if when & DelayWhen.AFTER_CALL:
-                logger.debug(
-                    f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
-                )
+                logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
                 time.sleep(seconds)
             return retval
                 time.sleep(seconds)
             return retval
+
         return wrapper_delay
 
     if _func is None:
         return wrapper_delay
 
     if _func is None:
@@ -212,6 +287,7 @@ class _SingletonWrapper:
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
     """
     A singleton wrapper class. Its instances would be created
     for each decorated class.
+
     """
 
     def __init__(self, cls):
     """
 
     def __init__(self, cls):
@@ -220,9 +296,7 @@ class _SingletonWrapper:
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
 
     def __call__(self, *args, **kwargs):
         """Returns a single instance of decorated class"""
-        logger.debug(
-            f"@singleton returning global instance of {self.__wrapped__.__name__}"
-        )
+        logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
         if self._instance is None:
             self._instance = self.__wrapped__(*args, **kwargs)
         return self._instance
@@ -233,6 +307,19 @@ def singleton(cls):
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
     A singleton decorator. Returns a wrapper objects. A call on that object
     returns a single instance object of decorated class. Use the __wrapped__
     attribute to access decorated class directly in unit tests
+
+    >>> @singleton
+    ... class foo(object):
+    ...     pass
+
+    >>> a = foo()
+    >>> b = foo()
+    >>> a is b
+    True
+
+    >>> id(a) == id(b)
+    True
+
     """
     return _SingletonWrapper(cls)
 
     """
     return _SingletonWrapper(cls)
 
@@ -241,8 +328,35 @@ def memoized(func: Callable) -> Callable:
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
     """Keep a cache of previous function call results.
 
     The cache here is a dict with a key based on the arguments to the
-    call.  Consider also: functools.lru_cache for a more advanced
-    implementation.
+    call.  Consider also: functools.cache for a more advanced
+    implementation.  See:
+    https://docs.python.org/3/library/functools.html#functools.cache
+
+    >>> import time
+
+    >>> @memoized
+    ... def expensive(arg) -> int:
+    ...     # Simulate something slow to compute or lookup
+    ...     time.sleep(1.0)
+    ...     return arg * arg
+
+    >>> start = time.time()
+    >>> expensive(5)           # Takes about 1 sec
+    25
+
+    >>> expensive(3)           # Also takes about 1 sec
+    9
+
+    >>> expensive(5)           # Pulls from cache, fast
+    25
+
+    >>> expensive(3)           # Pulls from cache again, fast
+    9
+
+    >>> dur = time.time() - start
+    >>> dur < 3.0
+    True
+
     """
 
     @functools.wraps(func)
     """
 
     @functools.wraps(func)
@@ -250,14 +364,13 @@ def memoized(func: Callable) -> Callable:
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
         cache_key = args + tuple(kwargs.items())
         if cache_key not in wrapper_memoized.cache:
             value = func(*args, **kwargs)
-            logger.debug(
-                f"Memoizing {cache_key} => {value} for {func.__name__}"
-            )
+            logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
             wrapper_memoized.cache[cache_key] = value
         else:
             wrapper_memoized.cache[cache_key] = value
         else:
-            logger.debug(f"Returning memoized value for {func.__name__}")
+            logger.debug('Returning memoized value for %s', {func.__name__})
         return wrapper_memoized.cache[cache_key]
         return wrapper_memoized.cache[cache_key]
-    wrapper_memoized.cache = dict()
+
+    wrapper_memoized.cache = {}  # type: ignore
     return wrapper_memoized
 
 
     return wrapper_memoized
 
 
@@ -268,16 +381,20 @@ def retry_predicate(
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
     delay_sec: float = 3.0,
     backoff: float = 2.0,
 ):
-    """Retries a function or method up to a certain number of times
-    with a prescribed initial delay period and backoff rate.
-
-    tries is the maximum number of attempts to run the function.
-    delay_sec sets the initial delay period in seconds.
-    backoff is a multiplied (must be >1) used to modify the delay.
-    predicate is a function that will be passed the retval of the
-    decorated function and must return True to stop or False to
-    retry.
+    """Retries a function or method up to a certain number of times with a
+    prescribed initial delay period and backoff rate (multiplier).
+
+    Args:
+        tries: the maximum number of attempts to run the function
+        delay_sec: sets the initial delay period in seconds
+        backoff: a multiplier (must be >=1.0) used to modify the
+            delay at each subsequent invocation
+        predicate: a Callable that will be passed the retval of
+            the decorated function and must return True to indicate
+            that we should stop calling or False to indicate a retry
+            is necessary
     """
     """
+
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
     if backoff < 1.0:
         msg = f"backoff must be greater than or equal to 1, got {backoff}"
         logger.critical(msg)
@@ -298,7 +415,7 @@ def retry_predicate(
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
         @functools.wraps(f)
         def f_retry(*args, **kwargs):
             mtries, mdelay = tries, delay_sec  # make mutable
-            logger.debug(f'deco_retry: will make up to {mtries} attempts...')
+            logger.debug('deco_retry: will make up to %d attempts...', mtries)
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
             retval = f(*args, **kwargs)
             while mtries > 0:
                 if predicate(retval) is True:
@@ -310,11 +427,38 @@ def retry_predicate(
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
                 mdelay *= backoff
                 retval = f(*args, **kwargs)
             return retval
+
         return f_retry
         return f_retry
+
     return deco_retry
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
     return deco_retry
 
 
 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """A helper for @retry_predicate that retries a decorated
+    function as long as it keeps returning False.
+
+    >>> import time
+
+    >>> counter = 0
+    >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
+    ... def foo():
+    ...     global counter
+    ...     counter += 1
+    ...     return counter >= 3
+
+    >>> start = time.time()
+    >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
+    True
+
+    >>> dur = time.time() - start
+    >>> counter
+    3
+    >>> dur > 2.0
+    True
+    >>> dur < 2.3
+    True
+
+    """
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
     return retry_predicate(
         tries,
         predicate=lambda x: x is True,
@@ -324,6 +468,11 @@ def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
 
 
 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
+    """Another helper for @retry_predicate above.  Retries up to N
+    times so long as the wrapped function returns None with a delay
+    between each retry and a backoff that can increase the delay.
+    """
+
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
     return retry_predicate(
         tries,
         predicate=lambda x: x is not None,
@@ -340,9 +489,10 @@ def deprecated(func):
 
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
 
     @functools.wraps(func)
     def wrapper_deprecated(*args, **kwargs):
-        msg = f"Call to deprecated function {func.__name__}"
+        msg = f"Call to deprecated function {func.__qualname__}"
         logger.warning(msg)
         logger.warning(msg)
-        warnings.warn(msg, category=DeprecationWarning)
+        warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
+        print(msg, file=sys.stderr)
         return func(*args, **kwargs)
 
     return wrapper_deprecated
         return func(*args, **kwargs)
 
     return wrapper_deprecated
@@ -360,7 +510,7 @@ def thunkify(func):
         wait_event = threading.Event()
 
         result = [None]
         wait_event = threading.Event()
 
         result = [None]
-        exc = [False, None]
+        exc: List[Any] = [False, None]
 
         def worker_func():
             try:
 
         def worker_func():
             try:
@@ -370,7 +520,6 @@ def thunkify(func):
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
                 exc[0] = True
                 exc[1] = sys.exc_info()  # (type, value, traceback)
                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
-                print(msg)
                 logger.warning(msg)
             finally:
                 wait_event.set()
                 logger.warning(msg)
             finally:
                 wait_event.set()
@@ -378,6 +527,7 @@ def thunkify(func):
         def thunk():
             wait_event.wait()
             if exc[0]:
         def thunk():
             wait_event.wait()
             if exc[0]:
+                assert exc[1]
                 raise exc[1][0](exc[1][1])
             return result[0]
 
                 raise exc[1][0](exc[1][1])
             return result[0]
 
@@ -398,9 +548,9 @@ def thunkify(func):
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
 
 def _raise_exception(exception, error_message: Optional[str]):
     if error_message is None:
-        raise exception()
+        raise Exception(exception)
     else:
     else:
-        raise exception(error_message)
+        raise Exception(error_message)
 
 
 def _target(queue, function, *args, **kwargs):
 
 
 def _target(queue, function, *args, **kwargs):
@@ -418,10 +568,10 @@ def _target(queue, function, *args, **kwargs):
 
 
 class _Timeout(object):
 
 
 class _Timeout(object):
-    """Wrap a function and add a timeout (limit) attribute to it.
+    """Wrap a function and add a timeout to it.
 
     Instances of this class are automatically generated by the add_timeout
 
     Instances of this class are automatically generated by the add_timeout
-    function defined below.
+    function defined below.  Do not use directly.
     """
 
     def __init__(
     """
 
     def __init__(
@@ -451,9 +601,7 @@ class _Timeout(object):
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
         self.__limit = kwargs.pop("timeout", self.__limit)
         self.__queue = multiprocessing.Queue(1)
         args = (self.__queue, self.__function) + args
-        self.__process = multiprocessing.Process(
-            target=_target, args=args, kwargs=kwargs
-        )
+        self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
         self.__process.daemon = True
         self.__process.start()
         if self.__limit is not None:
@@ -483,6 +631,7 @@ class _Timeout(object):
             if flag:
                 return load
             raise load
             if flag:
                 return load
             raise load
+        return None
 
 
 def timeout(
 
 
 def timeout(
@@ -498,19 +647,41 @@ def timeout(
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
     main thread).  When not using signals, timeout granularity will be
     rounded to the nearest 0.1s.
 
-    Raises an exception when the timeout is reached.
+    Beware that an @timeout on a function inside a module will be
+    evaluated at module load time and not when the wrapped function is
+    invoked.  This can lead to problems when relying on the automatic
+    main thread detection code (use_signals=None, the default) since
+    the import probably happens on the main thread and the invocation
+    can happen on a different thread (which can't use signals).
+
+    Raises an exception when/if the timeout is reached.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
 
     It is illegal to pass anything other than a function as the first
     parameter.  The function is wrapped and returned to the caller.
+
+    >>> @timeout(0.2)
+    ... def foo(delay: float):
+    ...     time.sleep(delay)
+    ...     return "ok"
+
+    >>> foo(0)
+    'ok'
+
+    >>> foo(1.0)
+    Traceback (most recent call last):
+    ...
+    Exception: Function call timed out
+
     """
     if use_signals is None:
         import thread_utils
     """
     if use_signals is None:
         import thread_utils
+
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         if use_signals:
 
         use_signals = thread_utils.is_current_thread_main_thread()
 
     def decorate(function):
         if use_signals:
 
-            def handler(signum, frame):
+            def handler(unused_signum, unused_frame):
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
                 _raise_exception(timeout_exception, error_message)
 
             @functools.wraps(function)
@@ -535,9 +706,7 @@ def timeout(
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
 
             @functools.wraps(function)
             def new_function(*args, **kwargs):
-                timeout_wrapper = _Timeout(
-                    function, timeout_exception, error_message, seconds
-                )
+                timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
                 return timeout_wrapper(*args, **kwargs)
 
             return new_function
@@ -545,40 +714,31 @@ def timeout(
     return decorate
 
 
     return decorate
 
 
-class non_reentrant_code(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
+def synchronized(lock):
+    """Emulates java's synchronized keyword: given a lock, require that
+    threads take that lock (or wait) before invoking the wrapped
+    function and automatically releases the lock afterwards.
+    """
 
 
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
+    def wrap(f):
+        @functools.wraps(f)
+        def _gatekeeper(*args, **kw):
+            lock.acquire()
+            try:
+                return f(*args, **kw)
+            finally:
+                lock.release()
 
         return _gatekeeper
 
 
         return _gatekeeper
 
-
-class rlocked(object):
-    def __init__(self):
-        self._lock = threading.RLock
-        self._entered = False
-
-    def __call__(self, f):
-        def _gatekeeper(*args, **kwargs):
-            with self._lock:
-                if self._entered:
-                    return
-                self._entered = True
-                f(*args, **kwargs)
-                self._entered = False
-        return _gatekeeper
+    return wrap
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
 
 
 def call_with_sample_rate(sample_rate: float) -> Callable:
+    """Calls the wrapped function probabilistically given a rate between
+    0.0 and 1.0 inclusive (0% probability and 100% probability).
+    """
+
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
     if not 0.0 <= sample_rate <= 1.0:
         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
         logger.critical(msg)
@@ -590,18 +750,20 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
             if random.uniform(0, 1) < sample_rate:
                 return f(*args, **kwargs)
             else:
-                logger.debug(
-                    f"@call_with_sample_rate skipping a call to {f.__name__}"
-                )
+                logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
+                return None
+
         return _call_with_sample_rate
         return _call_with_sample_rate
+
     return decorator
 
 
 def decorate_matching_methods_with(decorator, acl=None):
     return decorator
 
 
 def decorate_matching_methods_with(decorator, acl=None):
-    """Apply decorator to all methods in a class whose names begin with
-    prefix.  If prefix is None (default), decorate all methods in the
-    class.
+    """Apply the given decorator to all methods in a class whose names
+    begin with prefix.  If prefix is None (default), decorate all
+    methods in the class.
     """
     """
+
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
     def decorate_the_class(cls):
         for name, m in inspect.getmembers(cls, inspect.isfunction):
             if acl is None:
@@ -610,11 +772,11 @@ def decorate_matching_methods_with(decorator, acl=None):
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
                 if acl(name):
                     setattr(cls, name, decorator(m))
         return cls
+
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
     return decorate_the_class
 
 
 if __name__ == '__main__':
     import doctest
-    doctest.testmod()
-
 
 
+    doctest.testmod()