10 import multiprocessing
17 from typing import Any, Callable, Optional
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
25 logger = logging.getLogger(__name__)
28 def timed(func: Callable) -> Callable:
29 """Print the runtime of the decorated function.
36 >>> foo() # doctest: +ELLIPSIS
41 @functools.wraps(func)
42 def wrapper_timer(*args, **kwargs):
43 start_time = time.perf_counter()
44 value = func(*args, **kwargs)
45 end_time = time.perf_counter()
46 run_time = end_time - start_time
47 msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
54 def invocation_logged(func: Callable) -> Callable:
55 """Log the call of a function.
57 >>> @invocation_logged
59 ... print('Hello, world.')
68 @functools.wraps(func)
69 def wrapper_invocation_logged(*args, **kwargs):
70 msg = f"Entered {func.__qualname__}"
73 ret = func(*args, **kwargs)
74 msg = f"Exited {func.__qualname__}"
78 return wrapper_invocation_logged
81 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
82 """Limit invocation of a wrapped function to n calls per period.
83 Thread safe. In testing this was relatively fair with multiple
84 threads using it though that hasn't been measured.
87 >>> import decorator_utils
88 >>> import thread_utils
92 >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
93 ... def limited(x: int):
97 >>> @thread_utils.background_thread
99 ... for _ in range(3):
102 >>> @thread_utils.background_thread
104 ... for _ in range(3):
107 >>> start = time.time()
112 >>> end = time.time()
113 >>> dur = end - start
121 min_interval_seconds = per_period_in_seconds / float(n_calls)
123 def wrapper_rate_limited(func: Callable) -> Callable:
124 cv = threading.Condition()
125 last_invocation_timestamp = [0.0]
127 def may_proceed() -> float:
129 last_invocation = last_invocation_timestamp[0]
130 if last_invocation != 0.0:
131 elapsed_since_last = now - last_invocation
132 wait_time = min_interval_seconds - elapsed_since_last
135 logger.debug(f'@{time.time()}> wait_time = {wait_time}')
138 def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
142 lambda: may_proceed() <= 0.0,
143 timeout=may_proceed(),
147 logger.debug(f'@{time.time()}> calling it...')
148 ret = func(*args, **kargs)
149 last_invocation_timestamp[0] = time.time()
151 f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
155 return wrapper_wrapper_rate_limited
156 return wrapper_rate_limited
159 def debug_args(func: Callable) -> Callable:
160 """Print the function signature and return value at each call.
163 ... def foo(a, b, c):
167 ... return (a + b, c)
169 >>> foo(1, 2.0, "test")
170 Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
174 foo returned (3.0, 'test'):<class 'tuple'>
178 @functools.wraps(func)
179 def wrapper_debug_args(*args, **kwargs):
180 args_repr = [f"{repr(a)}:{type(a)}" for a in args]
181 kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
182 signature = ", ".join(args_repr + kwargs_repr)
183 msg = f"Calling {func.__qualname__}({signature})"
186 value = func(*args, **kwargs)
187 msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
191 return wrapper_debug_args
194 def debug_count_calls(func: Callable) -> Callable:
195 """Count function invocations and print a message befor every call.
197 >>> @debug_count_calls
201 ... return x * factoral(x - 1)
204 Call #1 of 'factoral'
205 Call #2 of 'factoral'
206 Call #3 of 'factoral'
207 Call #4 of 'factoral'
208 Call #5 of 'factoral'
213 @functools.wraps(func)
214 def wrapper_debug_count_calls(*args, **kwargs):
215 wrapper_debug_count_calls.num_calls += 1
216 msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
219 return func(*args, **kwargs)
220 wrapper_debug_count_calls.num_calls = 0
221 return wrapper_debug_count_calls
224 class DelayWhen(enum.IntEnum):
231 _func: Callable = None,
233 seconds: float = 1.0,
234 when: DelayWhen = DelayWhen.BEFORE_CALL,
236 """Delay the execution of a function by sleeping before and/or after.
238 Slow down a function by inserting a delay before and/or after its
243 >>> @delay(seconds=1.0)
247 >>> start = time.time()
249 >>> dur = time.time() - start
254 def decorator_delay(func: Callable) -> Callable:
255 @functools.wraps(func)
256 def wrapper_delay(*args, **kwargs):
257 if when & DelayWhen.BEFORE_CALL:
259 f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
262 retval = func(*args, **kwargs)
263 if when & DelayWhen.AFTER_CALL:
265 f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
272 return decorator_delay
274 return decorator_delay(_func)
277 class _SingletonWrapper:
279 A singleton wrapper class. Its instances would be created
280 for each decorated class.
284 def __init__(self, cls):
285 self.__wrapped__ = cls
286 self._instance = None
288 def __call__(self, *args, **kwargs):
289 """Returns a single instance of decorated class"""
291 f"@singleton returning global instance of {self.__wrapped__.__name__}"
293 if self._instance is None:
294 self._instance = self.__wrapped__(*args, **kwargs)
295 return self._instance
300 A singleton decorator. Returns a wrapper objects. A call on that object
301 returns a single instance object of decorated class. Use the __wrapped__
302 attribute to access decorated class directly in unit tests
305 ... class foo(object):
317 return _SingletonWrapper(cls)
320 def memoized(func: Callable) -> Callable:
321 """Keep a cache of previous function call results.
323 The cache here is a dict with a key based on the arguments to the
324 call. Consider also: functools.lru_cache for a more advanced
330 ... def expensive(arg) -> int:
331 ... # Simulate something slow to compute or lookup
335 >>> start = time.time()
336 >>> expensive(5) # Takes about 1 sec
339 >>> expensive(3) # Also takes about 1 sec
342 >>> expensive(5) # Pulls from cache, fast
345 >>> expensive(3) # Pulls from cache again, fast
348 >>> dur = time.time() - start
353 @functools.wraps(func)
354 def wrapper_memoized(*args, **kwargs):
355 cache_key = args + tuple(kwargs.items())
356 if cache_key not in wrapper_memoized.cache:
357 value = func(*args, **kwargs)
359 f"Memoizing {cache_key} => {value} for {func.__name__}"
361 wrapper_memoized.cache[cache_key] = value
363 logger.debug(f"Returning memoized value for {func.__name__}")
364 return wrapper_memoized.cache[cache_key]
365 wrapper_memoized.cache = dict()
366 return wrapper_memoized
372 predicate: Callable[..., bool],
373 delay_sec: float = 3.0,
374 backoff: float = 2.0,
376 """Retries a function or method up to a certain number of times
377 with a prescribed initial delay period and backoff rate.
379 tries is the maximum number of attempts to run the function.
380 delay_sec sets the initial delay period in seconds.
381 backoff is a multiplied (must be >1) used to modify the delay.
382 predicate is a function that will be passed the retval of the
383 decorated function and must return True to stop or False to
388 msg = f"backoff must be greater than or equal to 1, got {backoff}"
390 raise ValueError(msg)
392 tries = math.floor(tries)
394 msg = f"tries must be 0 or greater, got {tries}"
396 raise ValueError(msg)
399 msg = f"delay_sec must be greater than 0, got {delay_sec}"
401 raise ValueError(msg)
405 def f_retry(*args, **kwargs):
406 mtries, mdelay = tries, delay_sec # make mutable
407 logger.debug(f'deco_retry: will make up to {mtries} attempts...')
408 retval = f(*args, **kwargs)
410 if predicate(retval) is True:
411 logger.debug('Predicate succeeded, deco_retry is done.')
413 logger.debug("Predicate failed, sleeping and retrying.")
417 retval = f(*args, **kwargs)
423 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
424 """A helper for @retry_predicate that retries a decorated
425 function as long as it keeps returning False.
431 >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
435 ... return counter >= 3
437 >>> start = time.time()
438 >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed
441 >>> dur = time.time() - start
450 return retry_predicate(
452 predicate=lambda x: x is True,
458 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
459 """Another helper for @retry_predicate above. Retries up to N
460 times so long as the wrapped function returns None with a delay
461 between each retry and a backoff that can increase the delay.
464 return retry_predicate(
466 predicate=lambda x: x is not None,
472 def deprecated(func):
473 """This is a decorator which can be used to mark functions
474 as deprecated. It will result in a warning being emitted
475 when the function is used.
478 @functools.wraps(func)
479 def wrapper_deprecated(*args, **kwargs):
480 msg = f"Call to deprecated function {func.__qualname__}"
482 warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
483 print(msg, file=sys.stderr)
484 return func(*args, **kwargs)
485 return wrapper_deprecated
490 Make a function immediately return a function of no args which,
491 when called, waits for the result, which will start being
492 processed in another thread.
495 @functools.wraps(func)
496 def lazy_thunked(*args, **kwargs):
497 wait_event = threading.Event()
504 func_result = func(*args, **kwargs)
505 result[0] = func_result
508 exc[1] = sys.exc_info() # (type, value, traceback)
509 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
517 raise exc[1][0](exc[1][1])
520 threading.Thread(target=worker_func).start()
526 ############################################################
528 ############################################################
530 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
532 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
535 def _raise_exception(exception, error_message: Optional[str]):
536 if error_message is None:
539 raise Exception(error_message)
542 def _target(queue, function, *args, **kwargs):
543 """Run a function with arguments and return output via a queue.
545 This is a helper function for the Process created in _Timeout. It runs
546 the function with positional arguments and keyword arguments and then
547 returns the function's output by way of a queue. If an exception gets
548 raised, it is returned to _Timeout to be raised by the value property.
551 queue.put((True, function(*args, **kwargs)))
553 queue.put((False, sys.exc_info()[1]))
556 class _Timeout(object):
557 """Wrap a function and add a timeout to it.
559 Instances of this class are automatically generated by the add_timeout
560 function defined below. Do not use directly.
566 timeout_exception: Exception,
570 self.__limit = seconds
571 self.__function = function
572 self.__timeout_exception = timeout_exception
573 self.__error_message = error_message
574 self.__name__ = function.__name__
575 self.__doc__ = function.__doc__
576 self.__timeout = time.time()
577 self.__process = multiprocessing.Process()
578 self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
580 def __call__(self, *args, **kwargs):
581 """Execute the embedded function object asynchronously.
583 The function given to the constructor is transparently called and
584 requires that "ready" be intermittently polled. If and when it is
585 True, the "value" property may then be checked for returned data.
587 self.__limit = kwargs.pop("timeout", self.__limit)
588 self.__queue = multiprocessing.Queue(1)
589 args = (self.__queue, self.__function) + args
590 self.__process = multiprocessing.Process(
591 target=_target, args=args, kwargs=kwargs
593 self.__process.daemon = True
594 self.__process.start()
595 if self.__limit is not None:
596 self.__timeout = self.__limit + time.time()
597 while not self.ready:
602 """Terminate any possible execution of the embedded function."""
603 if self.__process.is_alive():
604 self.__process.terminate()
605 _raise_exception(self.__timeout_exception, self.__error_message)
609 """Read-only property indicating status of "value" property."""
610 if self.__limit and self.__timeout < time.time():
612 return self.__queue.full() and not self.__queue.empty()
616 """Read-only property containing data returned from function."""
617 if self.ready is True:
618 flag, load = self.__queue.get()
625 seconds: float = 1.0,
626 use_signals: Optional[bool] = None,
627 timeout_exception=exceptions.TimeoutError,
628 error_message="Function call timed out",
630 """Add a timeout parameter to a function and return the function.
632 Note: the use_signals parameter is included in order to support
633 multiprocessing scenarios (signal can only be used from the process'
634 main thread). When not using signals, timeout granularity will be
635 rounded to the nearest 0.1s.
637 Raises an exception when/if the timeout is reached.
639 It is illegal to pass anything other than a function as the first
640 parameter. The function is wrapped and returned to the caller.
643 ... def foo(delay: float):
644 ... time.sleep(delay)
651 Traceback (most recent call last):
653 Exception: Function call timed out
656 if use_signals is None:
658 use_signals = thread_utils.is_current_thread_main_thread()
660 def decorate(function):
663 def handler(signum, frame):
664 _raise_exception(timeout_exception, error_message)
666 @functools.wraps(function)
667 def new_function(*args, **kwargs):
668 new_seconds = kwargs.pop("timeout", seconds)
670 old = signal.signal(signal.SIGALRM, handler)
671 signal.setitimer(signal.ITIMER_REAL, new_seconds)
674 return function(*args, **kwargs)
677 return function(*args, **kwargs)
680 signal.setitimer(signal.ITIMER_REAL, 0)
681 signal.signal(signal.SIGALRM, old)
686 @functools.wraps(function)
687 def new_function(*args, **kwargs):
688 timeout_wrapper = _Timeout(
689 function, timeout_exception, error_message, seconds
691 return timeout_wrapper(*args, **kwargs)
698 class non_reentrant_code(object):
700 self._lock = threading.RLock
701 self._entered = False
703 def __call__(self, f):
704 def _gatekeeper(*args, **kwargs):
710 self._entered = False
714 class rlocked(object):
716 self._lock = threading.RLock
717 self._entered = False
719 def __call__(self, f):
720 def _gatekeeper(*args, **kwargs):
726 self._entered = False
730 def call_with_sample_rate(sample_rate: float) -> Callable:
731 if not 0.0 <= sample_rate <= 1.0:
732 msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
734 raise ValueError(msg)
738 def _call_with_sample_rate(*args, **kwargs):
739 if random.uniform(0, 1) < sample_rate:
740 return f(*args, **kwargs)
743 f"@call_with_sample_rate skipping a call to {f.__name__}"
745 return _call_with_sample_rate
749 def decorate_matching_methods_with(decorator, acl=None):
750 """Apply decorator to all methods in a class whose names begin with
751 prefix. If prefix is None (default), decorate all methods in the
754 def decorate_the_class(cls):
755 for name, m in inspect.getmembers(cls, inspect.isfunction):
757 setattr(cls, name, decorator(m))
760 setattr(cls, name, decorator(m))
762 return decorate_the_class
765 if __name__ == '__main__':