3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
6 """Useful(?) decorators."""
13 import multiprocessing
21 from typing import Any, Callable, List, Optional
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
26 logger = logging.getLogger(__name__)
29 def timed(func: Callable) -> Callable:
30 """Print the runtime of the decorated function.
37 >>> foo() # doctest: +ELLIPSIS
42 @functools.wraps(func)
43 def wrapper_timer(*args, **kwargs):
44 start_time = time.perf_counter()
45 value = func(*args, **kwargs)
46 end_time = time.perf_counter()
47 run_time = end_time - start_time
48 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
56 def invocation_logged(func: Callable) -> Callable:
57 """Log the call of a function on stdout and the info log.
59 >>> @invocation_logged
61 ... print('Hello, world.')
70 @functools.wraps(func)
71 def wrapper_invocation_logged(*args, **kwargs):
72 msg = f"Entered {func.__qualname__}"
75 ret = func(*args, **kwargs)
76 msg = f"Exited {func.__qualname__}"
81 return wrapper_invocation_logged
84 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
85 """Limit invocation of a wrapped function to n calls per time period.
86 Thread safe. In testing this was relatively fair with multiple
87 threads using it though that hasn't been measured in detail.
90 >>> from pyutils import decorator_utils
91 >>> from pyutils.parallelize import thread_utils
95 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
96 ... def limited(x: int):
100 >>> @thread_utils.background_thread
102 ... for _ in range(3):
105 >>> @thread_utils.background_thread
107 ... for _ in range(3):
110 >>> start = time.time()
115 >>> end = time.time()
116 >>> dur = end - start
124 min_interval_seconds = per_period_in_seconds / float(n_calls)
126 def wrapper_rate_limited(func: Callable) -> Callable:
127 cv = threading.Condition()
128 last_invocation_timestamp = [0.0]
130 def may_proceed() -> float:
132 last_invocation = last_invocation_timestamp[0]
133 if last_invocation != 0.0:
134 elapsed_since_last = now - last_invocation
135 wait_time = min_interval_seconds - elapsed_since_last
138 logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
141 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
145 lambda: may_proceed() <= 0.0,
146 timeout=may_proceed(),
150 logger.debug('@%.4f> calling it...', time.time())
151 ret = func(*args, **kargs)
152 last_invocation_timestamp[0] = time.time()
154 '@%.4f> Last invocation <- %.4f',
156 last_invocation_timestamp[0],
161 return wrapper_wrapper_rate_limited
163 return wrapper_rate_limited
166 def debug_args(func: Callable) -> Callable:
167 """Print the function signature and return value at each call.
170 ... def foo(a, b, c):
174 ... return (a + b, c)
176 >>> foo(1, 2.0, "test")
177 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
181 foo returned (3.0, 'test'):<class 'tuple'>
185 @functools.wraps(func)
186 def wrapper_debug_args(*args, **kwargs):
187 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
188 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
189 signature = ", ".join(args_repr + kwargs_repr)
190 msg = f"Calling {func.__qualname__}({signature})"
193 value = func(*args, **kwargs)
194 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
199 return wrapper_debug_args
202 def debug_count_calls(func: Callable) -> Callable:
203 """Count function invocations and print a message befor every call.
205 >>> @debug_count_calls
209 ... return x * factoral(x - 1)
212 Call #1 of 'factoral'
213 Call #2 of 'factoral'
214 Call #3 of 'factoral'
215 Call #4 of 'factoral'
216 Call #5 of 'factoral'
221 @functools.wraps(func)
222 def wrapper_debug_count_calls(*args, **kwargs):
223 wrapper_debug_count_calls.num_calls += 1
224 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
227 return func(*args, **kwargs)
229 wrapper_debug_count_calls.num_calls = 0 # type: ignore
230 return wrapper_debug_count_calls
233 class DelayWhen(enum.IntEnum):
234 """When should we delay: before or after calling the function (or
245 _func: Callable = None,
247 seconds: float = 1.0,
248 when: DelayWhen = DelayWhen.BEFORE_CALL,
250 """Slow down a function by inserting a delay before and/or after its
255 >>> @delay(seconds=1.0)
259 >>> start = time.time()
261 >>> dur = time.time() - start
267 def decorator_delay(func: Callable) -> Callable:
268 @functools.wraps(func)
269 def wrapper_delay(*args, **kwargs):
270 if when & DelayWhen.BEFORE_CALL:
271 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
273 retval = func(*args, **kwargs)
274 if when & DelayWhen.AFTER_CALL:
275 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
282 return decorator_delay
284 return decorator_delay(_func)
287 class _SingletonWrapper:
289 A singleton wrapper class. Its instances would be created
290 for each decorated class.
294 def __init__(self, cls):
295 self.__wrapped__ = cls
296 self._instance = None
298 def __call__(self, *args, **kwargs):
299 """Returns a single instance of decorated class"""
301 '@singleton returning global instance of %s', self.__wrapped__.__name__
303 if self._instance is None:
304 self._instance = self.__wrapped__(*args, **kwargs)
305 return self._instance
310 A singleton decorator. Returns a wrapper objects. A call on that object
311 returns a single instance object of decorated class. Use the __wrapped__
312 attribute to access decorated class directly in unit tests
315 ... class foo(object):
327 return _SingletonWrapper(cls)
330 def memoized(func: Callable) -> Callable:
331 """Keep a cache of previous function call results.
333 The cache here is a dict with a key based on the arguments to the
334 call. Consider also: functools.cache for a more advanced
336 https://docs.python.org/3/library/functools.html#functools.cache
341 ... def expensive(arg) -> int:
342 ... # Simulate something slow to compute or lookup
346 >>> start = time.time()
347 >>> expensive(5) # Takes about 1 sec
350 >>> expensive(3) # Also takes about 1 sec
353 >>> expensive(5) # Pulls from cache, fast
356 >>> expensive(3) # Pulls from cache again, fast
359 >>> dur = time.time() - start
365 @functools.wraps(func)
366 def wrapper_memoized(*args, **kwargs):
367 cache_key = args + tuple(kwargs.items())
368 if cache_key not in wrapper_memoized.cache:
369 value = func(*args, **kwargs)
370 logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
371 wrapper_memoized.cache[cache_key] = value
373 logger.debug('Returning memoized value for %s', {func.__name__})
374 return wrapper_memoized.cache[cache_key]
376 wrapper_memoized.cache = {} # type: ignore
377 return wrapper_memoized
383 predicate: Callable[..., bool],
384 delay_sec: float = 3.0,
385 backoff: float = 2.0,
387 """Retries a function or method up to a certain number of times with a
388 prescribed initial delay period and backoff rate (multiplier).
391 tries: the maximum number of attempts to run the function
392 delay_sec: sets the initial delay period in seconds
393 backoff: a multiplier (must be >=1.0) used to modify the
394 delay at each subsequent invocation
395 predicate: a Callable that will be passed the retval of
396 the decorated function and must return True to indicate
397 that we should stop calling or False to indicate a retry
402 msg = f"backoff must be greater than or equal to 1, got {backoff}"
404 raise ValueError(msg)
406 tries = math.floor(tries)
408 msg = f"tries must be 0 or greater, got {tries}"
410 raise ValueError(msg)
413 msg = f"delay_sec must be greater than 0, got {delay_sec}"
415 raise ValueError(msg)
419 def f_retry(*args, **kwargs):
420 mtries, mdelay = tries, delay_sec # make mutable
421 logger.debug('deco_retry: will make up to %d attempts...', mtries)
422 retval = f(*args, **kwargs)
424 if predicate(retval) is True:
425 logger.debug('Predicate succeeded, deco_retry is done.')
427 logger.debug("Predicate failed, sleeping and retrying.")
431 retval = f(*args, **kwargs)
439 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
440 """A helper for @retry_predicate that retries a decorated
441 function as long as it keeps returning False.
446 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
450 ... return counter >= 3
452 >>> start = time.time()
453 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
456 >>> dur = time.time() - start
465 return retry_predicate(
467 predicate=lambda x: x is True,
473 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
474 """Another helper for @retry_predicate above. Retries up to N
475 times so long as the wrapped function returns None with a delay
476 between each retry and a backoff that can increase the delay.
479 return retry_predicate(
481 predicate=lambda x: x is not None,
487 def deprecated(func):
488 """This is a decorator which can be used to mark functions
489 as deprecated. It will result in a warning being emitted
490 when the function is used.
493 ... def foo() -> None:
495 >>> foo() # prints + logs "Call to deprecated function foo"
498 @functools.wraps(func)
499 def wrapper_deprecated(*args, **kwargs):
500 msg = f"Call to deprecated function {func.__qualname__}"
502 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
503 print(msg, file=sys.stderr)
504 return func(*args, **kwargs)
506 return wrapper_deprecated
511 Make a function immediately return a function of no args which,
512 when called, waits for the result, which will start being
513 processed in another thread.
516 @functools.wraps(func)
517 def lazy_thunked(*args, **kwargs):
518 wait_event = threading.Event()
521 exc: List[Any] = [False, None]
525 func_result = func(*args, **kwargs)
526 result[0] = func_result
529 exc[1] = sys.exc_info() # (type, value, traceback)
530 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
539 raise exc[1][0](exc[1][1])
542 threading.Thread(target=worker_func).start()
548 ############################################################
550 ############################################################
552 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
554 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
556 # Original work is covered by PSF-2.0:
558 # 1. This LICENSE AGREEMENT is between the Python Software Foundation
559 # ("PSF"), and the Individual or Organization ("Licensee") accessing
560 # and otherwise using this software ("Python") in source or binary
561 # form and its associated documentation.
563 # 2. Subject to the terms and conditions of this License Agreement,
564 # PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide
565 # license to reproduce, analyze, test, perform and/or display
566 # publicly, prepare derivative works, distribute, and otherwise use
567 # Python alone or in any derivative version, provided, however, that
568 # PSF's License Agreement and PSF's notice of copyright, i.e.,
569 # "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006 Python Software
570 # Foundation; All Rights Reserved" are retained in Python alone or in
571 # any derivative version prepared by Licensee.
573 # 3. In the event Licensee prepares a derivative work that is based on
574 # or incorporates Python or any part thereof, and wants to make the
575 # derivative work available to others as provided herein, then
576 # Licensee hereby agrees to include in any such work a brief summary
577 # of the changes made to Python.
579 # (N.B. See NOTICE file in the root of this module for a list of
582 # 4. PSF is making Python available to Licensee on an "AS IS"
583 # basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
584 # IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
585 # DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR
586 # FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL
587 # NOT INFRINGE ANY THIRD PARTY RIGHTS.
589 # 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
590 # FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A
591 # RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY
592 # DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
594 # 6. This License Agreement will automatically terminate upon a
595 # material breach of its terms and conditions.
597 # 7. Nothing in this License Agreement shall be deemed to create any
598 # relationship of agency, partnership, or joint venture between PSF
599 # and Licensee. This License Agreement does not grant permission to
600 # use PSF trademarks or trade name in a trademark sense to endorse or
601 # promote products or services of Licensee, or any third party.
603 # 8. By copying, installing or otherwise using Python, Licensee agrees
604 # to be bound by the terms and conditions of this License Agreement.
607 def _raise_exception(exception, error_message: Optional[str]):
608 if error_message is None:
609 raise Exception(exception)
611 raise Exception(error_message)
614 def _target(queue, function, *args, **kwargs):
615 """Run a function with arguments and return output via a queue.
617 This is a helper function for the Process created in _Timeout. It runs
618 the function with positional arguments and keyword arguments and then
619 returns the function's output by way of a queue. If an exception gets
620 raised, it is returned to _Timeout to be raised by the value property.
623 queue.put((True, function(*args, **kwargs)))
625 queue.put((False, sys.exc_info()[1]))
628 class _Timeout(object):
629 """Wrap a function and add a timeout to it.
631 Instances of this class are automatically generated by the add_timeout
632 function defined below. Do not use directly.
638 timeout_exception: Exception,
642 self.__limit = seconds
643 self.__function = function
644 self.__timeout_exception = timeout_exception
645 self.__error_message = error_message
646 self.__name__ = function.__name__
647 self.__doc__ = function.__doc__
648 self.__timeout = time.time()
649 self.__process = multiprocessing.Process()
650 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
652 def __call__(self, *args, **kwargs):
653 """Execute the embedded function object asynchronously.
655 The function given to the constructor is transparently called and
656 requires that "ready" be intermittently polled. If and when it is
657 True, the "value" property may then be checked for returned data.
659 self.__limit = kwargs.pop("timeout", self.__limit)
660 self.__queue = multiprocessing.Queue(1)
661 args = (self.__queue, self.__function) + args
662 self.__process = multiprocessing.Process(
663 target=_target, args=args, kwargs=kwargs
665 self.__process.daemon = True
666 self.__process.start()
667 if self.__limit is not None:
668 self.__timeout = self.__limit + time.time()
669 while not self.ready:
674 """Terminate any possible execution of the embedded function."""
675 if self.__process.is_alive():
676 self.__process.terminate()
677 _raise_exception(self.__timeout_exception, self.__error_message)
681 """Read-only property indicating status of "value" property."""
682 if self.__limit and self.__timeout < time.time():
684 return self.__queue.full() and not self.__queue.empty()
688 """Read-only property containing data returned from function."""
689 if self.ready is True:
690 flag, load = self.__queue.get()
698 seconds: float = 1.0,
699 use_signals: Optional[bool] = None,
700 timeout_exception=TimeoutError,
701 error_message="Function call timed out",
703 """Add a timeout parameter to a function and return the function.
705 Note: the use_signals parameter is included in order to support
706 multiprocessing scenarios (signal can only be used from the process'
707 main thread). When not using signals, timeout granularity will be
708 rounded to the nearest 0.1s.
710 Beware that an @timeout on a function inside a module will be
711 evaluated at module load time and not when the wrapped function is
712 invoked. This can lead to problems when relying on the automatic
713 main thread detection code (use_signals=None, the default) since
714 the import probably happens on the main thread and the invocation
715 can happen on a different thread (which can't use signals).
717 Raises an exception when/if the timeout is reached.
719 It is illegal to pass anything other than a function as the first
720 parameter. The function is wrapped and returned to the caller.
723 ... def foo(delay: float):
724 ... time.sleep(delay)
731 Traceback (most recent call last):
733 Exception: Function call timed out
736 if use_signals is None:
737 import pyutils.parallelize.thread_utils as tu
739 use_signals = tu.is_current_thread_main_thread()
741 def decorate(function):
744 def handler(unused_signum, unused_frame):
745 _raise_exception(timeout_exception, error_message)
747 @functools.wraps(function)
748 def new_function(*args, **kwargs):
749 new_seconds = kwargs.pop("timeout", seconds)
751 old = signal.signal(signal.SIGALRM, handler)
752 signal.setitimer(signal.ITIMER_REAL, new_seconds)
755 return function(*args, **kwargs)
758 return function(*args, **kwargs)
761 signal.setitimer(signal.ITIMER_REAL, 0)
762 signal.signal(signal.SIGALRM, old)
767 @functools.wraps(function)
768 def new_function(*args, **kwargs):
769 timeout_wrapper = _Timeout(
770 function, timeout_exception, error_message, seconds
772 return timeout_wrapper(*args, **kwargs)
779 def synchronized(lock):
780 """Emulates java's synchronized keyword: given a lock, require that
781 threads take that lock (or wait) before invoking the wrapped
782 function and automatically releases the lock afterwards.
787 def _gatekeeper(*args, **kw):
790 return f(*args, **kw)
799 def call_with_sample_rate(sample_rate: float) -> Callable:
800 """Calls the wrapped function probabilistically given a rate between
801 0.0 and 1.0 inclusive (0% probability and 100% probability).
804 if not 0.0 <= sample_rate <= 1.0:
805 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
807 raise ValueError(msg)
811 def _call_with_sample_rate(*args, **kwargs):
812 if random.uniform(0, 1) < sample_rate:
813 return f(*args, **kwargs)
815 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
818 return _call_with_sample_rate
823 def decorate_matching_methods_with(decorator, acl=None):
824 """Apply the given decorator to all methods in a class whose names
825 begin with prefix. If prefix is None (default), decorate all
826 methods in the class.
829 def decorate_the_class(cls):
830 for name, m in inspect.getmembers(cls, inspect.isfunction):
832 setattr(cls, name, decorator(m))
835 setattr(cls, name, decorator(m))
838 return decorate_the_class
841 if __name__ == '__main__':