10 import multiprocessing
17 from typing import Any, Callable, Optional
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
25 logger = logging.getLogger(__name__)
28 def timed(func: Callable) -> Callable:
29 """Print the runtime of the decorated function.
36 >>> foo() # doctest: +ELLIPSIS
41 @functools.wraps(func)
42 def wrapper_timer(*args, **kwargs):
43 start_time = time.perf_counter()
44 value = func(*args, **kwargs)
45 end_time = time.perf_counter()
46 run_time = end_time - start_time
47 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
54 def invocation_logged(func: Callable) -> Callable:
55 """Log the call of a function.
57 >>> @invocation_logged
59 ... print('Hello, world.')
68 @functools.wraps(func)
69 def wrapper_invocation_logged(*args, **kwargs):
70 msg = f"Entered {func.__qualname__}"
73 ret = func(*args, **kwargs)
74 msg = f"Exited {func.__qualname__}"
78 return wrapper_invocation_logged
81 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
82 """Limit invocation of a wrapped function to n calls per period.
83 Thread safe. In testing this was relatively fair with multiple
84 threads using it though that hasn't been measured.
87 >>> import decorator_utils
88 >>> import thread_utils
92 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
93 ... def limited(x: int):
97 >>> @thread_utils.background_thread
99 ... for _ in range(3):
102 >>> @thread_utils.background_thread
104 ... for _ in range(3):
107 >>> start = time.time()
112 >>> end = time.time()
113 >>> dur = end - start
121 min_interval_seconds = per_period_in_seconds / float(n_calls)
123 def wrapper_rate_limited(func: Callable) -> Callable:
124 cv = threading.Condition()
125 last_invocation_timestamp = [0.0]
127 def may_proceed() -> float:
129 last_invocation = last_invocation_timestamp[0]
130 if last_invocation != 0.0:
131 elapsed_since_last = now - last_invocation
132 wait_time = min_interval_seconds - elapsed_since_last
135 logger.debug(f'@{time.time()}> wait_time = {wait_time}')
138 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
142 lambda: may_proceed() <= 0.0,
143 timeout=may_proceed(),
147 logger.debug(f'@{time.time()}> calling it...')
148 ret = func(*args, **kargs)
149 last_invocation_timestamp[0] = time.time()
151 f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
155 return wrapper_wrapper_rate_limited
156 return wrapper_rate_limited
159 def debug_args(func: Callable) -> Callable:
160 """Print the function signature and return value at each call.
163 ... def foo(a, b, c):
167 ... return (a + b, c)
169 >>> foo(1, 2.0, "test")
170 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
174 foo returned (3.0, 'test'):<class 'tuple'>
178 @functools.wraps(func)
179 def wrapper_debug_args(*args, **kwargs):
180 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
181 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
182 signature = ", ".join(args_repr + kwargs_repr)
183 msg = f"Calling {func.__qualname__}({signature})"
186 value = func(*args, **kwargs)
187 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
191 return wrapper_debug_args
194 def debug_count_calls(func: Callable) -> Callable:
195 """Count function invocations and print a message befor every call.
197 >>> @debug_count_calls
201 ... return x * factoral(x - 1)
204 Call #1 of 'factoral'
205 Call #2 of 'factoral'
206 Call #3 of 'factoral'
207 Call #4 of 'factoral'
208 Call #5 of 'factoral'
213 @functools.wraps(func)
214 def wrapper_debug_count_calls(*args, **kwargs):
215 wrapper_debug_count_calls.num_calls += 1
216 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
219 return func(*args, **kwargs)
220 wrapper_debug_count_calls.num_calls = 0
221 return wrapper_debug_count_calls
224 class DelayWhen(enum.IntEnum):
231 _func: Callable = None,
233 seconds: float = 1.0,
234 when: DelayWhen = DelayWhen.BEFORE_CALL,
236 """Delay the execution of a function by sleeping before and/or after.
238 Slow down a function by inserting a delay before and/or after its
243 >>> @delay(seconds=1.0)
247 >>> start = time.time()
249 >>> dur = time.time() - start
254 def decorator_delay(func: Callable) -> Callable:
255 @functools.wraps(func)
256 def wrapper_delay(*args, **kwargs):
257 if when & DelayWhen.BEFORE_CALL:
259 f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
262 retval = func(*args, **kwargs)
263 if when & DelayWhen.AFTER_CALL:
265 f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
272 return decorator_delay
274 return decorator_delay(_func)
277 class _SingletonWrapper:
279 A singleton wrapper class. Its instances would be created
280 for each decorated class.
284 def __init__(self, cls):
285 self.__wrapped__ = cls
286 self._instance = None
288 def __call__(self, *args, **kwargs):
289 """Returns a single instance of decorated class"""
291 f"@singleton returning global instance of {self.__wrapped__.__name__}"
293 if self._instance is None:
294 self._instance = self.__wrapped__(*args, **kwargs)
295 return self._instance
300 A singleton decorator. Returns a wrapper objects. A call on that object
301 returns a single instance object of decorated class. Use the __wrapped__
302 attribute to access decorated class directly in unit tests
305 ... class foo(object):
317 return _SingletonWrapper(cls)
320 def memoized(func: Callable) -> Callable:
321 """Keep a cache of previous function call results.
323 The cache here is a dict with a key based on the arguments to the
324 call. Consider also: functools.lru_cache for a more advanced
330 ... def expensive(arg) -> int:
331 ... # Simulate something slow to compute or lookup
335 >>> start = time.time()
336 >>> expensive(5) # Takes about 1 sec
339 >>> expensive(3) # Also takes about 1 sec
342 >>> expensive(5) # Pulls from cache, fast
345 >>> expensive(3) # Pulls from cache again, fast
348 >>> dur = time.time() - start
353 @functools.wraps(func)
354 def wrapper_memoized(*args, **kwargs):
355 cache_key = args + tuple(kwargs.items())
356 if cache_key not in wrapper_memoized.cache:
357 value = func(*args, **kwargs)
359 f"Memoizing {cache_key} => {value} for {func.__name__}"
361 wrapper_memoized.cache[cache_key] = value
363 logger.debug(f"Returning memoized value for {func.__name__}")
364 return wrapper_memoized.cache[cache_key]
365 wrapper_memoized.cache = dict()
366 return wrapper_memoized
372 predicate: Callable[..., bool],
373 delay_sec: float = 3.0,
374 backoff: float = 2.0,
376 """Retries a function or method up to a certain number of times
377 with a prescribed initial delay period and backoff rate.
379 tries is the maximum number of attempts to run the function.
380 delay_sec sets the initial delay period in seconds.
381 backoff is a multiplied (must be >1) used to modify the delay.
382 predicate is a function that will be passed the retval of the
383 decorated function and must return True to stop or False to
388 msg = f"backoff must be greater than or equal to 1, got {backoff}"
390 raise ValueError(msg)
392 tries = math.floor(tries)
394 msg = f"tries must be 0 or greater, got {tries}"
396 raise ValueError(msg)
399 msg = f"delay_sec must be greater than 0, got {delay_sec}"
401 raise ValueError(msg)
405 def f_retry(*args, **kwargs):
406 mtries, mdelay = tries, delay_sec # make mutable
407 logger.debug(f'deco_retry: will make up to {mtries} attempts...')
408 retval = f(*args, **kwargs)
410 if predicate(retval) is True:
411 logger.debug('Predicate succeeded, deco_retry is done.')
413 logger.debug("Predicate failed, sleeping and retrying.")
417 retval = f(*args, **kwargs)
423 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
424 """A helper for @retry_predicate that retries a decorated
425 function as long as it keeps returning False.
431 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
435 ... return counter >= 3
437 >>> start = time.time()
438 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
441 >>> dur = time.time() - start
450 return retry_predicate(
452 predicate=lambda x: x is True,
458 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
459 """Another helper for @retry_predicate above. Retries up to N
460 times so long as the wrapped function returns None with a delay
461 between each retry and a backoff that can increase the delay.
464 return retry_predicate(
466 predicate=lambda x: x is not None,
472 def deprecated(func):
473 """This is a decorator which can be used to mark functions
474 as deprecated. It will result in a warning being emitted
475 when the function is used.
478 @functools.wraps(func)
479 def wrapper_deprecated(*args, **kwargs):
480 msg = f"Call to deprecated function {func.__qualname__}"
482 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
483 print(msg, file=sys.stderr)
484 return func(*args, **kwargs)
485 return wrapper_deprecated
490 Make a function immediately return a function of no args which,
491 when called, waits for the result, which will start being
492 processed in another thread.
495 @functools.wraps(func)
496 def lazy_thunked(*args, **kwargs):
497 wait_event = threading.Event()
504 func_result = func(*args, **kwargs)
505 result[0] = func_result
508 exc[1] = sys.exc_info() # (type, value, traceback)
509 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
518 raise exc[1][0](exc[1][1])
521 threading.Thread(target=worker_func).start()
527 ############################################################
529 ############################################################
531 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
533 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
536 def _raise_exception(exception, error_message: Optional[str]):
537 if error_message is None:
540 raise Exception(error_message)
543 def _target(queue, function, *args, **kwargs):
544 """Run a function with arguments and return output via a queue.
546 This is a helper function for the Process created in _Timeout. It runs
547 the function with positional arguments and keyword arguments and then
548 returns the function's output by way of a queue. If an exception gets
549 raised, it is returned to _Timeout to be raised by the value property.
552 queue.put((True, function(*args, **kwargs)))
554 queue.put((False, sys.exc_info()[1]))
557 class _Timeout(object):
558 """Wrap a function and add a timeout to it.
560 Instances of this class are automatically generated by the add_timeout
561 function defined below. Do not use directly.
567 timeout_exception: Exception,
571 self.__limit = seconds
572 self.__function = function
573 self.__timeout_exception = timeout_exception
574 self.__error_message = error_message
575 self.__name__ = function.__name__
576 self.__doc__ = function.__doc__
577 self.__timeout = time.time()
578 self.__process = multiprocessing.Process()
579 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
581 def __call__(self, *args, **kwargs):
582 """Execute the embedded function object asynchronously.
584 The function given to the constructor is transparently called and
585 requires that "ready" be intermittently polled. If and when it is
586 True, the "value" property may then be checked for returned data.
588 self.__limit = kwargs.pop("timeout", self.__limit)
589 self.__queue = multiprocessing.Queue(1)
590 args = (self.__queue, self.__function) + args
591 self.__process = multiprocessing.Process(
592 target=_target, args=args, kwargs=kwargs
594 self.__process.daemon = True
595 self.__process.start()
596 if self.__limit is not None:
597 self.__timeout = self.__limit + time.time()
598 while not self.ready:
603 """Terminate any possible execution of the embedded function."""
604 if self.__process.is_alive():
605 self.__process.terminate()
606 _raise_exception(self.__timeout_exception, self.__error_message)
610 """Read-only property indicating status of "value" property."""
611 if self.__limit and self.__timeout < time.time():
613 return self.__queue.full() and not self.__queue.empty()
617 """Read-only property containing data returned from function."""
618 if self.ready is True:
619 flag, load = self.__queue.get()
626 seconds: float = 1.0,
627 use_signals: Optional[bool] = None,
628 timeout_exception=exceptions.TimeoutError,
629 error_message="Function call timed out",
631 """Add a timeout parameter to a function and return the function.
633 Note: the use_signals parameter is included in order to support
634 multiprocessing scenarios (signal can only be used from the process'
635 main thread). When not using signals, timeout granularity will be
636 rounded to the nearest 0.1s.
638 Raises an exception when/if the timeout is reached.
640 It is illegal to pass anything other than a function as the first
641 parameter. The function is wrapped and returned to the caller.
644 ... def foo(delay: float):
645 ... time.sleep(delay)
652 Traceback (most recent call last):
654 Exception: Function call timed out
657 if use_signals is None:
659 use_signals = thread_utils.is_current_thread_main_thread()
661 def decorate(function):
664 def handler(signum, frame):
665 _raise_exception(timeout_exception, error_message)
667 @functools.wraps(function)
668 def new_function(*args, **kwargs):
669 new_seconds = kwargs.pop("timeout", seconds)
671 old = signal.signal(signal.SIGALRM, handler)
672 signal.setitimer(signal.ITIMER_REAL, new_seconds)
675 return function(*args, **kwargs)
678 return function(*args, **kwargs)
681 signal.setitimer(signal.ITIMER_REAL, 0)
682 signal.signal(signal.SIGALRM, old)
687 @functools.wraps(function)
688 def new_function(*args, **kwargs):
689 timeout_wrapper = _Timeout(
690 function, timeout_exception, error_message, seconds
692 return timeout_wrapper(*args, **kwargs)
699 class non_reentrant_code(object):
701 self._lock = threading.RLock
702 self._entered = False
704 def __call__(self, f):
705 def _gatekeeper(*args, **kwargs):
711 self._entered = False
715 class rlocked(object):
717 self._lock = threading.RLock
718 self._entered = False
720 def __call__(self, f):
721 def _gatekeeper(*args, **kwargs):
727 self._entered = False
731 def call_with_sample_rate(sample_rate: float) -> Callable:
732 if not 0.0 <= sample_rate <= 1.0:
733 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
735 raise ValueError(msg)
739 def _call_with_sample_rate(*args, **kwargs):
740 if random.uniform(0, 1) < sample_rate:
741 return f(*args, **kwargs)
744 f"@call_with_sample_rate skipping a call to {f.__name__}"
746 return _call_with_sample_rate
750 def decorate_matching_methods_with(decorator, acl=None):
751 """Apply decorator to all methods in a class whose names begin with
752 prefix. If prefix is None (default), decorate all methods in the
755 def decorate_the_class(cls):
756 for name, m in inspect.getmembers(cls, inspect.isfunction):
758 setattr(cls, name, decorator(m))
761 setattr(cls, name, decorator(m))
763 return decorate_the_class
766 if __name__ == '__main__':