3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
6 """Useful(?) decorators."""
13 import multiprocessing
21 from typing import Any, Callable, List, Optional
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
27 logger = logging.getLogger(__name__)
30 def timed(func: Callable) -> Callable:
31 """Print the runtime of the decorated function.
38 >>> foo() # doctest: +ELLIPSIS
43 @functools.wraps(func)
44 def wrapper_timer(*args, **kwargs):
45 start_time = time.perf_counter()
46 value = func(*args, **kwargs)
47 end_time = time.perf_counter()
48 run_time = end_time - start_time
49 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
57 def invocation_logged(func: Callable) -> Callable:
58 """Log the call of a function on stdout and the info log.
60 >>> @invocation_logged
62 ... print('Hello, world.')
71 @functools.wraps(func)
72 def wrapper_invocation_logged(*args, **kwargs):
73 msg = f"Entered {func.__qualname__}"
76 ret = func(*args, **kwargs)
77 msg = f"Exited {func.__qualname__}"
82 return wrapper_invocation_logged
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86 """Limit invocation of a wrapped function to n calls per time period.
87 Thread safe. In testing this was relatively fair with multiple
88 threads using it though that hasn't been measured in detail.
91 >>> import decorator_utils
92 >>> import thread_utils
96 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97 ... def limited(x: int):
101 >>> @thread_utils.background_thread
103 ... for _ in range(3):
106 >>> @thread_utils.background_thread
108 ... for _ in range(3):
111 >>> start = time.time()
116 >>> end = time.time()
117 >>> dur = end - start
125 min_interval_seconds = per_period_in_seconds / float(n_calls)
127 def wrapper_rate_limited(func: Callable) -> Callable:
128 cv = threading.Condition()
129 last_invocation_timestamp = [0.0]
131 def may_proceed() -> float:
133 last_invocation = last_invocation_timestamp[0]
134 if last_invocation != 0.0:
135 elapsed_since_last = now - last_invocation
136 wait_time = min_interval_seconds - elapsed_since_last
139 logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
142 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
146 lambda: may_proceed() <= 0.0,
147 timeout=may_proceed(),
151 logger.debug('@%.4f> calling it...', time.time())
152 ret = func(*args, **kargs)
153 last_invocation_timestamp[0] = time.time()
155 '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
160 return wrapper_wrapper_rate_limited
162 return wrapper_rate_limited
165 def debug_args(func: Callable) -> Callable:
166 """Print the function signature and return value at each call.
169 ... def foo(a, b, c):
173 ... return (a + b, c)
175 >>> foo(1, 2.0, "test")
176 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
180 foo returned (3.0, 'test'):<class 'tuple'>
184 @functools.wraps(func)
185 def wrapper_debug_args(*args, **kwargs):
186 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188 signature = ", ".join(args_repr + kwargs_repr)
189 msg = f"Calling {func.__qualname__}({signature})"
192 value = func(*args, **kwargs)
193 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
198 return wrapper_debug_args
201 def debug_count_calls(func: Callable) -> Callable:
202 """Count function invocations and print a message befor every call.
204 >>> @debug_count_calls
208 ... return x * factoral(x - 1)
211 Call #1 of 'factoral'
212 Call #2 of 'factoral'
213 Call #3 of 'factoral'
214 Call #4 of 'factoral'
215 Call #5 of 'factoral'
220 @functools.wraps(func)
221 def wrapper_debug_count_calls(*args, **kwargs):
222 wrapper_debug_count_calls.num_calls += 1
223 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
226 return func(*args, **kwargs)
228 wrapper_debug_count_calls.num_calls = 0 # type: ignore
229 return wrapper_debug_count_calls
232 class DelayWhen(enum.IntEnum):
233 """When should we delay: before or after calling the function (or
244 _func: Callable = None,
246 seconds: float = 1.0,
247 when: DelayWhen = DelayWhen.BEFORE_CALL,
249 """Slow down a function by inserting a delay before and/or after its
254 >>> @delay(seconds=1.0)
258 >>> start = time.time()
260 >>> dur = time.time() - start
266 def decorator_delay(func: Callable) -> Callable:
267 @functools.wraps(func)
268 def wrapper_delay(*args, **kwargs):
269 if when & DelayWhen.BEFORE_CALL:
270 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
272 retval = func(*args, **kwargs)
273 if when & DelayWhen.AFTER_CALL:
274 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
281 return decorator_delay
283 return decorator_delay(_func)
286 class _SingletonWrapper:
288 A singleton wrapper class. Its instances would be created
289 for each decorated class.
293 def __init__(self, cls):
294 self.__wrapped__ = cls
295 self._instance = None
297 def __call__(self, *args, **kwargs):
298 """Returns a single instance of decorated class"""
299 logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
300 if self._instance is None:
301 self._instance = self.__wrapped__(*args, **kwargs)
302 return self._instance
307 A singleton decorator. Returns a wrapper objects. A call on that object
308 returns a single instance object of decorated class. Use the __wrapped__
309 attribute to access decorated class directly in unit tests
312 ... class foo(object):
324 return _SingletonWrapper(cls)
327 def memoized(func: Callable) -> Callable:
328 """Keep a cache of previous function call results.
330 The cache here is a dict with a key based on the arguments to the
331 call. Consider also: functools.cache for a more advanced
333 https://docs.python.org/3/library/functools.html#functools.cache
338 ... def expensive(arg) -> int:
339 ... # Simulate something slow to compute or lookup
343 >>> start = time.time()
344 >>> expensive(5) # Takes about 1 sec
347 >>> expensive(3) # Also takes about 1 sec
350 >>> expensive(5) # Pulls from cache, fast
353 >>> expensive(3) # Pulls from cache again, fast
356 >>> dur = time.time() - start
362 @functools.wraps(func)
363 def wrapper_memoized(*args, **kwargs):
364 cache_key = args + tuple(kwargs.items())
365 if cache_key not in wrapper_memoized.cache:
366 value = func(*args, **kwargs)
367 logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
368 wrapper_memoized.cache[cache_key] = value
370 logger.debug('Returning memoized value for %s', {func.__name__})
371 return wrapper_memoized.cache[cache_key]
373 wrapper_memoized.cache = {} # type: ignore
374 return wrapper_memoized
380 predicate: Callable[..., bool],
381 delay_sec: float = 3.0,
382 backoff: float = 2.0,
384 """Retries a function or method up to a certain number of times with a
385 prescribed initial delay period and backoff rate (multiplier).
388 tries: the maximum number of attempts to run the function
389 delay_sec: sets the initial delay period in seconds
390 backoff: a multiplier (must be >=1.0) used to modify the
391 delay at each subsequent invocation
392 predicate: a Callable that will be passed the retval of
393 the decorated function and must return True to indicate
394 that we should stop calling or False to indicate a retry
399 msg = f"backoff must be greater than or equal to 1, got {backoff}"
401 raise ValueError(msg)
403 tries = math.floor(tries)
405 msg = f"tries must be 0 or greater, got {tries}"
407 raise ValueError(msg)
410 msg = f"delay_sec must be greater than 0, got {delay_sec}"
412 raise ValueError(msg)
416 def f_retry(*args, **kwargs):
417 mtries, mdelay = tries, delay_sec # make mutable
418 logger.debug('deco_retry: will make up to %d attempts...', mtries)
419 retval = f(*args, **kwargs)
421 if predicate(retval) is True:
422 logger.debug('Predicate succeeded, deco_retry is done.')
424 logger.debug("Predicate failed, sleeping and retrying.")
428 retval = f(*args, **kwargs)
436 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
437 """A helper for @retry_predicate that retries a decorated
438 function as long as it keeps returning False.
443 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
447 ... return counter >= 3
449 >>> start = time.time()
450 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
453 >>> dur = time.time() - start
462 return retry_predicate(
464 predicate=lambda x: x is True,
470 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
471 """Another helper for @retry_predicate above. Retries up to N
472 times so long as the wrapped function returns None with a delay
473 between each retry and a backoff that can increase the delay.
476 return retry_predicate(
478 predicate=lambda x: x is not None,
484 def deprecated(func):
485 """This is a decorator which can be used to mark functions
486 as deprecated. It will result in a warning being emitted
487 when the function is used.
490 @functools.wraps(func)
491 def wrapper_deprecated(*args, **kwargs):
492 msg = f"Call to deprecated function {func.__qualname__}"
494 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495 print(msg, file=sys.stderr)
496 return func(*args, **kwargs)
498 return wrapper_deprecated
503 Make a function immediately return a function of no args which,
504 when called, waits for the result, which will start being
505 processed in another thread.
508 @functools.wraps(func)
509 def lazy_thunked(*args, **kwargs):
510 wait_event = threading.Event()
513 exc: List[Any] = [False, None]
517 func_result = func(*args, **kwargs)
518 result[0] = func_result
521 exc[1] = sys.exc_info() # (type, value, traceback)
522 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
531 raise exc[1][0](exc[1][1])
534 threading.Thread(target=worker_func).start()
540 ############################################################
542 ############################################################
544 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
546 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
549 def _raise_exception(exception, error_message: Optional[str]):
550 if error_message is None:
551 raise Exception(exception)
553 raise Exception(error_message)
556 def _target(queue, function, *args, **kwargs):
557 """Run a function with arguments and return output via a queue.
559 This is a helper function for the Process created in _Timeout. It runs
560 the function with positional arguments and keyword arguments and then
561 returns the function's output by way of a queue. If an exception gets
562 raised, it is returned to _Timeout to be raised by the value property.
565 queue.put((True, function(*args, **kwargs)))
567 queue.put((False, sys.exc_info()[1]))
570 class _Timeout(object):
571 """Wrap a function and add a timeout to it.
573 Instances of this class are automatically generated by the add_timeout
574 function defined below. Do not use directly.
580 timeout_exception: Exception,
584 self.__limit = seconds
585 self.__function = function
586 self.__timeout_exception = timeout_exception
587 self.__error_message = error_message
588 self.__name__ = function.__name__
589 self.__doc__ = function.__doc__
590 self.__timeout = time.time()
591 self.__process = multiprocessing.Process()
592 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
594 def __call__(self, *args, **kwargs):
595 """Execute the embedded function object asynchronously.
597 The function given to the constructor is transparently called and
598 requires that "ready" be intermittently polled. If and when it is
599 True, the "value" property may then be checked for returned data.
601 self.__limit = kwargs.pop("timeout", self.__limit)
602 self.__queue = multiprocessing.Queue(1)
603 args = (self.__queue, self.__function) + args
604 self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
605 self.__process.daemon = True
606 self.__process.start()
607 if self.__limit is not None:
608 self.__timeout = self.__limit + time.time()
609 while not self.ready:
614 """Terminate any possible execution of the embedded function."""
615 if self.__process.is_alive():
616 self.__process.terminate()
617 _raise_exception(self.__timeout_exception, self.__error_message)
621 """Read-only property indicating status of "value" property."""
622 if self.__limit and self.__timeout < time.time():
624 return self.__queue.full() and not self.__queue.empty()
628 """Read-only property containing data returned from function."""
629 if self.ready is True:
630 flag, load = self.__queue.get()
638 seconds: float = 1.0,
639 use_signals: Optional[bool] = None,
640 timeout_exception=exceptions.TimeoutError,
641 error_message="Function call timed out",
643 """Add a timeout parameter to a function and return the function.
645 Note: the use_signals parameter is included in order to support
646 multiprocessing scenarios (signal can only be used from the process'
647 main thread). When not using signals, timeout granularity will be
648 rounded to the nearest 0.1s.
650 Raises an exception when/if the timeout is reached.
652 It is illegal to pass anything other than a function as the first
653 parameter. The function is wrapped and returned to the caller.
656 ... def foo(delay: float):
657 ... time.sleep(delay)
664 Traceback (most recent call last):
666 Exception: Function call timed out
669 if use_signals is None:
672 use_signals = thread_utils.is_current_thread_main_thread()
674 def decorate(function):
677 def handler(unused_signum, unused_frame):
678 _raise_exception(timeout_exception, error_message)
680 @functools.wraps(function)
681 def new_function(*args, **kwargs):
682 new_seconds = kwargs.pop("timeout", seconds)
684 old = signal.signal(signal.SIGALRM, handler)
685 signal.setitimer(signal.ITIMER_REAL, new_seconds)
688 return function(*args, **kwargs)
691 return function(*args, **kwargs)
694 signal.setitimer(signal.ITIMER_REAL, 0)
695 signal.signal(signal.SIGALRM, old)
700 @functools.wraps(function)
701 def new_function(*args, **kwargs):
702 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
703 return timeout_wrapper(*args, **kwargs)
710 def synchronized(lock):
711 """Emulates java's synchronized keyword: given a lock, require that
712 threads take that lock (or wait) before invoking the wrapped
713 function and automatically releases the lock afterwards.
718 def _gatekeeper(*args, **kw):
721 return f(*args, **kw)
730 def call_with_sample_rate(sample_rate: float) -> Callable:
731 """Calls the wrapped function probabilistically given a rate between
732 0.0 and 1.0 inclusive (0% probability and 100% probability).
735 if not 0.0 <= sample_rate <= 1.0:
736 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
738 raise ValueError(msg)
742 def _call_with_sample_rate(*args, **kwargs):
743 if random.uniform(0, 1) < sample_rate:
744 return f(*args, **kwargs)
746 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
749 return _call_with_sample_rate
754 def decorate_matching_methods_with(decorator, acl=None):
755 """Apply the given decorator to all methods in a class whose names
756 begin with prefix. If prefix is None (default), decorate all
757 methods in the class.
760 def decorate_the_class(cls):
761 for name, m in inspect.getmembers(cls, inspect.isfunction):
763 setattr(cls, name, decorator(m))
766 setattr(cls, name, decorator(m))
769 return decorate_the_class
772 if __name__ == '__main__':