10 import multiprocessing
17 from typing import Any, Callable, Optional
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
25 logger = logging.getLogger(__name__)
28 def timed(func: Callable) -> Callable:
29 """Print the runtime of the decorated function.
36 >>> foo() # doctest: +ELLIPSIS
41 @functools.wraps(func)
42 def wrapper_timer(*args, **kwargs):
43 start_time = time.perf_counter()
44 value = func(*args, **kwargs)
45 end_time = time.perf_counter()
46 run_time = end_time - start_time
47 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
55 def invocation_logged(func: Callable) -> Callable:
56 """Log the call of a function.
58 >>> @invocation_logged
60 ... print('Hello, world.')
69 @functools.wraps(func)
70 def wrapper_invocation_logged(*args, **kwargs):
71 msg = f"Entered {func.__qualname__}"
74 ret = func(*args, **kwargs)
75 msg = f"Exited {func.__qualname__}"
80 return wrapper_invocation_logged
83 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
84 """Limit invocation of a wrapped function to n calls per period.
85 Thread safe. In testing this was relatively fair with multiple
86 threads using it though that hasn't been measured.
89 >>> import decorator_utils
90 >>> import thread_utils
94 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
95 ... def limited(x: int):
99 >>> @thread_utils.background_thread
101 ... for _ in range(3):
104 >>> @thread_utils.background_thread
106 ... for _ in range(3):
109 >>> start = time.time()
114 >>> end = time.time()
115 >>> dur = end - start
123 min_interval_seconds = per_period_in_seconds / float(n_calls)
125 def wrapper_rate_limited(func: Callable) -> Callable:
126 cv = threading.Condition()
127 last_invocation_timestamp = [0.0]
129 def may_proceed() -> float:
131 last_invocation = last_invocation_timestamp[0]
132 if last_invocation != 0.0:
133 elapsed_since_last = now - last_invocation
134 wait_time = min_interval_seconds - elapsed_since_last
137 logger.debug(f'@{time.time()}> wait_time = {wait_time}')
140 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
144 lambda: may_proceed() <= 0.0,
145 timeout=may_proceed(),
149 logger.debug(f'@{time.time()}> calling it...')
150 ret = func(*args, **kargs)
151 last_invocation_timestamp[0] = time.time()
153 f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
158 return wrapper_wrapper_rate_limited
160 return wrapper_rate_limited
163 def debug_args(func: Callable) -> Callable:
164 """Print the function signature and return value at each call.
167 ... def foo(a, b, c):
171 ... return (a + b, c)
173 >>> foo(1, 2.0, "test")
174 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
178 foo returned (3.0, 'test'):<class 'tuple'>
182 @functools.wraps(func)
183 def wrapper_debug_args(*args, **kwargs):
184 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
185 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
186 signature = ", ".join(args_repr + kwargs_repr)
187 msg = f"Calling {func.__qualname__}({signature})"
190 value = func(*args, **kwargs)
191 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
196 return wrapper_debug_args
199 def debug_count_calls(func: Callable) -> Callable:
200 """Count function invocations and print a message befor every call.
202 >>> @debug_count_calls
206 ... return x * factoral(x - 1)
209 Call #1 of 'factoral'
210 Call #2 of 'factoral'
211 Call #3 of 'factoral'
212 Call #4 of 'factoral'
213 Call #5 of 'factoral'
218 @functools.wraps(func)
219 def wrapper_debug_count_calls(*args, **kwargs):
220 wrapper_debug_count_calls.num_calls += 1
221 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
224 return func(*args, **kwargs)
226 wrapper_debug_count_calls.num_calls = 0
227 return wrapper_debug_count_calls
230 class DelayWhen(enum.IntEnum):
237 _func: Callable = None,
239 seconds: float = 1.0,
240 when: DelayWhen = DelayWhen.BEFORE_CALL,
242 """Delay the execution of a function by sleeping before and/or after.
244 Slow down a function by inserting a delay before and/or after its
249 >>> @delay(seconds=1.0)
253 >>> start = time.time()
255 >>> dur = time.time() - start
261 def decorator_delay(func: Callable) -> Callable:
262 @functools.wraps(func)
263 def wrapper_delay(*args, **kwargs):
264 if when & DelayWhen.BEFORE_CALL:
265 logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
267 retval = func(*args, **kwargs)
268 if when & DelayWhen.AFTER_CALL:
269 logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
276 return decorator_delay
278 return decorator_delay(_func)
281 class _SingletonWrapper:
283 A singleton wrapper class. Its instances would be created
284 for each decorated class.
288 def __init__(self, cls):
289 self.__wrapped__ = cls
290 self._instance = None
292 def __call__(self, *args, **kwargs):
293 """Returns a single instance of decorated class"""
295 f"@singleton returning global instance of {self.__wrapped__.__name__}"
297 if self._instance is None:
298 self._instance = self.__wrapped__(*args, **kwargs)
299 return self._instance
304 A singleton decorator. Returns a wrapper objects. A call on that object
305 returns a single instance object of decorated class. Use the __wrapped__
306 attribute to access decorated class directly in unit tests
309 ... class foo(object):
321 return _SingletonWrapper(cls)
324 def memoized(func: Callable) -> Callable:
325 """Keep a cache of previous function call results.
327 The cache here is a dict with a key based on the arguments to the
328 call. Consider also: functools.lru_cache for a more advanced
334 ... def expensive(arg) -> int:
335 ... # Simulate something slow to compute or lookup
339 >>> start = time.time()
340 >>> expensive(5) # Takes about 1 sec
343 >>> expensive(3) # Also takes about 1 sec
346 >>> expensive(5) # Pulls from cache, fast
349 >>> expensive(3) # Pulls from cache again, fast
352 >>> dur = time.time() - start
358 @functools.wraps(func)
359 def wrapper_memoized(*args, **kwargs):
360 cache_key = args + tuple(kwargs.items())
361 if cache_key not in wrapper_memoized.cache:
362 value = func(*args, **kwargs)
363 logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
364 wrapper_memoized.cache[cache_key] = value
366 logger.debug(f"Returning memoized value for {func.__name__}")
367 return wrapper_memoized.cache[cache_key]
369 wrapper_memoized.cache = dict()
370 return wrapper_memoized
376 predicate: Callable[..., bool],
377 delay_sec: float = 3.0,
378 backoff: float = 2.0,
380 """Retries a function or method up to a certain number of times
381 with a prescribed initial delay period and backoff rate.
383 tries is the maximum number of attempts to run the function.
384 delay_sec sets the initial delay period in seconds.
385 backoff is a multiplied (must be >1) used to modify the delay.
386 predicate is a function that will be passed the retval of the
387 decorated function and must return True to stop or False to
392 msg = f"backoff must be greater than or equal to 1, got {backoff}"
394 raise ValueError(msg)
396 tries = math.floor(tries)
398 msg = f"tries must be 0 or greater, got {tries}"
400 raise ValueError(msg)
403 msg = f"delay_sec must be greater than 0, got {delay_sec}"
405 raise ValueError(msg)
409 def f_retry(*args, **kwargs):
410 mtries, mdelay = tries, delay_sec # make mutable
411 logger.debug(f'deco_retry: will make up to {mtries} attempts...')
412 retval = f(*args, **kwargs)
414 if predicate(retval) is True:
415 logger.debug('Predicate succeeded, deco_retry is done.')
417 logger.debug("Predicate failed, sleeping and retrying.")
421 retval = f(*args, **kwargs)
429 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
430 """A helper for @retry_predicate that retries a decorated
431 function as long as it keeps returning False.
437 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
441 ... return counter >= 3
443 >>> start = time.time()
444 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
447 >>> dur = time.time() - start
456 return retry_predicate(
458 predicate=lambda x: x is True,
464 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
465 """Another helper for @retry_predicate above. Retries up to N
466 times so long as the wrapped function returns None with a delay
467 between each retry and a backoff that can increase the delay.
470 return retry_predicate(
472 predicate=lambda x: x is not None,
478 def deprecated(func):
479 """This is a decorator which can be used to mark functions
480 as deprecated. It will result in a warning being emitted
481 when the function is used.
485 @functools.wraps(func)
486 def wrapper_deprecated(*args, **kwargs):
487 msg = f"Call to deprecated function {func.__qualname__}"
489 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
490 print(msg, file=sys.stderr)
491 return func(*args, **kwargs)
493 return wrapper_deprecated
498 Make a function immediately return a function of no args which,
499 when called, waits for the result, which will start being
500 processed in another thread.
503 @functools.wraps(func)
504 def lazy_thunked(*args, **kwargs):
505 wait_event = threading.Event()
512 func_result = func(*args, **kwargs)
513 result[0] = func_result
516 exc[1] = sys.exc_info() # (type, value, traceback)
517 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
525 raise exc[1][0](exc[1][1])
528 threading.Thread(target=worker_func).start()
534 ############################################################
536 ############################################################
538 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
540 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
543 def _raise_exception(exception, error_message: Optional[str]):
544 if error_message is None:
547 raise Exception(error_message)
550 def _target(queue, function, *args, **kwargs):
551 """Run a function with arguments and return output via a queue.
553 This is a helper function for the Process created in _Timeout. It runs
554 the function with positional arguments and keyword arguments and then
555 returns the function's output by way of a queue. If an exception gets
556 raised, it is returned to _Timeout to be raised by the value property.
559 queue.put((True, function(*args, **kwargs)))
561 queue.put((False, sys.exc_info()[1]))
564 class _Timeout(object):
565 """Wrap a function and add a timeout to it.
567 Instances of this class are automatically generated by the add_timeout
568 function defined below. Do not use directly.
574 timeout_exception: Exception,
578 self.__limit = seconds
579 self.__function = function
580 self.__timeout_exception = timeout_exception
581 self.__error_message = error_message
582 self.__name__ = function.__name__
583 self.__doc__ = function.__doc__
584 self.__timeout = time.time()
585 self.__process = multiprocessing.Process()
586 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
588 def __call__(self, *args, **kwargs):
589 """Execute the embedded function object asynchronously.
591 The function given to the constructor is transparently called and
592 requires that "ready" be intermittently polled. If and when it is
593 True, the "value" property may then be checked for returned data.
595 self.__limit = kwargs.pop("timeout", self.__limit)
596 self.__queue = multiprocessing.Queue(1)
597 args = (self.__queue, self.__function) + args
598 self.__process = multiprocessing.Process(
599 target=_target, args=args, kwargs=kwargs
601 self.__process.daemon = True
602 self.__process.start()
603 if self.__limit is not None:
604 self.__timeout = self.__limit + time.time()
605 while not self.ready:
610 """Terminate any possible execution of the embedded function."""
611 if self.__process.is_alive():
612 self.__process.terminate()
613 _raise_exception(self.__timeout_exception, self.__error_message)
617 """Read-only property indicating status of "value" property."""
618 if self.__limit and self.__timeout < time.time():
620 return self.__queue.full() and not self.__queue.empty()
624 """Read-only property containing data returned from function."""
625 if self.ready is True:
626 flag, load = self.__queue.get()
633 seconds: float = 1.0,
634 use_signals: Optional[bool] = None,
635 timeout_exception=exceptions.TimeoutError,
636 error_message="Function call timed out",
638 """Add a timeout parameter to a function and return the function.
640 Note: the use_signals parameter is included in order to support
641 multiprocessing scenarios (signal can only be used from the process'
642 main thread). When not using signals, timeout granularity will be
643 rounded to the nearest 0.1s.
645 Raises an exception when/if the timeout is reached.
647 It is illegal to pass anything other than a function as the first
648 parameter. The function is wrapped and returned to the caller.
651 ... def foo(delay: float):
652 ... time.sleep(delay)
659 Traceback (most recent call last):
661 Exception: Function call timed out
664 if use_signals is None:
667 use_signals = thread_utils.is_current_thread_main_thread()
669 def decorate(function):
672 def handler(signum, frame):
673 _raise_exception(timeout_exception, error_message)
675 @functools.wraps(function)
676 def new_function(*args, **kwargs):
677 new_seconds = kwargs.pop("timeout", seconds)
679 old = signal.signal(signal.SIGALRM, handler)
680 signal.setitimer(signal.ITIMER_REAL, new_seconds)
683 return function(*args, **kwargs)
686 return function(*args, **kwargs)
689 signal.setitimer(signal.ITIMER_REAL, 0)
690 signal.signal(signal.SIGALRM, old)
695 @functools.wraps(function)
696 def new_function(*args, **kwargs):
697 timeout_wrapper = _Timeout(
698 function, timeout_exception, error_message, seconds
700 return timeout_wrapper(*args, **kwargs)
707 def synchronized(lock):
710 def _gatekeeper(*args, **kw):
713 return f(*args, **kw)
722 def call_with_sample_rate(sample_rate: float) -> Callable:
723 if not 0.0 <= sample_rate <= 1.0:
724 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
726 raise ValueError(msg)
730 def _call_with_sample_rate(*args, **kwargs):
731 if random.uniform(0, 1) < sample_rate:
732 return f(*args, **kwargs)
734 logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
736 return _call_with_sample_rate
741 def decorate_matching_methods_with(decorator, acl=None):
742 """Apply decorator to all methods in a class whose names begin with
743 prefix. If prefix is None (default), decorate all methods in the
747 def decorate_the_class(cls):
748 for name, m in inspect.getmembers(cls, inspect.isfunction):
750 setattr(cls, name, decorator(m))
753 setattr(cls, name, decorator(m))
756 return decorate_the_class
759 if __name__ == '__main__':