Scale back warnings.warn and add stacklevels= where appropriate.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function.
30
31     >>> @timed
32     ... def foo():
33     ...     import time
34     ...     time.sleep(0.1)
35
36     >>> foo()  # doctest: +ELLIPSIS
37     Finished foo in ...
38
39     """
40
41     @functools.wraps(func)
42     def wrapper_timer(*args, **kwargs):
43         start_time = time.perf_counter()
44         value = func(*args, **kwargs)
45         end_time = time.perf_counter()
46         run_time = end_time - start_time
47         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
48         print(msg)
49         logger.info(msg)
50         return value
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78     return wrapper_invocation_logged
79
80
81 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
82     """Limit invocation of a wrapped function to n calls per period.
83     Thread safe.  In testing this was relatively fair with multiple
84     threads using it though that hasn't been measured.
85
86     >>> import time
87     >>> import decorator_utils
88     >>> import thread_utils
89
90     >>> calls = 0
91
92     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
93     ... def limited(x: int):
94     ...     global calls
95     ...     calls += 1
96
97     >>> @thread_utils.background_thread
98     ... def a(stop):
99     ...     for _ in range(3):
100     ...         limited(_)
101
102     >>> @thread_utils.background_thread
103     ... def b(stop):
104     ...     for _ in range(3):
105     ...         limited(_)
106
107     >>> start = time.time()
108     >>> (t1, e1) = a()
109     >>> (t2, e2) = b()
110     >>> t1.join()
111     >>> t2.join()
112     >>> end = time.time()
113     >>> dur = end - start
114     >>> dur > 0.5
115     True
116
117     >>> calls
118     6
119
120     """
121     min_interval_seconds = per_period_in_seconds / float(n_calls)
122
123     def wrapper_rate_limited(func: Callable) -> Callable:
124         cv = threading.Condition()
125         last_invocation_timestamp = [0.0]
126
127         def may_proceed() -> float:
128             now = time.time()
129             last_invocation = last_invocation_timestamp[0]
130             if last_invocation != 0.0:
131                 elapsed_since_last = now - last_invocation
132                 wait_time = min_interval_seconds - elapsed_since_last
133             else:
134                 wait_time = 0.0
135             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
136             return wait_time
137
138         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
139             with cv:
140                 while True:
141                     if cv.wait_for(
142                         lambda: may_proceed() <= 0.0,
143                         timeout=may_proceed(),
144                     ):
145                         break
146             with cv:
147                 logger.debug(f'@{time.time()}> calling it...')
148                 ret = func(*args, **kargs)
149                 last_invocation_timestamp[0] = time.time()
150                 logger.debug(
151                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
152                 )
153                 cv.notify()
154             return ret
155         return wrapper_wrapper_rate_limited
156     return wrapper_rate_limited
157
158
159 def debug_args(func: Callable) -> Callable:
160     """Print the function signature and return value at each call.
161
162     >>> @debug_args
163     ... def foo(a, b, c):
164     ...     print(a)
165     ...     print(b)
166     ...     print(c)
167     ...     return (a + b, c)
168
169     >>> foo(1, 2.0, "test")
170     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
171     1
172     2.0
173     test
174     foo returned (3.0, 'test'):<class 'tuple'>
175     (3.0, 'test')
176     """
177
178     @functools.wraps(func)
179     def wrapper_debug_args(*args, **kwargs):
180         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
181         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
182         signature = ", ".join(args_repr + kwargs_repr)
183         msg = f"Calling {func.__qualname__}({signature})"
184         print(msg)
185         logger.info(msg)
186         value = func(*args, **kwargs)
187         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
188         print(msg)
189         logger.info(msg)
190         return value
191     return wrapper_debug_args
192
193
194 def debug_count_calls(func: Callable) -> Callable:
195     """Count function invocations and print a message befor every call.
196
197     >>> @debug_count_calls
198     ... def factoral(x):
199     ...     if x == 1:
200     ...         return 1
201     ...     return x * factoral(x - 1)
202
203     >>> factoral(5)
204     Call #1 of 'factoral'
205     Call #2 of 'factoral'
206     Call #3 of 'factoral'
207     Call #4 of 'factoral'
208     Call #5 of 'factoral'
209     120
210
211     """
212
213     @functools.wraps(func)
214     def wrapper_debug_count_calls(*args, **kwargs):
215         wrapper_debug_count_calls.num_calls += 1
216         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
217         print(msg)
218         logger.info(msg)
219         return func(*args, **kwargs)
220     wrapper_debug_count_calls.num_calls = 0
221     return wrapper_debug_count_calls
222
223
224 class DelayWhen(enum.IntEnum):
225     BEFORE_CALL = 1
226     AFTER_CALL = 2
227     BEFORE_AND_AFTER = 3
228
229
230 def delay(
231     _func: Callable = None,
232     *,
233     seconds: float = 1.0,
234     when: DelayWhen = DelayWhen.BEFORE_CALL,
235 ) -> Callable:
236     """Delay the execution of a function by sleeping before and/or after.
237
238     Slow down a function by inserting a delay before and/or after its
239     invocation.
240
241     >>> import time
242
243     >>> @delay(seconds=1.0)
244     ... def foo():
245     ...     pass
246
247     >>> start = time.time()
248     >>> foo()
249     >>> dur = time.time() - start
250     >>> dur >= 1.0
251     True
252
253     """
254     def decorator_delay(func: Callable) -> Callable:
255         @functools.wraps(func)
256         def wrapper_delay(*args, **kwargs):
257             if when & DelayWhen.BEFORE_CALL:
258                 logger.debug(
259                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
260                 )
261                 time.sleep(seconds)
262             retval = func(*args, **kwargs)
263             if when & DelayWhen.AFTER_CALL:
264                 logger.debug(
265                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
266                 )
267                 time.sleep(seconds)
268             return retval
269         return wrapper_delay
270
271     if _func is None:
272         return decorator_delay
273     else:
274         return decorator_delay(_func)
275
276
277 class _SingletonWrapper:
278     """
279     A singleton wrapper class. Its instances would be created
280     for each decorated class.
281
282     """
283
284     def __init__(self, cls):
285         self.__wrapped__ = cls
286         self._instance = None
287
288     def __call__(self, *args, **kwargs):
289         """Returns a single instance of decorated class"""
290         logger.debug(
291             f"@singleton returning global instance of {self.__wrapped__.__name__}"
292         )
293         if self._instance is None:
294             self._instance = self.__wrapped__(*args, **kwargs)
295         return self._instance
296
297
298 def singleton(cls):
299     """
300     A singleton decorator. Returns a wrapper objects. A call on that object
301     returns a single instance object of decorated class. Use the __wrapped__
302     attribute to access decorated class directly in unit tests
303
304     >>> @singleton
305     ... class foo(object):
306     ...     pass
307
308     >>> a = foo()
309     >>> b = foo()
310     >>> a is b
311     True
312
313     >>> id(a) == id(b)
314     True
315
316     """
317     return _SingletonWrapper(cls)
318
319
320 def memoized(func: Callable) -> Callable:
321     """Keep a cache of previous function call results.
322
323     The cache here is a dict with a key based on the arguments to the
324     call.  Consider also: functools.lru_cache for a more advanced
325     implementation.
326
327     >>> import time
328
329     >>> @memoized
330     ... def expensive(arg) -> int:
331     ...     # Simulate something slow to compute or lookup
332     ...     time.sleep(1.0)
333     ...     return arg * arg
334
335     >>> start = time.time()
336     >>> expensive(5)           # Takes about 1 sec
337     25
338
339     >>> expensive(3)           # Also takes about 1 sec
340     9
341
342     >>> expensive(5)           # Pulls from cache, fast
343     25
344
345     >>> expensive(3)           # Pulls from cache again, fast
346     9
347
348     >>> dur = time.time() - start
349     >>> dur < 3.0
350     True
351
352     """
353     @functools.wraps(func)
354     def wrapper_memoized(*args, **kwargs):
355         cache_key = args + tuple(kwargs.items())
356         if cache_key not in wrapper_memoized.cache:
357             value = func(*args, **kwargs)
358             logger.debug(
359                 f"Memoizing {cache_key} => {value} for {func.__name__}"
360             )
361             wrapper_memoized.cache[cache_key] = value
362         else:
363             logger.debug(f"Returning memoized value for {func.__name__}")
364         return wrapper_memoized.cache[cache_key]
365     wrapper_memoized.cache = dict()
366     return wrapper_memoized
367
368
369 def retry_predicate(
370     tries: int,
371     *,
372     predicate: Callable[..., bool],
373     delay_sec: float = 3.0,
374     backoff: float = 2.0,
375 ):
376     """Retries a function or method up to a certain number of times
377     with a prescribed initial delay period and backoff rate.
378
379     tries is the maximum number of attempts to run the function.
380     delay_sec sets the initial delay period in seconds.
381     backoff is a multiplied (must be >1) used to modify the delay.
382     predicate is a function that will be passed the retval of the
383     decorated function and must return True to stop or False to
384     retry.
385
386     """
387     if backoff < 1.0:
388         msg = f"backoff must be greater than or equal to 1, got {backoff}"
389         logger.critical(msg)
390         raise ValueError(msg)
391
392     tries = math.floor(tries)
393     if tries < 0:
394         msg = f"tries must be 0 or greater, got {tries}"
395         logger.critical(msg)
396         raise ValueError(msg)
397
398     if delay_sec <= 0:
399         msg = f"delay_sec must be greater than 0, got {delay_sec}"
400         logger.critical(msg)
401         raise ValueError(msg)
402
403     def deco_retry(f):
404         @functools.wraps(f)
405         def f_retry(*args, **kwargs):
406             mtries, mdelay = tries, delay_sec  # make mutable
407             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
408             retval = f(*args, **kwargs)
409             while mtries > 0:
410                 if predicate(retval) is True:
411                     logger.debug('Predicate succeeded, deco_retry is done.')
412                     return retval
413                 logger.debug("Predicate failed, sleeping and retrying.")
414                 mtries -= 1
415                 time.sleep(mdelay)
416                 mdelay *= backoff
417                 retval = f(*args, **kwargs)
418             return retval
419         return f_retry
420     return deco_retry
421
422
423 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
424     """A helper for @retry_predicate that retries a decorated
425     function as long as it keeps returning False.
426
427     >>> import time
428
429     >>> counter = 0
430
431     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
432     ... def foo():
433     ...     global counter
434     ...     counter += 1
435     ...     return counter >= 3
436
437     >>> start = time.time()
438     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
439     True
440
441     >>> dur = time.time() - start
442     >>> counter
443     3
444     >>> dur > 2.0
445     True
446     >>> dur < 2.2
447     True
448
449     """
450     return retry_predicate(
451         tries,
452         predicate=lambda x: x is True,
453         delay_sec=delay_sec,
454         backoff=backoff,
455     )
456
457
458 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
459     """Another helper for @retry_predicate above.  Retries up to N
460     times so long as the wrapped function returns None with a delay
461     between each retry and a backoff that can increase the delay.
462
463     """
464     return retry_predicate(
465         tries,
466         predicate=lambda x: x is not None,
467         delay_sec=delay_sec,
468         backoff=backoff,
469     )
470
471
472 def deprecated(func):
473     """This is a decorator which can be used to mark functions
474     as deprecated. It will result in a warning being emitted
475     when the function is used.
476
477     """
478     @functools.wraps(func)
479     def wrapper_deprecated(*args, **kwargs):
480         msg = f"Call to deprecated function {func.__qualname__}"
481         logger.warning(msg)
482         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
483         print(msg, file=sys.stderr)
484         return func(*args, **kwargs)
485     return wrapper_deprecated
486
487
488 def thunkify(func):
489     """
490     Make a function immediately return a function of no args which,
491     when called, waits for the result, which will start being
492     processed in another thread.
493     """
494
495     @functools.wraps(func)
496     def lazy_thunked(*args, **kwargs):
497         wait_event = threading.Event()
498
499         result = [None]
500         exc = [False, None]
501
502         def worker_func():
503             try:
504                 func_result = func(*args, **kwargs)
505                 result[0] = func_result
506             except Exception:
507                 exc[0] = True
508                 exc[1] = sys.exc_info()  # (type, value, traceback)
509                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
510                 logger.warning(msg)
511             finally:
512                 wait_event.set()
513
514         def thunk():
515             wait_event.wait()
516             if exc[0]:
517                 raise exc[1][0](exc[1][1])
518             return result[0]
519
520         threading.Thread(target=worker_func).start()
521         return thunk
522
523     return lazy_thunked
524
525
526 ############################################################
527 # Timeout
528 ############################################################
529
530 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
531 # Used work of Stephen "Zero" Chappell <[email protected]>
532 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
533
534
535 def _raise_exception(exception, error_message: Optional[str]):
536     if error_message is None:
537         raise Exception()
538     else:
539         raise Exception(error_message)
540
541
542 def _target(queue, function, *args, **kwargs):
543     """Run a function with arguments and return output via a queue.
544
545     This is a helper function for the Process created in _Timeout. It runs
546     the function with positional arguments and keyword arguments and then
547     returns the function's output by way of a queue. If an exception gets
548     raised, it is returned to _Timeout to be raised by the value property.
549     """
550     try:
551         queue.put((True, function(*args, **kwargs)))
552     except Exception:
553         queue.put((False, sys.exc_info()[1]))
554
555
556 class _Timeout(object):
557     """Wrap a function and add a timeout to it.
558
559     Instances of this class are automatically generated by the add_timeout
560     function defined below.  Do not use directly.
561     """
562
563     def __init__(
564         self,
565         function: Callable,
566         timeout_exception: Exception,
567         error_message: str,
568         seconds: float,
569     ):
570         self.__limit = seconds
571         self.__function = function
572         self.__timeout_exception = timeout_exception
573         self.__error_message = error_message
574         self.__name__ = function.__name__
575         self.__doc__ = function.__doc__
576         self.__timeout = time.time()
577         self.__process = multiprocessing.Process()
578         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
579
580     def __call__(self, *args, **kwargs):
581         """Execute the embedded function object asynchronously.
582
583         The function given to the constructor is transparently called and
584         requires that "ready" be intermittently polled. If and when it is
585         True, the "value" property may then be checked for returned data.
586         """
587         self.__limit = kwargs.pop("timeout", self.__limit)
588         self.__queue = multiprocessing.Queue(1)
589         args = (self.__queue, self.__function) + args
590         self.__process = multiprocessing.Process(
591             target=_target, args=args, kwargs=kwargs
592         )
593         self.__process.daemon = True
594         self.__process.start()
595         if self.__limit is not None:
596             self.__timeout = self.__limit + time.time()
597         while not self.ready:
598             time.sleep(0.1)
599         return self.value
600
601     def cancel(self):
602         """Terminate any possible execution of the embedded function."""
603         if self.__process.is_alive():
604             self.__process.terminate()
605         _raise_exception(self.__timeout_exception, self.__error_message)
606
607     @property
608     def ready(self):
609         """Read-only property indicating status of "value" property."""
610         if self.__limit and self.__timeout < time.time():
611             self.cancel()
612         return self.__queue.full() and not self.__queue.empty()
613
614     @property
615     def value(self):
616         """Read-only property containing data returned from function."""
617         if self.ready is True:
618             flag, load = self.__queue.get()
619             if flag:
620                 return load
621             raise load
622
623
624 def timeout(
625     seconds: float = 1.0,
626     use_signals: Optional[bool] = None,
627     timeout_exception=exceptions.TimeoutError,
628     error_message="Function call timed out",
629 ):
630     """Add a timeout parameter to a function and return the function.
631
632     Note: the use_signals parameter is included in order to support
633     multiprocessing scenarios (signal can only be used from the process'
634     main thread).  When not using signals, timeout granularity will be
635     rounded to the nearest 0.1s.
636
637     Raises an exception when/if the timeout is reached.
638
639     It is illegal to pass anything other than a function as the first
640     parameter.  The function is wrapped and returned to the caller.
641
642     >>> @timeout(0.2)
643     ... def foo(delay: float):
644     ...     time.sleep(delay)
645     ...     return "ok"
646
647     >>> foo(0)
648     'ok'
649
650     >>> foo(1.0)
651     Traceback (most recent call last):
652     ...
653     Exception: Function call timed out
654
655     """
656     if use_signals is None:
657         import thread_utils
658         use_signals = thread_utils.is_current_thread_main_thread()
659
660     def decorate(function):
661         if use_signals:
662
663             def handler(signum, frame):
664                 _raise_exception(timeout_exception, error_message)
665
666             @functools.wraps(function)
667             def new_function(*args, **kwargs):
668                 new_seconds = kwargs.pop("timeout", seconds)
669                 if new_seconds:
670                     old = signal.signal(signal.SIGALRM, handler)
671                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
672
673                 if not seconds:
674                     return function(*args, **kwargs)
675
676                 try:
677                     return function(*args, **kwargs)
678                 finally:
679                     if new_seconds:
680                         signal.setitimer(signal.ITIMER_REAL, 0)
681                         signal.signal(signal.SIGALRM, old)
682
683             return new_function
684         else:
685
686             @functools.wraps(function)
687             def new_function(*args, **kwargs):
688                 timeout_wrapper = _Timeout(
689                     function, timeout_exception, error_message, seconds
690                 )
691                 return timeout_wrapper(*args, **kwargs)
692
693             return new_function
694
695     return decorate
696
697
698 class non_reentrant_code(object):
699     def __init__(self):
700         self._lock = threading.RLock
701         self._entered = False
702
703     def __call__(self, f):
704         def _gatekeeper(*args, **kwargs):
705             with self._lock:
706                 if self._entered:
707                     return
708                 self._entered = True
709                 f(*args, **kwargs)
710                 self._entered = False
711         return _gatekeeper
712
713
714 class rlocked(object):
715     def __init__(self):
716         self._lock = threading.RLock
717         self._entered = False
718
719     def __call__(self, f):
720         def _gatekeeper(*args, **kwargs):
721             with self._lock:
722                 if self._entered:
723                     return
724                 self._entered = True
725                 f(*args, **kwargs)
726                 self._entered = False
727         return _gatekeeper
728
729
730 def call_with_sample_rate(sample_rate: float) -> Callable:
731     if not 0.0 <= sample_rate <= 1.0:
732         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
733         logger.critical(msg)
734         raise ValueError(msg)
735
736     def decorator(f):
737         @functools.wraps(f)
738         def _call_with_sample_rate(*args, **kwargs):
739             if random.uniform(0, 1) < sample_rate:
740                 return f(*args, **kwargs)
741             else:
742                 logger.debug(
743                     f"@call_with_sample_rate skipping a call to {f.__name__}"
744                 )
745         return _call_with_sample_rate
746     return decorator
747
748
749 def decorate_matching_methods_with(decorator, acl=None):
750     """Apply decorator to all methods in a class whose names begin with
751     prefix.  If prefix is None (default), decorate all methods in the
752     class.
753     """
754     def decorate_the_class(cls):
755         for name, m in inspect.getmembers(cls, inspect.isfunction):
756             if acl is None:
757                 setattr(cls, name, decorator(m))
758             else:
759                 if acl(name):
760                     setattr(cls, name, decorator(m))
761         return cls
762     return decorate_the_class
763
764
765 if __name__ == '__main__':
766     import doctest
767     doctest.testmod()
768