3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
13 import multiprocessing
21 from typing import Any, Callable, Optional
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
27 logger = logging.getLogger(__name__)
30 def timed(func: Callable) -> Callable:
31 """Print the runtime of the decorated function.
38 >>> foo() # doctest: +ELLIPSIS
43 @functools.wraps(func)
44 def wrapper_timer(*args, **kwargs):
45 start_time = time.perf_counter()
46 value = func(*args, **kwargs)
47 end_time = time.perf_counter()
48 run_time = end_time - start_time
49 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
57 def invocation_logged(func: Callable) -> Callable:
58 """Log the call of a function.
60 >>> @invocation_logged
62 ... print('Hello, world.')
71 @functools.wraps(func)
72 def wrapper_invocation_logged(*args, **kwargs):
73 msg = f"Entered {func.__qualname__}"
76 ret = func(*args, **kwargs)
77 msg = f"Exited {func.__qualname__}"
82 return wrapper_invocation_logged
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86 """Limit invocation of a wrapped function to n calls per period.
87 Thread safe. In testing this was relatively fair with multiple
88 threads using it though that hasn't been measured.
91 >>> import decorator_utils
92 >>> import thread_utils
96 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97 ... def limited(x: int):
101 >>> @thread_utils.background_thread
103 ... for _ in range(3):
106 >>> @thread_utils.background_thread
108 ... for _ in range(3):
111 >>> start = time.time()
116 >>> end = time.time()
117 >>> dur = end - start
125 min_interval_seconds = per_period_in_seconds / float(n_calls)
127 def wrapper_rate_limited(func: Callable) -> Callable:
128 cv = threading.Condition()
129 last_invocation_timestamp = [0.0]
131 def may_proceed() -> float:
133 last_invocation = last_invocation_timestamp[0]
134 if last_invocation != 0.0:
135 elapsed_since_last = now - last_invocation
136 wait_time = min_interval_seconds - elapsed_since_last
139 logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
142 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
146 lambda: may_proceed() <= 0.0,
147 timeout=may_proceed(),
151 logger.debug('@%.4f> calling it...', time.time())
152 ret = func(*args, **kargs)
153 last_invocation_timestamp[0] = time.time()
155 '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
160 return wrapper_wrapper_rate_limited
162 return wrapper_rate_limited
165 def debug_args(func: Callable) -> Callable:
166 """Print the function signature and return value at each call.
169 ... def foo(a, b, c):
173 ... return (a + b, c)
175 >>> foo(1, 2.0, "test")
176 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
180 foo returned (3.0, 'test'):<class 'tuple'>
184 @functools.wraps(func)
185 def wrapper_debug_args(*args, **kwargs):
186 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188 signature = ", ".join(args_repr + kwargs_repr)
189 msg = f"Calling {func.__qualname__}({signature})"
192 value = func(*args, **kwargs)
193 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
198 return wrapper_debug_args
201 def debug_count_calls(func: Callable) -> Callable:
202 """Count function invocations and print a message befor every call.
204 >>> @debug_count_calls
208 ... return x * factoral(x - 1)
211 Call #1 of 'factoral'
212 Call #2 of 'factoral'
213 Call #3 of 'factoral'
214 Call #4 of 'factoral'
215 Call #5 of 'factoral'
220 @functools.wraps(func)
221 def wrapper_debug_count_calls(*args, **kwargs):
222 wrapper_debug_count_calls.num_calls += 1
223 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
226 return func(*args, **kwargs)
228 wrapper_debug_count_calls.num_calls = 0 # type: ignore
229 return wrapper_debug_count_calls
232 class DelayWhen(enum.IntEnum):
233 """When should we delay: before or after calling the function (or
244 _func: Callable = None,
246 seconds: float = 1.0,
247 when: DelayWhen = DelayWhen.BEFORE_CALL,
249 """Delay the execution of a function by sleeping before and/or after.
251 Slow down a function by inserting a delay before and/or after its
256 >>> @delay(seconds=1.0)
260 >>> start = time.time()
262 >>> dur = time.time() - start
268 def decorator_delay(func: Callable) -> Callable:
269 @functools.wraps(func)
270 def wrapper_delay(*args, **kwargs):
271 if when & DelayWhen.BEFORE_CALL:
272 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
274 retval = func(*args, **kwargs)
275 if when & DelayWhen.AFTER_CALL:
276 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
283 return decorator_delay
285 return decorator_delay(_func)
288 class _SingletonWrapper:
290 A singleton wrapper class. Its instances would be created
291 for each decorated class.
295 def __init__(self, cls):
296 self.__wrapped__ = cls
297 self._instance = None
299 def __call__(self, *args, **kwargs):
300 """Returns a single instance of decorated class"""
301 logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
302 if self._instance is None:
303 self._instance = self.__wrapped__(*args, **kwargs)
304 return self._instance
309 A singleton decorator. Returns a wrapper objects. A call on that object
310 returns a single instance object of decorated class. Use the __wrapped__
311 attribute to access decorated class directly in unit tests
314 ... class foo(object):
326 return _SingletonWrapper(cls)
329 def memoized(func: Callable) -> Callable:
330 """Keep a cache of previous function call results.
332 The cache here is a dict with a key based on the arguments to the
333 call. Consider also: functools.lru_cache for a more advanced
339 ... def expensive(arg) -> int:
340 ... # Simulate something slow to compute or lookup
344 >>> start = time.time()
345 >>> expensive(5) # Takes about 1 sec
348 >>> expensive(3) # Also takes about 1 sec
351 >>> expensive(5) # Pulls from cache, fast
354 >>> expensive(3) # Pulls from cache again, fast
357 >>> dur = time.time() - start
363 @functools.wraps(func)
364 def wrapper_memoized(*args, **kwargs):
365 cache_key = args + tuple(kwargs.items())
366 if cache_key not in wrapper_memoized.cache:
367 value = func(*args, **kwargs)
368 logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
369 wrapper_memoized.cache[cache_key] = value
371 logger.debug('Returning memoized value for %s', {func.__name__})
372 return wrapper_memoized.cache[cache_key]
374 wrapper_memoized.cache = {} # type: ignore
375 return wrapper_memoized
381 predicate: Callable[..., bool],
382 delay_sec: float = 3.0,
383 backoff: float = 2.0,
385 """Retries a function or method up to a certain number of times
386 with a prescribed initial delay period and backoff rate.
388 tries is the maximum number of attempts to run the function.
389 delay_sec sets the initial delay period in seconds.
390 backoff is a multiplied (must be >1) used to modify the delay.
391 predicate is a function that will be passed the retval of the
392 decorated function and must return True to stop or False to
397 msg = f"backoff must be greater than or equal to 1, got {backoff}"
399 raise ValueError(msg)
401 tries = math.floor(tries)
403 msg = f"tries must be 0 or greater, got {tries}"
405 raise ValueError(msg)
408 msg = f"delay_sec must be greater than 0, got {delay_sec}"
410 raise ValueError(msg)
414 def f_retry(*args, **kwargs):
415 mtries, mdelay = tries, delay_sec # make mutable
416 logger.debug('deco_retry: will make up to %d attempts...', mtries)
417 retval = f(*args, **kwargs)
419 if predicate(retval) is True:
420 logger.debug('Predicate succeeded, deco_retry is done.')
422 logger.debug("Predicate failed, sleeping and retrying.")
426 retval = f(*args, **kwargs)
434 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
435 """A helper for @retry_predicate that retries a decorated
436 function as long as it keeps returning False.
442 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
446 ... return counter >= 3
448 >>> start = time.time()
449 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
452 >>> dur = time.time() - start
461 return retry_predicate(
463 predicate=lambda x: x is True,
469 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
470 """Another helper for @retry_predicate above. Retries up to N
471 times so long as the wrapped function returns None with a delay
472 between each retry and a backoff that can increase the delay.
475 return retry_predicate(
477 predicate=lambda x: x is not None,
483 def deprecated(func):
484 """This is a decorator which can be used to mark functions
485 as deprecated. It will result in a warning being emitted
486 when the function is used.
490 @functools.wraps(func)
491 def wrapper_deprecated(*args, **kwargs):
492 msg = f"Call to deprecated function {func.__qualname__}"
494 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495 print(msg, file=sys.stderr)
496 return func(*args, **kwargs)
498 return wrapper_deprecated
503 Make a function immediately return a function of no args which,
504 when called, waits for the result, which will start being
505 processed in another thread.
508 @functools.wraps(func)
509 def lazy_thunked(*args, **kwargs):
510 wait_event = threading.Event()
517 func_result = func(*args, **kwargs)
518 result[0] = func_result
521 exc[1] = sys.exc_info() # (type, value, traceback)
522 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
530 raise exc[1][0](exc[1][1])
533 threading.Thread(target=worker_func).start()
539 ############################################################
541 ############################################################
543 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
545 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
548 def _raise_exception(exception, error_message: Optional[str]):
549 if error_message is None:
550 raise Exception(exception)
552 raise Exception(error_message)
555 def _target(queue, function, *args, **kwargs):
556 """Run a function with arguments and return output via a queue.
558 This is a helper function for the Process created in _Timeout. It runs
559 the function with positional arguments and keyword arguments and then
560 returns the function's output by way of a queue. If an exception gets
561 raised, it is returned to _Timeout to be raised by the value property.
564 queue.put((True, function(*args, **kwargs)))
566 queue.put((False, sys.exc_info()[1]))
569 class _Timeout(object):
570 """Wrap a function and add a timeout to it.
572 Instances of this class are automatically generated by the add_timeout
573 function defined below. Do not use directly.
579 timeout_exception: Exception,
583 self.__limit = seconds
584 self.__function = function
585 self.__timeout_exception = timeout_exception
586 self.__error_message = error_message
587 self.__name__ = function.__name__
588 self.__doc__ = function.__doc__
589 self.__timeout = time.time()
590 self.__process = multiprocessing.Process()
591 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
593 def __call__(self, *args, **kwargs):
594 """Execute the embedded function object asynchronously.
596 The function given to the constructor is transparently called and
597 requires that "ready" be intermittently polled. If and when it is
598 True, the "value" property may then be checked for returned data.
600 self.__limit = kwargs.pop("timeout", self.__limit)
601 self.__queue = multiprocessing.Queue(1)
602 args = (self.__queue, self.__function) + args
603 self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
604 self.__process.daemon = True
605 self.__process.start()
606 if self.__limit is not None:
607 self.__timeout = self.__limit + time.time()
608 while not self.ready:
613 """Terminate any possible execution of the embedded function."""
614 if self.__process.is_alive():
615 self.__process.terminate()
616 _raise_exception(self.__timeout_exception, self.__error_message)
620 """Read-only property indicating status of "value" property."""
621 if self.__limit and self.__timeout < time.time():
623 return self.__queue.full() and not self.__queue.empty()
627 """Read-only property containing data returned from function."""
628 if self.ready is True:
629 flag, load = self.__queue.get()
637 seconds: float = 1.0,
638 use_signals: Optional[bool] = None,
639 timeout_exception=exceptions.TimeoutError,
640 error_message="Function call timed out",
642 """Add a timeout parameter to a function and return the function.
644 Note: the use_signals parameter is included in order to support
645 multiprocessing scenarios (signal can only be used from the process'
646 main thread). When not using signals, timeout granularity will be
647 rounded to the nearest 0.1s.
649 Raises an exception when/if the timeout is reached.
651 It is illegal to pass anything other than a function as the first
652 parameter. The function is wrapped and returned to the caller.
655 ... def foo(delay: float):
656 ... time.sleep(delay)
663 Traceback (most recent call last):
665 Exception: Function call timed out
668 if use_signals is None:
671 use_signals = thread_utils.is_current_thread_main_thread()
673 def decorate(function):
676 def handler(signum, frame):
677 _raise_exception(timeout_exception, error_message)
679 @functools.wraps(function)
680 def new_function(*args, **kwargs):
681 new_seconds = kwargs.pop("timeout", seconds)
683 old = signal.signal(signal.SIGALRM, handler)
684 signal.setitimer(signal.ITIMER_REAL, new_seconds)
687 return function(*args, **kwargs)
690 return function(*args, **kwargs)
693 signal.setitimer(signal.ITIMER_REAL, 0)
694 signal.signal(signal.SIGALRM, old)
699 @functools.wraps(function)
700 def new_function(*args, **kwargs):
701 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
702 return timeout_wrapper(*args, **kwargs)
709 def synchronized(lock):
712 def _gatekeeper(*args, **kw):
715 return f(*args, **kw)
724 def call_with_sample_rate(sample_rate: float) -> Callable:
725 if not 0.0 <= sample_rate <= 1.0:
726 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
728 raise ValueError(msg)
732 def _call_with_sample_rate(*args, **kwargs):
733 if random.uniform(0, 1) < sample_rate:
734 return f(*args, **kwargs)
736 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
739 return _call_with_sample_rate
744 def decorate_matching_methods_with(decorator, acl=None):
745 """Apply decorator to all methods in a class whose names begin with
746 prefix. If prefix is None (default), decorate all methods in the
750 def decorate_the_class(cls):
751 for name, m in inspect.getmembers(cls, inspect.isfunction):
753 setattr(cls, name, decorator(m))
756 setattr(cls, name, decorator(m))
759 return decorate_the_class
762 if __name__ == '__main__':