10 import multiprocessing
18 from typing import Any, Callable, Optional
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
24 logger = logging.getLogger(__name__)
27 def timed(func: Callable) -> Callable:
28 """Print the runtime of the decorated function.
35 >>> foo() # doctest: +ELLIPSIS
40 @functools.wraps(func)
41 def wrapper_timer(*args, **kwargs):
42 start_time = time.perf_counter()
43 value = func(*args, **kwargs)
44 end_time = time.perf_counter()
45 run_time = end_time - start_time
46 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
54 def invocation_logged(func: Callable) -> Callable:
55 """Log the call of a function.
57 >>> @invocation_logged
59 ... print('Hello, world.')
68 @functools.wraps(func)
69 def wrapper_invocation_logged(*args, **kwargs):
70 msg = f"Entered {func.__qualname__}"
73 ret = func(*args, **kwargs)
74 msg = f"Exited {func.__qualname__}"
79 return wrapper_invocation_logged
82 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
83 """Limit invocation of a wrapped function to n calls per period.
84 Thread safe. In testing this was relatively fair with multiple
85 threads using it though that hasn't been measured.
88 >>> import decorator_utils
89 >>> import thread_utils
93 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
94 ... def limited(x: int):
98 >>> @thread_utils.background_thread
100 ... for _ in range(3):
103 >>> @thread_utils.background_thread
105 ... for _ in range(3):
108 >>> start = time.time()
113 >>> end = time.time()
114 >>> dur = end - start
122 min_interval_seconds = per_period_in_seconds / float(n_calls)
124 def wrapper_rate_limited(func: Callable) -> Callable:
125 cv = threading.Condition()
126 last_invocation_timestamp = [0.0]
128 def may_proceed() -> float:
130 last_invocation = last_invocation_timestamp[0]
131 if last_invocation != 0.0:
132 elapsed_since_last = now - last_invocation
133 wait_time = min_interval_seconds - elapsed_since_last
136 logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
139 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143 lambda: may_proceed() <= 0.0,
144 timeout=may_proceed(),
148 logger.debug('@%.4f> calling it...', time.time())
149 ret = func(*args, **kargs)
150 last_invocation_timestamp[0] = time.time()
152 '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
157 return wrapper_wrapper_rate_limited
159 return wrapper_rate_limited
162 def debug_args(func: Callable) -> Callable:
163 """Print the function signature and return value at each call.
166 ... def foo(a, b, c):
170 ... return (a + b, c)
172 >>> foo(1, 2.0, "test")
173 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
177 foo returned (3.0, 'test'):<class 'tuple'>
181 @functools.wraps(func)
182 def wrapper_debug_args(*args, **kwargs):
183 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
184 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
185 signature = ", ".join(args_repr + kwargs_repr)
186 msg = f"Calling {func.__qualname__}({signature})"
189 value = func(*args, **kwargs)
190 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
195 return wrapper_debug_args
198 def debug_count_calls(func: Callable) -> Callable:
199 """Count function invocations and print a message befor every call.
201 >>> @debug_count_calls
205 ... return x * factoral(x - 1)
208 Call #1 of 'factoral'
209 Call #2 of 'factoral'
210 Call #3 of 'factoral'
211 Call #4 of 'factoral'
212 Call #5 of 'factoral'
217 @functools.wraps(func)
218 def wrapper_debug_count_calls(*args, **kwargs):
219 wrapper_debug_count_calls.num_calls += 1
220 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
223 return func(*args, **kwargs)
225 wrapper_debug_count_calls.num_calls = 0 # type: ignore
226 return wrapper_debug_count_calls
229 class DelayWhen(enum.IntEnum):
230 """When should we delay: before or after calling the function (or
241 _func: Callable = None,
243 seconds: float = 1.0,
244 when: DelayWhen = DelayWhen.BEFORE_CALL,
246 """Delay the execution of a function by sleeping before and/or after.
248 Slow down a function by inserting a delay before and/or after its
253 >>> @delay(seconds=1.0)
257 >>> start = time.time()
259 >>> dur = time.time() - start
265 def decorator_delay(func: Callable) -> Callable:
266 @functools.wraps(func)
267 def wrapper_delay(*args, **kwargs):
268 if when & DelayWhen.BEFORE_CALL:
269 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
271 retval = func(*args, **kwargs)
272 if when & DelayWhen.AFTER_CALL:
273 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
280 return decorator_delay
282 return decorator_delay(_func)
285 class _SingletonWrapper:
287 A singleton wrapper class. Its instances would be created
288 for each decorated class.
292 def __init__(self, cls):
293 self.__wrapped__ = cls
294 self._instance = None
296 def __call__(self, *args, **kwargs):
297 """Returns a single instance of decorated class"""
298 logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
299 if self._instance is None:
300 self._instance = self.__wrapped__(*args, **kwargs)
301 return self._instance
306 A singleton decorator. Returns a wrapper objects. A call on that object
307 returns a single instance object of decorated class. Use the __wrapped__
308 attribute to access decorated class directly in unit tests
311 ... class foo(object):
323 return _SingletonWrapper(cls)
326 def memoized(func: Callable) -> Callable:
327 """Keep a cache of previous function call results.
329 The cache here is a dict with a key based on the arguments to the
330 call. Consider also: functools.lru_cache for a more advanced
336 ... def expensive(arg) -> int:
337 ... # Simulate something slow to compute or lookup
341 >>> start = time.time()
342 >>> expensive(5) # Takes about 1 sec
345 >>> expensive(3) # Also takes about 1 sec
348 >>> expensive(5) # Pulls from cache, fast
351 >>> expensive(3) # Pulls from cache again, fast
354 >>> dur = time.time() - start
360 @functools.wraps(func)
361 def wrapper_memoized(*args, **kwargs):
362 cache_key = args + tuple(kwargs.items())
363 if cache_key not in wrapper_memoized.cache:
364 value = func(*args, **kwargs)
365 logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
366 wrapper_memoized.cache[cache_key] = value
368 logger.debug('Returning memoized value for %s', {func.__name__})
369 return wrapper_memoized.cache[cache_key]
371 wrapper_memoized.cache = {} # type: ignore
372 return wrapper_memoized
378 predicate: Callable[..., bool],
379 delay_sec: float = 3.0,
380 backoff: float = 2.0,
382 """Retries a function or method up to a certain number of times
383 with a prescribed initial delay period and backoff rate.
385 tries is the maximum number of attempts to run the function.
386 delay_sec sets the initial delay period in seconds.
387 backoff is a multiplied (must be >1) used to modify the delay.
388 predicate is a function that will be passed the retval of the
389 decorated function and must return True to stop or False to
394 msg = f"backoff must be greater than or equal to 1, got {backoff}"
396 raise ValueError(msg)
398 tries = math.floor(tries)
400 msg = f"tries must be 0 or greater, got {tries}"
402 raise ValueError(msg)
405 msg = f"delay_sec must be greater than 0, got {delay_sec}"
407 raise ValueError(msg)
411 def f_retry(*args, **kwargs):
412 mtries, mdelay = tries, delay_sec # make mutable
413 logger.debug('deco_retry: will make up to %d attempts...', mtries)
414 retval = f(*args, **kwargs)
416 if predicate(retval) is True:
417 logger.debug('Predicate succeeded, deco_retry is done.')
419 logger.debug("Predicate failed, sleeping and retrying.")
423 retval = f(*args, **kwargs)
431 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
432 """A helper for @retry_predicate that retries a decorated
433 function as long as it keeps returning False.
439 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
443 ... return counter >= 3
445 >>> start = time.time()
446 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
449 >>> dur = time.time() - start
458 return retry_predicate(
460 predicate=lambda x: x is True,
466 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
467 """Another helper for @retry_predicate above. Retries up to N
468 times so long as the wrapped function returns None with a delay
469 between each retry and a backoff that can increase the delay.
472 return retry_predicate(
474 predicate=lambda x: x is not None,
480 def deprecated(func):
481 """This is a decorator which can be used to mark functions
482 as deprecated. It will result in a warning being emitted
483 when the function is used.
487 @functools.wraps(func)
488 def wrapper_deprecated(*args, **kwargs):
489 msg = f"Call to deprecated function {func.__qualname__}"
491 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
492 print(msg, file=sys.stderr)
493 return func(*args, **kwargs)
495 return wrapper_deprecated
500 Make a function immediately return a function of no args which,
501 when called, waits for the result, which will start being
502 processed in another thread.
505 @functools.wraps(func)
506 def lazy_thunked(*args, **kwargs):
507 wait_event = threading.Event()
514 func_result = func(*args, **kwargs)
515 result[0] = func_result
518 exc[1] = sys.exc_info() # (type, value, traceback)
519 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
527 raise exc[1][0](exc[1][1])
530 threading.Thread(target=worker_func).start()
536 ############################################################
538 ############################################################
540 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
542 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
545 def _raise_exception(exception, error_message: Optional[str]):
546 if error_message is None:
547 raise Exception(exception)
549 raise Exception(error_message)
552 def _target(queue, function, *args, **kwargs):
553 """Run a function with arguments and return output via a queue.
555 This is a helper function for the Process created in _Timeout. It runs
556 the function with positional arguments and keyword arguments and then
557 returns the function's output by way of a queue. If an exception gets
558 raised, it is returned to _Timeout to be raised by the value property.
561 queue.put((True, function(*args, **kwargs)))
563 queue.put((False, sys.exc_info()[1]))
566 class _Timeout(object):
567 """Wrap a function and add a timeout to it.
569 Instances of this class are automatically generated by the add_timeout
570 function defined below. Do not use directly.
576 timeout_exception: Exception,
580 self.__limit = seconds
581 self.__function = function
582 self.__timeout_exception = timeout_exception
583 self.__error_message = error_message
584 self.__name__ = function.__name__
585 self.__doc__ = function.__doc__
586 self.__timeout = time.time()
587 self.__process = multiprocessing.Process()
588 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
590 def __call__(self, *args, **kwargs):
591 """Execute the embedded function object asynchronously.
593 The function given to the constructor is transparently called and
594 requires that "ready" be intermittently polled. If and when it is
595 True, the "value" property may then be checked for returned data.
597 self.__limit = kwargs.pop("timeout", self.__limit)
598 self.__queue = multiprocessing.Queue(1)
599 args = (self.__queue, self.__function) + args
600 self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
601 self.__process.daemon = True
602 self.__process.start()
603 if self.__limit is not None:
604 self.__timeout = self.__limit + time.time()
605 while not self.ready:
610 """Terminate any possible execution of the embedded function."""
611 if self.__process.is_alive():
612 self.__process.terminate()
613 _raise_exception(self.__timeout_exception, self.__error_message)
617 """Read-only property indicating status of "value" property."""
618 if self.__limit and self.__timeout < time.time():
620 return self.__queue.full() and not self.__queue.empty()
624 """Read-only property containing data returned from function."""
625 if self.ready is True:
626 flag, load = self.__queue.get()
634 seconds: float = 1.0,
635 use_signals: Optional[bool] = None,
636 timeout_exception=exceptions.TimeoutError,
637 error_message="Function call timed out",
639 """Add a timeout parameter to a function and return the function.
641 Note: the use_signals parameter is included in order to support
642 multiprocessing scenarios (signal can only be used from the process'
643 main thread). When not using signals, timeout granularity will be
644 rounded to the nearest 0.1s.
646 Raises an exception when/if the timeout is reached.
648 It is illegal to pass anything other than a function as the first
649 parameter. The function is wrapped and returned to the caller.
652 ... def foo(delay: float):
653 ... time.sleep(delay)
660 Traceback (most recent call last):
662 Exception: Function call timed out
665 if use_signals is None:
668 use_signals = thread_utils.is_current_thread_main_thread()
670 def decorate(function):
673 def handler(signum, frame):
674 _raise_exception(timeout_exception, error_message)
676 @functools.wraps(function)
677 def new_function(*args, **kwargs):
678 new_seconds = kwargs.pop("timeout", seconds)
680 old = signal.signal(signal.SIGALRM, handler)
681 signal.setitimer(signal.ITIMER_REAL, new_seconds)
684 return function(*args, **kwargs)
687 return function(*args, **kwargs)
690 signal.setitimer(signal.ITIMER_REAL, 0)
691 signal.signal(signal.SIGALRM, old)
696 @functools.wraps(function)
697 def new_function(*args, **kwargs):
698 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
699 return timeout_wrapper(*args, **kwargs)
706 def synchronized(lock):
709 def _gatekeeper(*args, **kw):
712 return f(*args, **kw)
721 def call_with_sample_rate(sample_rate: float) -> Callable:
722 if not 0.0 <= sample_rate <= 1.0:
723 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
725 raise ValueError(msg)
729 def _call_with_sample_rate(*args, **kwargs):
730 if random.uniform(0, 1) < sample_rate:
731 return f(*args, **kwargs)
733 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
736 return _call_with_sample_rate
741 def decorate_matching_methods_with(decorator, acl=None):
742 """Apply decorator to all methods in a class whose names begin with
743 prefix. If prefix is None (default), decorate all methods in the
747 def decorate_the_class(cls):
748 for name, m in inspect.getmembers(cls, inspect.isfunction):
750 setattr(cls, name, decorator(m))
753 setattr(cls, name, decorator(m))
756 return decorate_the_class
759 if __name__ == '__main__':