10 import multiprocessing
18 from typing import Any, Callable, Optional
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
24 logger = logging.getLogger(__name__)
27 def timed(func: Callable) -> Callable:
28 """Print the runtime of the decorated function.
35 >>> foo() # doctest: +ELLIPSIS
40 @functools.wraps(func)
41 def wrapper_timer(*args, **kwargs):
42 start_time = time.perf_counter()
43 value = func(*args, **kwargs)
44 end_time = time.perf_counter()
45 run_time = end_time - start_time
46 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
54 def invocation_logged(func: Callable) -> Callable:
55 """Log the call of a function.
57 >>> @invocation_logged
59 ... print('Hello, world.')
68 @functools.wraps(func)
69 def wrapper_invocation_logged(*args, **kwargs):
70 msg = f"Entered {func.__qualname__}"
73 ret = func(*args, **kwargs)
74 msg = f"Exited {func.__qualname__}"
79 return wrapper_invocation_logged
82 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
83 """Limit invocation of a wrapped function to n calls per period.
84 Thread safe. In testing this was relatively fair with multiple
85 threads using it though that hasn't been measured.
88 >>> import decorator_utils
89 >>> import thread_utils
93 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
94 ... def limited(x: int):
98 >>> @thread_utils.background_thread
100 ... for _ in range(3):
103 >>> @thread_utils.background_thread
105 ... for _ in range(3):
108 >>> start = time.time()
113 >>> end = time.time()
114 >>> dur = end - start
122 min_interval_seconds = per_period_in_seconds / float(n_calls)
124 def wrapper_rate_limited(func: Callable) -> Callable:
125 cv = threading.Condition()
126 last_invocation_timestamp = [0.0]
128 def may_proceed() -> float:
130 last_invocation = last_invocation_timestamp[0]
131 if last_invocation != 0.0:
132 elapsed_since_last = now - last_invocation
133 wait_time = min_interval_seconds - elapsed_since_last
136 logger.debug(f'@{time.time()}> wait_time = {wait_time}')
139 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143 lambda: may_proceed() <= 0.0,
144 timeout=may_proceed(),
148 logger.debug(f'@{time.time()}> calling it...')
149 ret = func(*args, **kargs)
150 last_invocation_timestamp[0] = time.time()
151 logger.debug(f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}')
155 return wrapper_wrapper_rate_limited
157 return wrapper_rate_limited
160 def debug_args(func: Callable) -> Callable:
161 """Print the function signature and return value at each call.
164 ... def foo(a, b, c):
168 ... return (a + b, c)
170 >>> foo(1, 2.0, "test")
171 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
175 foo returned (3.0, 'test'):<class 'tuple'>
179 @functools.wraps(func)
180 def wrapper_debug_args(*args, **kwargs):
181 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
182 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
183 signature = ", ".join(args_repr + kwargs_repr)
184 msg = f"Calling {func.__qualname__}({signature})"
187 value = func(*args, **kwargs)
188 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
193 return wrapper_debug_args
196 def debug_count_calls(func: Callable) -> Callable:
197 """Count function invocations and print a message befor every call.
199 >>> @debug_count_calls
203 ... return x * factoral(x - 1)
206 Call #1 of 'factoral'
207 Call #2 of 'factoral'
208 Call #3 of 'factoral'
209 Call #4 of 'factoral'
210 Call #5 of 'factoral'
215 @functools.wraps(func)
216 def wrapper_debug_count_calls(*args, **kwargs):
217 wrapper_debug_count_calls.num_calls += 1
218 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
221 return func(*args, **kwargs)
223 wrapper_debug_count_calls.num_calls = 0 # type: ignore
224 return wrapper_debug_count_calls
227 class DelayWhen(enum.IntEnum):
234 _func: Callable = None,
236 seconds: float = 1.0,
237 when: DelayWhen = DelayWhen.BEFORE_CALL,
239 """Delay the execution of a function by sleeping before and/or after.
241 Slow down a function by inserting a delay before and/or after its
246 >>> @delay(seconds=1.0)
250 >>> start = time.time()
252 >>> dur = time.time() - start
258 def decorator_delay(func: Callable) -> Callable:
259 @functools.wraps(func)
260 def wrapper_delay(*args, **kwargs):
261 if when & DelayWhen.BEFORE_CALL:
262 logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
264 retval = func(*args, **kwargs)
265 if when & DelayWhen.AFTER_CALL:
266 logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
273 return decorator_delay
275 return decorator_delay(_func)
278 class _SingletonWrapper:
280 A singleton wrapper class. Its instances would be created
281 for each decorated class.
285 def __init__(self, cls):
286 self.__wrapped__ = cls
287 self._instance = None
289 def __call__(self, *args, **kwargs):
290 """Returns a single instance of decorated class"""
291 logger.debug(f"@singleton returning global instance of {self.__wrapped__.__name__}")
292 if self._instance is None:
293 self._instance = self.__wrapped__(*args, **kwargs)
294 return self._instance
299 A singleton decorator. Returns a wrapper objects. A call on that object
300 returns a single instance object of decorated class. Use the __wrapped__
301 attribute to access decorated class directly in unit tests
304 ... class foo(object):
316 return _SingletonWrapper(cls)
319 def memoized(func: Callable) -> Callable:
320 """Keep a cache of previous function call results.
322 The cache here is a dict with a key based on the arguments to the
323 call. Consider also: functools.lru_cache for a more advanced
329 ... def expensive(arg) -> int:
330 ... # Simulate something slow to compute or lookup
334 >>> start = time.time()
335 >>> expensive(5) # Takes about 1 sec
338 >>> expensive(3) # Also takes about 1 sec
341 >>> expensive(5) # Pulls from cache, fast
344 >>> expensive(3) # Pulls from cache again, fast
347 >>> dur = time.time() - start
353 @functools.wraps(func)
354 def wrapper_memoized(*args, **kwargs):
355 cache_key = args + tuple(kwargs.items())
356 if cache_key not in wrapper_memoized.cache:
357 value = func(*args, **kwargs)
358 logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
359 wrapper_memoized.cache[cache_key] = value
361 logger.debug(f"Returning memoized value for {func.__name__}")
362 return wrapper_memoized.cache[cache_key]
364 wrapper_memoized.cache = dict() # type: ignore
365 return wrapper_memoized
371 predicate: Callable[..., bool],
372 delay_sec: float = 3.0,
373 backoff: float = 2.0,
375 """Retries a function or method up to a certain number of times
376 with a prescribed initial delay period and backoff rate.
378 tries is the maximum number of attempts to run the function.
379 delay_sec sets the initial delay period in seconds.
380 backoff is a multiplied (must be >1) used to modify the delay.
381 predicate is a function that will be passed the retval of the
382 decorated function and must return True to stop or False to
387 msg = f"backoff must be greater than or equal to 1, got {backoff}"
389 raise ValueError(msg)
391 tries = math.floor(tries)
393 msg = f"tries must be 0 or greater, got {tries}"
395 raise ValueError(msg)
398 msg = f"delay_sec must be greater than 0, got {delay_sec}"
400 raise ValueError(msg)
404 def f_retry(*args, **kwargs):
405 mtries, mdelay = tries, delay_sec # make mutable
406 logger.debug(f'deco_retry: will make up to {mtries} attempts...')
407 retval = f(*args, **kwargs)
409 if predicate(retval) is True:
410 logger.debug('Predicate succeeded, deco_retry is done.')
412 logger.debug("Predicate failed, sleeping and retrying.")
416 retval = f(*args, **kwargs)
424 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
425 """A helper for @retry_predicate that retries a decorated
426 function as long as it keeps returning False.
432 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
436 ... return counter >= 3
438 >>> start = time.time()
439 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
442 >>> dur = time.time() - start
451 return retry_predicate(
453 predicate=lambda x: x is True,
459 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
460 """Another helper for @retry_predicate above. Retries up to N
461 times so long as the wrapped function returns None with a delay
462 between each retry and a backoff that can increase the delay.
465 return retry_predicate(
467 predicate=lambda x: x is not None,
473 def deprecated(func):
474 """This is a decorator which can be used to mark functions
475 as deprecated. It will result in a warning being emitted
476 when the function is used.
480 @functools.wraps(func)
481 def wrapper_deprecated(*args, **kwargs):
482 msg = f"Call to deprecated function {func.__qualname__}"
484 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
485 print(msg, file=sys.stderr)
486 return func(*args, **kwargs)
488 return wrapper_deprecated
493 Make a function immediately return a function of no args which,
494 when called, waits for the result, which will start being
495 processed in another thread.
498 @functools.wraps(func)
499 def lazy_thunked(*args, **kwargs):
500 wait_event = threading.Event()
507 func_result = func(*args, **kwargs)
508 result[0] = func_result
511 exc[1] = sys.exc_info() # (type, value, traceback)
512 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
520 raise exc[1][0](exc[1][1])
523 threading.Thread(target=worker_func).start()
529 ############################################################
531 ############################################################
533 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
535 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
538 def _raise_exception(exception, error_message: Optional[str]):
539 if error_message is None:
542 raise Exception(error_message)
545 def _target(queue, function, *args, **kwargs):
546 """Run a function with arguments and return output via a queue.
548 This is a helper function for the Process created in _Timeout. It runs
549 the function with positional arguments and keyword arguments and then
550 returns the function's output by way of a queue. If an exception gets
551 raised, it is returned to _Timeout to be raised by the value property.
554 queue.put((True, function(*args, **kwargs)))
556 queue.put((False, sys.exc_info()[1]))
559 class _Timeout(object):
560 """Wrap a function and add a timeout to it.
562 Instances of this class are automatically generated by the add_timeout
563 function defined below. Do not use directly.
569 timeout_exception: Exception,
573 self.__limit = seconds
574 self.__function = function
575 self.__timeout_exception = timeout_exception
576 self.__error_message = error_message
577 self.__name__ = function.__name__
578 self.__doc__ = function.__doc__
579 self.__timeout = time.time()
580 self.__process = multiprocessing.Process()
581 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
583 def __call__(self, *args, **kwargs):
584 """Execute the embedded function object asynchronously.
586 The function given to the constructor is transparently called and
587 requires that "ready" be intermittently polled. If and when it is
588 True, the "value" property may then be checked for returned data.
590 self.__limit = kwargs.pop("timeout", self.__limit)
591 self.__queue = multiprocessing.Queue(1)
592 args = (self.__queue, self.__function) + args
593 self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
594 self.__process.daemon = True
595 self.__process.start()
596 if self.__limit is not None:
597 self.__timeout = self.__limit + time.time()
598 while not self.ready:
603 """Terminate any possible execution of the embedded function."""
604 if self.__process.is_alive():
605 self.__process.terminate()
606 _raise_exception(self.__timeout_exception, self.__error_message)
610 """Read-only property indicating status of "value" property."""
611 if self.__limit and self.__timeout < time.time():
613 return self.__queue.full() and not self.__queue.empty()
617 """Read-only property containing data returned from function."""
618 if self.ready is True:
619 flag, load = self.__queue.get()
626 seconds: float = 1.0,
627 use_signals: Optional[bool] = None,
628 timeout_exception=exceptions.TimeoutError,
629 error_message="Function call timed out",
631 """Add a timeout parameter to a function and return the function.
633 Note: the use_signals parameter is included in order to support
634 multiprocessing scenarios (signal can only be used from the process'
635 main thread). When not using signals, timeout granularity will be
636 rounded to the nearest 0.1s.
638 Raises an exception when/if the timeout is reached.
640 It is illegal to pass anything other than a function as the first
641 parameter. The function is wrapped and returned to the caller.
644 ... def foo(delay: float):
645 ... time.sleep(delay)
652 Traceback (most recent call last):
654 Exception: Function call timed out
657 if use_signals is None:
660 use_signals = thread_utils.is_current_thread_main_thread()
662 def decorate(function):
665 def handler(signum, frame):
666 _raise_exception(timeout_exception, error_message)
668 @functools.wraps(function)
669 def new_function(*args, **kwargs):
670 new_seconds = kwargs.pop("timeout", seconds)
672 old = signal.signal(signal.SIGALRM, handler)
673 signal.setitimer(signal.ITIMER_REAL, new_seconds)
676 return function(*args, **kwargs)
679 return function(*args, **kwargs)
682 signal.setitimer(signal.ITIMER_REAL, 0)
683 signal.signal(signal.SIGALRM, old)
688 @functools.wraps(function)
689 def new_function(*args, **kwargs):
690 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
691 return timeout_wrapper(*args, **kwargs)
698 def synchronized(lock):
701 def _gatekeeper(*args, **kw):
704 return f(*args, **kw)
713 def call_with_sample_rate(sample_rate: float) -> Callable:
714 if not 0.0 <= sample_rate <= 1.0:
715 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
717 raise ValueError(msg)
721 def _call_with_sample_rate(*args, **kwargs):
722 if random.uniform(0, 1) < sample_rate:
723 return f(*args, **kwargs)
725 logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
727 return _call_with_sample_rate
732 def decorate_matching_methods_with(decorator, acl=None):
733 """Apply decorator to all methods in a class whose names begin with
734 prefix. If prefix is None (default), decorate all methods in the
738 def decorate_the_class(cls):
739 for name, m in inspect.getmembers(cls, inspect.isfunction):
741 setattr(cls, name, decorator(m))
744 setattr(cls, name, decorator(m))
747 return decorate_the_class
750 if __name__ == '__main__':