3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
13 import multiprocessing
21 from typing import Any, Callable, List, Optional
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
27 logger = logging.getLogger(__name__)
30 def timed(func: Callable) -> Callable:
31 """Print the runtime of the decorated function.
38 >>> foo() # doctest: +ELLIPSIS
43 @functools.wraps(func)
44 def wrapper_timer(*args, **kwargs):
45 start_time = time.perf_counter()
46 value = func(*args, **kwargs)
47 end_time = time.perf_counter()
48 run_time = end_time - start_time
49 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
57 def invocation_logged(func: Callable) -> Callable:
58 """Log the call of a function.
60 >>> @invocation_logged
62 ... print('Hello, world.')
71 @functools.wraps(func)
72 def wrapper_invocation_logged(*args, **kwargs):
73 msg = f"Entered {func.__qualname__}"
76 ret = func(*args, **kwargs)
77 msg = f"Exited {func.__qualname__}"
82 return wrapper_invocation_logged
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86 """Limit invocation of a wrapped function to n calls per period.
87 Thread safe. In testing this was relatively fair with multiple
88 threads using it though that hasn't been measured.
91 >>> import decorator_utils
92 >>> import thread_utils
96 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97 ... def limited(x: int):
101 >>> @thread_utils.background_thread
103 ... for _ in range(3):
106 >>> @thread_utils.background_thread
108 ... for _ in range(3):
111 >>> start = time.time()
116 >>> end = time.time()
117 >>> dur = end - start
125 min_interval_seconds = per_period_in_seconds / float(n_calls)
127 def wrapper_rate_limited(func: Callable) -> Callable:
128 cv = threading.Condition()
129 last_invocation_timestamp = [0.0]
131 def may_proceed() -> float:
133 last_invocation = last_invocation_timestamp[0]
134 if last_invocation != 0.0:
135 elapsed_since_last = now - last_invocation
136 wait_time = min_interval_seconds - elapsed_since_last
139 logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
142 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
146 lambda: may_proceed() <= 0.0,
147 timeout=may_proceed(),
151 logger.debug('@%.4f> calling it...', time.time())
152 ret = func(*args, **kargs)
153 last_invocation_timestamp[0] = time.time()
155 '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
160 return wrapper_wrapper_rate_limited
162 return wrapper_rate_limited
165 def debug_args(func: Callable) -> Callable:
166 """Print the function signature and return value at each call.
169 ... def foo(a, b, c):
173 ... return (a + b, c)
175 >>> foo(1, 2.0, "test")
176 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
180 foo returned (3.0, 'test'):<class 'tuple'>
184 @functools.wraps(func)
185 def wrapper_debug_args(*args, **kwargs):
186 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188 signature = ", ".join(args_repr + kwargs_repr)
189 msg = f"Calling {func.__qualname__}({signature})"
192 value = func(*args, **kwargs)
193 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
198 return wrapper_debug_args
201 def debug_count_calls(func: Callable) -> Callable:
202 """Count function invocations and print a message befor every call.
204 >>> @debug_count_calls
208 ... return x * factoral(x - 1)
211 Call #1 of 'factoral'
212 Call #2 of 'factoral'
213 Call #3 of 'factoral'
214 Call #4 of 'factoral'
215 Call #5 of 'factoral'
220 @functools.wraps(func)
221 def wrapper_debug_count_calls(*args, **kwargs):
222 wrapper_debug_count_calls.num_calls += 1
223 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
226 return func(*args, **kwargs)
228 wrapper_debug_count_calls.num_calls = 0 # type: ignore
229 return wrapper_debug_count_calls
232 class DelayWhen(enum.IntEnum):
233 """When should we delay: before or after calling the function (or
244 _func: Callable = None,
246 seconds: float = 1.0,
247 when: DelayWhen = DelayWhen.BEFORE_CALL,
249 """Delay the execution of a function by sleeping before and/or after.
251 Slow down a function by inserting a delay before and/or after its
256 >>> @delay(seconds=1.0)
260 >>> start = time.time()
262 >>> dur = time.time() - start
268 def decorator_delay(func: Callable) -> Callable:
269 @functools.wraps(func)
270 def wrapper_delay(*args, **kwargs):
271 if when & DelayWhen.BEFORE_CALL:
272 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
274 retval = func(*args, **kwargs)
275 if when & DelayWhen.AFTER_CALL:
276 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
283 return decorator_delay
285 return decorator_delay(_func)
288 class _SingletonWrapper:
290 A singleton wrapper class. Its instances would be created
291 for each decorated class.
295 def __init__(self, cls):
296 self.__wrapped__ = cls
297 self._instance = None
299 def __call__(self, *args, **kwargs):
300 """Returns a single instance of decorated class"""
301 logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
302 if self._instance is None:
303 self._instance = self.__wrapped__(*args, **kwargs)
304 return self._instance
309 A singleton decorator. Returns a wrapper objects. A call on that object
310 returns a single instance object of decorated class. Use the __wrapped__
311 attribute to access decorated class directly in unit tests
314 ... class foo(object):
326 return _SingletonWrapper(cls)
329 def memoized(func: Callable) -> Callable:
330 """Keep a cache of previous function call results.
332 The cache here is a dict with a key based on the arguments to the
333 call. Consider also: functools.lru_cache for a more advanced
339 ... def expensive(arg) -> int:
340 ... # Simulate something slow to compute or lookup
344 >>> start = time.time()
345 >>> expensive(5) # Takes about 1 sec
348 >>> expensive(3) # Also takes about 1 sec
351 >>> expensive(5) # Pulls from cache, fast
354 >>> expensive(3) # Pulls from cache again, fast
357 >>> dur = time.time() - start
363 @functools.wraps(func)
364 def wrapper_memoized(*args, **kwargs):
365 cache_key = args + tuple(kwargs.items())
366 if cache_key not in wrapper_memoized.cache:
367 value = func(*args, **kwargs)
368 logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
369 wrapper_memoized.cache[cache_key] = value
371 logger.debug('Returning memoized value for %s', {func.__name__})
372 return wrapper_memoized.cache[cache_key]
374 wrapper_memoized.cache = {} # type: ignore
375 return wrapper_memoized
381 predicate: Callable[..., bool],
382 delay_sec: float = 3.0,
383 backoff: float = 2.0,
385 """Retries a function or method up to a certain number of times
386 with a prescribed initial delay period and backoff rate.
388 tries is the maximum number of attempts to run the function.
389 delay_sec sets the initial delay period in seconds.
390 backoff is a multiplied (must be >1) used to modify the delay.
391 predicate is a function that will be passed the retval of the
392 decorated function and must return True to stop or False to
397 msg = f"backoff must be greater than or equal to 1, got {backoff}"
399 raise ValueError(msg)
401 tries = math.floor(tries)
403 msg = f"tries must be 0 or greater, got {tries}"
405 raise ValueError(msg)
408 msg = f"delay_sec must be greater than 0, got {delay_sec}"
410 raise ValueError(msg)
414 def f_retry(*args, **kwargs):
415 mtries, mdelay = tries, delay_sec # make mutable
416 logger.debug('deco_retry: will make up to %d attempts...', mtries)
417 retval = f(*args, **kwargs)
419 if predicate(retval) is True:
420 logger.debug('Predicate succeeded, deco_retry is done.')
422 logger.debug("Predicate failed, sleeping and retrying.")
426 retval = f(*args, **kwargs)
434 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
435 """A helper for @retry_predicate that retries a decorated
436 function as long as it keeps returning False.
442 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
446 ... return counter >= 3
448 >>> start = time.time()
449 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
452 >>> dur = time.time() - start
461 return retry_predicate(
463 predicate=lambda x: x is True,
469 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
470 """Another helper for @retry_predicate above. Retries up to N
471 times so long as the wrapped function returns None with a delay
472 between each retry and a backoff that can increase the delay.
475 return retry_predicate(
477 predicate=lambda x: x is not None,
483 def deprecated(func):
484 """This is a decorator which can be used to mark functions
485 as deprecated. It will result in a warning being emitted
486 when the function is used.
490 @functools.wraps(func)
491 def wrapper_deprecated(*args, **kwargs):
492 msg = f"Call to deprecated function {func.__qualname__}"
494 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495 print(msg, file=sys.stderr)
496 return func(*args, **kwargs)
498 return wrapper_deprecated
503 Make a function immediately return a function of no args which,
504 when called, waits for the result, which will start being
505 processed in another thread.
508 @functools.wraps(func)
509 def lazy_thunked(*args, **kwargs):
510 wait_event = threading.Event()
513 exc: List[Any] = [False, None]
517 func_result = func(*args, **kwargs)
518 result[0] = func_result
521 exc[1] = sys.exc_info() # (type, value, traceback)
522 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
531 raise exc[1][0](exc[1][1])
534 threading.Thread(target=worker_func).start()
540 ############################################################
542 ############################################################
544 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
546 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
549 def _raise_exception(exception, error_message: Optional[str]):
550 if error_message is None:
551 raise Exception(exception)
553 raise Exception(error_message)
556 def _target(queue, function, *args, **kwargs):
557 """Run a function with arguments and return output via a queue.
559 This is a helper function for the Process created in _Timeout. It runs
560 the function with positional arguments and keyword arguments and then
561 returns the function's output by way of a queue. If an exception gets
562 raised, it is returned to _Timeout to be raised by the value property.
565 queue.put((True, function(*args, **kwargs)))
567 queue.put((False, sys.exc_info()[1]))
570 class _Timeout(object):
571 """Wrap a function and add a timeout to it.
573 Instances of this class are automatically generated by the add_timeout
574 function defined below. Do not use directly.
580 timeout_exception: Exception,
584 self.__limit = seconds
585 self.__function = function
586 self.__timeout_exception = timeout_exception
587 self.__error_message = error_message
588 self.__name__ = function.__name__
589 self.__doc__ = function.__doc__
590 self.__timeout = time.time()
591 self.__process = multiprocessing.Process()
592 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
594 def __call__(self, *args, **kwargs):
595 """Execute the embedded function object asynchronously.
597 The function given to the constructor is transparently called and
598 requires that "ready" be intermittently polled. If and when it is
599 True, the "value" property may then be checked for returned data.
601 self.__limit = kwargs.pop("timeout", self.__limit)
602 self.__queue = multiprocessing.Queue(1)
603 args = (self.__queue, self.__function) + args
604 self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
605 self.__process.daemon = True
606 self.__process.start()
607 if self.__limit is not None:
608 self.__timeout = self.__limit + time.time()
609 while not self.ready:
614 """Terminate any possible execution of the embedded function."""
615 if self.__process.is_alive():
616 self.__process.terminate()
617 _raise_exception(self.__timeout_exception, self.__error_message)
621 """Read-only property indicating status of "value" property."""
622 if self.__limit and self.__timeout < time.time():
624 return self.__queue.full() and not self.__queue.empty()
628 """Read-only property containing data returned from function."""
629 if self.ready is True:
630 flag, load = self.__queue.get()
638 seconds: float = 1.0,
639 use_signals: Optional[bool] = None,
640 timeout_exception=exceptions.TimeoutError,
641 error_message="Function call timed out",
643 """Add a timeout parameter to a function and return the function.
645 Note: the use_signals parameter is included in order to support
646 multiprocessing scenarios (signal can only be used from the process'
647 main thread). When not using signals, timeout granularity will be
648 rounded to the nearest 0.1s.
650 Raises an exception when/if the timeout is reached.
652 It is illegal to pass anything other than a function as the first
653 parameter. The function is wrapped and returned to the caller.
656 ... def foo(delay: float):
657 ... time.sleep(delay)
664 Traceback (most recent call last):
666 Exception: Function call timed out
669 if use_signals is None:
672 use_signals = thread_utils.is_current_thread_main_thread()
674 def decorate(function):
677 def handler(unused_signum, unused_frame):
678 _raise_exception(timeout_exception, error_message)
680 @functools.wraps(function)
681 def new_function(*args, **kwargs):
682 new_seconds = kwargs.pop("timeout", seconds)
684 old = signal.signal(signal.SIGALRM, handler)
685 signal.setitimer(signal.ITIMER_REAL, new_seconds)
688 return function(*args, **kwargs)
691 return function(*args, **kwargs)
694 signal.setitimer(signal.ITIMER_REAL, 0)
695 signal.signal(signal.SIGALRM, old)
700 @functools.wraps(function)
701 def new_function(*args, **kwargs):
702 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
703 return timeout_wrapper(*args, **kwargs)
710 def synchronized(lock):
713 def _gatekeeper(*args, **kw):
716 return f(*args, **kw)
725 def call_with_sample_rate(sample_rate: float) -> Callable:
726 if not 0.0 <= sample_rate <= 1.0:
727 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
729 raise ValueError(msg)
733 def _call_with_sample_rate(*args, **kwargs):
734 if random.uniform(0, 1) < sample_rate:
735 return f(*args, **kwargs)
737 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
740 return _call_with_sample_rate
745 def decorate_matching_methods_with(decorator, acl=None):
746 """Apply decorator to all methods in a class whose names begin with
747 prefix. If prefix is None (default), decorate all methods in the
751 def decorate_the_class(cls):
752 for name, m in inspect.getmembers(cls, inspect.isfunction):
754 setattr(cls, name, decorator(m))
757 setattr(cls, name, decorator(m))
760 return decorate_the_class
763 if __name__ == '__main__':