Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 import warnings
18 from typing import Any, Callable, Optional
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24 logger = logging.getLogger(__name__)
25
26
27 def timed(func: Callable) -> Callable:
28     """Print the runtime of the decorated function.
29
30     >>> @timed
31     ... def foo():
32     ...     import time
33     ...     time.sleep(0.1)
34
35     >>> foo()  # doctest: +ELLIPSIS
36     Finished foo in ...
37
38     """
39
40     @functools.wraps(func)
41     def wrapper_timer(*args, **kwargs):
42         start_time = time.perf_counter()
43         value = func(*args, **kwargs)
44         end_time = time.perf_counter()
45         run_time = end_time - start_time
46         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
47         print(msg)
48         logger.info(msg)
49         return value
50
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78
79     return wrapper_invocation_logged
80
81
82 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
83     """Limit invocation of a wrapped function to n calls per period.
84     Thread safe.  In testing this was relatively fair with multiple
85     threads using it though that hasn't been measured.
86
87     >>> import time
88     >>> import decorator_utils
89     >>> import thread_utils
90
91     >>> calls = 0
92
93     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
94     ... def limited(x: int):
95     ...     global calls
96     ...     calls += 1
97
98     >>> @thread_utils.background_thread
99     ... def a(stop):
100     ...     for _ in range(3):
101     ...         limited(_)
102
103     >>> @thread_utils.background_thread
104     ... def b(stop):
105     ...     for _ in range(3):
106     ...         limited(_)
107
108     >>> start = time.time()
109     >>> (t1, e1) = a()
110     >>> (t2, e2) = b()
111     >>> t1.join()
112     >>> t2.join()
113     >>> end = time.time()
114     >>> dur = end - start
115     >>> dur > 0.5
116     True
117
118     >>> calls
119     6
120
121     """
122     min_interval_seconds = per_period_in_seconds / float(n_calls)
123
124     def wrapper_rate_limited(func: Callable) -> Callable:
125         cv = threading.Condition()
126         last_invocation_timestamp = [0.0]
127
128         def may_proceed() -> float:
129             now = time.time()
130             last_invocation = last_invocation_timestamp[0]
131             if last_invocation != 0.0:
132                 elapsed_since_last = now - last_invocation
133                 wait_time = min_interval_seconds - elapsed_since_last
134             else:
135                 wait_time = 0.0
136             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
137             return wait_time
138
139         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
140             with cv:
141                 while True:
142                     if cv.wait_for(
143                         lambda: may_proceed() <= 0.0,
144                         timeout=may_proceed(),
145                     ):
146                         break
147             with cv:
148                 logger.debug(f'@{time.time()}> calling it...')
149                 ret = func(*args, **kargs)
150                 last_invocation_timestamp[0] = time.time()
151                 logger.debug(
152                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
153                 )
154                 cv.notify()
155             return ret
156
157         return wrapper_wrapper_rate_limited
158
159     return wrapper_rate_limited
160
161
162 def debug_args(func: Callable) -> Callable:
163     """Print the function signature and return value at each call.
164
165     >>> @debug_args
166     ... def foo(a, b, c):
167     ...     print(a)
168     ...     print(b)
169     ...     print(c)
170     ...     return (a + b, c)
171
172     >>> foo(1, 2.0, "test")
173     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
174     1
175     2.0
176     test
177     foo returned (3.0, 'test'):<class 'tuple'>
178     (3.0, 'test')
179     """
180
181     @functools.wraps(func)
182     def wrapper_debug_args(*args, **kwargs):
183         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
184         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
185         signature = ", ".join(args_repr + kwargs_repr)
186         msg = f"Calling {func.__qualname__}({signature})"
187         print(msg)
188         logger.info(msg)
189         value = func(*args, **kwargs)
190         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
191         print(msg)
192         logger.info(msg)
193         return value
194
195     return wrapper_debug_args
196
197
198 def debug_count_calls(func: Callable) -> Callable:
199     """Count function invocations and print a message befor every call.
200
201     >>> @debug_count_calls
202     ... def factoral(x):
203     ...     if x == 1:
204     ...         return 1
205     ...     return x * factoral(x - 1)
206
207     >>> factoral(5)
208     Call #1 of 'factoral'
209     Call #2 of 'factoral'
210     Call #3 of 'factoral'
211     Call #4 of 'factoral'
212     Call #5 of 'factoral'
213     120
214
215     """
216
217     @functools.wraps(func)
218     def wrapper_debug_count_calls(*args, **kwargs):
219         wrapper_debug_count_calls.num_calls += 1
220         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
221         print(msg)
222         logger.info(msg)
223         return func(*args, **kwargs)
224
225     wrapper_debug_count_calls.num_calls = 0  # type: ignore
226     return wrapper_debug_count_calls
227
228
229 class DelayWhen(enum.IntEnum):
230     BEFORE_CALL = 1
231     AFTER_CALL = 2
232     BEFORE_AND_AFTER = 3
233
234
235 def delay(
236     _func: Callable = None,
237     *,
238     seconds: float = 1.0,
239     when: DelayWhen = DelayWhen.BEFORE_CALL,
240 ) -> Callable:
241     """Delay the execution of a function by sleeping before and/or after.
242
243     Slow down a function by inserting a delay before and/or after its
244     invocation.
245
246     >>> import time
247
248     >>> @delay(seconds=1.0)
249     ... def foo():
250     ...     pass
251
252     >>> start = time.time()
253     >>> foo()
254     >>> dur = time.time() - start
255     >>> dur >= 1.0
256     True
257
258     """
259
260     def decorator_delay(func: Callable) -> Callable:
261         @functools.wraps(func)
262         def wrapper_delay(*args, **kwargs):
263             if when & DelayWhen.BEFORE_CALL:
264                 logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
265                 time.sleep(seconds)
266             retval = func(*args, **kwargs)
267             if when & DelayWhen.AFTER_CALL:
268                 logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
269                 time.sleep(seconds)
270             return retval
271
272         return wrapper_delay
273
274     if _func is None:
275         return decorator_delay
276     else:
277         return decorator_delay(_func)
278
279
280 class _SingletonWrapper:
281     """
282     A singleton wrapper class. Its instances would be created
283     for each decorated class.
284
285     """
286
287     def __init__(self, cls):
288         self.__wrapped__ = cls
289         self._instance = None
290
291     def __call__(self, *args, **kwargs):
292         """Returns a single instance of decorated class"""
293         logger.debug(
294             f"@singleton returning global instance of {self.__wrapped__.__name__}"
295         )
296         if self._instance is None:
297             self._instance = self.__wrapped__(*args, **kwargs)
298         return self._instance
299
300
301 def singleton(cls):
302     """
303     A singleton decorator. Returns a wrapper objects. A call on that object
304     returns a single instance object of decorated class. Use the __wrapped__
305     attribute to access decorated class directly in unit tests
306
307     >>> @singleton
308     ... class foo(object):
309     ...     pass
310
311     >>> a = foo()
312     >>> b = foo()
313     >>> a is b
314     True
315
316     >>> id(a) == id(b)
317     True
318
319     """
320     return _SingletonWrapper(cls)
321
322
323 def memoized(func: Callable) -> Callable:
324     """Keep a cache of previous function call results.
325
326     The cache here is a dict with a key based on the arguments to the
327     call.  Consider also: functools.lru_cache for a more advanced
328     implementation.
329
330     >>> import time
331
332     >>> @memoized
333     ... def expensive(arg) -> int:
334     ...     # Simulate something slow to compute or lookup
335     ...     time.sleep(1.0)
336     ...     return arg * arg
337
338     >>> start = time.time()
339     >>> expensive(5)           # Takes about 1 sec
340     25
341
342     >>> expensive(3)           # Also takes about 1 sec
343     9
344
345     >>> expensive(5)           # Pulls from cache, fast
346     25
347
348     >>> expensive(3)           # Pulls from cache again, fast
349     9
350
351     >>> dur = time.time() - start
352     >>> dur < 3.0
353     True
354
355     """
356
357     @functools.wraps(func)
358     def wrapper_memoized(*args, **kwargs):
359         cache_key = args + tuple(kwargs.items())
360         if cache_key not in wrapper_memoized.cache:
361             value = func(*args, **kwargs)
362             logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
363             wrapper_memoized.cache[cache_key] = value
364         else:
365             logger.debug(f"Returning memoized value for {func.__name__}")
366         return wrapper_memoized.cache[cache_key]
367
368     wrapper_memoized.cache = dict()  # type: ignore
369     return wrapper_memoized
370
371
372 def retry_predicate(
373     tries: int,
374     *,
375     predicate: Callable[..., bool],
376     delay_sec: float = 3.0,
377     backoff: float = 2.0,
378 ):
379     """Retries a function or method up to a certain number of times
380     with a prescribed initial delay period and backoff rate.
381
382     tries is the maximum number of attempts to run the function.
383     delay_sec sets the initial delay period in seconds.
384     backoff is a multiplied (must be >1) used to modify the delay.
385     predicate is a function that will be passed the retval of the
386     decorated function and must return True to stop or False to
387     retry.
388
389     """
390     if backoff < 1.0:
391         msg = f"backoff must be greater than or equal to 1, got {backoff}"
392         logger.critical(msg)
393         raise ValueError(msg)
394
395     tries = math.floor(tries)
396     if tries < 0:
397         msg = f"tries must be 0 or greater, got {tries}"
398         logger.critical(msg)
399         raise ValueError(msg)
400
401     if delay_sec <= 0:
402         msg = f"delay_sec must be greater than 0, got {delay_sec}"
403         logger.critical(msg)
404         raise ValueError(msg)
405
406     def deco_retry(f):
407         @functools.wraps(f)
408         def f_retry(*args, **kwargs):
409             mtries, mdelay = tries, delay_sec  # make mutable
410             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
411             retval = f(*args, **kwargs)
412             while mtries > 0:
413                 if predicate(retval) is True:
414                     logger.debug('Predicate succeeded, deco_retry is done.')
415                     return retval
416                 logger.debug("Predicate failed, sleeping and retrying.")
417                 mtries -= 1
418                 time.sleep(mdelay)
419                 mdelay *= backoff
420                 retval = f(*args, **kwargs)
421             return retval
422
423         return f_retry
424
425     return deco_retry
426
427
428 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
429     """A helper for @retry_predicate that retries a decorated
430     function as long as it keeps returning False.
431
432     >>> import time
433
434     >>> counter = 0
435
436     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
437     ... def foo():
438     ...     global counter
439     ...     counter += 1
440     ...     return counter >= 3
441
442     >>> start = time.time()
443     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
444     True
445
446     >>> dur = time.time() - start
447     >>> counter
448     3
449     >>> dur > 2.0
450     True
451     >>> dur < 2.3
452     True
453
454     """
455     return retry_predicate(
456         tries,
457         predicate=lambda x: x is True,
458         delay_sec=delay_sec,
459         backoff=backoff,
460     )
461
462
463 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
464     """Another helper for @retry_predicate above.  Retries up to N
465     times so long as the wrapped function returns None with a delay
466     between each retry and a backoff that can increase the delay.
467
468     """
469     return retry_predicate(
470         tries,
471         predicate=lambda x: x is not None,
472         delay_sec=delay_sec,
473         backoff=backoff,
474     )
475
476
477 def deprecated(func):
478     """This is a decorator which can be used to mark functions
479     as deprecated. It will result in a warning being emitted
480     when the function is used.
481
482     """
483
484     @functools.wraps(func)
485     def wrapper_deprecated(*args, **kwargs):
486         msg = f"Call to deprecated function {func.__qualname__}"
487         logger.warning(msg)
488         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
489         print(msg, file=sys.stderr)
490         return func(*args, **kwargs)
491
492     return wrapper_deprecated
493
494
495 def thunkify(func):
496     """
497     Make a function immediately return a function of no args which,
498     when called, waits for the result, which will start being
499     processed in another thread.
500     """
501
502     @functools.wraps(func)
503     def lazy_thunked(*args, **kwargs):
504         wait_event = threading.Event()
505
506         result = [None]
507         exc = [False, None]
508
509         def worker_func():
510             try:
511                 func_result = func(*args, **kwargs)
512                 result[0] = func_result
513             except Exception:
514                 exc[0] = True
515                 exc[1] = sys.exc_info()  # (type, value, traceback)
516                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
517                 logger.warning(msg)
518             finally:
519                 wait_event.set()
520
521         def thunk():
522             wait_event.wait()
523             if exc[0]:
524                 raise exc[1][0](exc[1][1])
525             return result[0]
526
527         threading.Thread(target=worker_func).start()
528         return thunk
529
530     return lazy_thunked
531
532
533 ############################################################
534 # Timeout
535 ############################################################
536
537 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
538 # Used work of Stephen "Zero" Chappell <[email protected]>
539 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
540
541
542 def _raise_exception(exception, error_message: Optional[str]):
543     if error_message is None:
544         raise Exception()
545     else:
546         raise Exception(error_message)
547
548
549 def _target(queue, function, *args, **kwargs):
550     """Run a function with arguments and return output via a queue.
551
552     This is a helper function for the Process created in _Timeout. It runs
553     the function with positional arguments and keyword arguments and then
554     returns the function's output by way of a queue. If an exception gets
555     raised, it is returned to _Timeout to be raised by the value property.
556     """
557     try:
558         queue.put((True, function(*args, **kwargs)))
559     except Exception:
560         queue.put((False, sys.exc_info()[1]))
561
562
563 class _Timeout(object):
564     """Wrap a function and add a timeout to it.
565
566     Instances of this class are automatically generated by the add_timeout
567     function defined below.  Do not use directly.
568     """
569
570     def __init__(
571         self,
572         function: Callable,
573         timeout_exception: Exception,
574         error_message: str,
575         seconds: float,
576     ):
577         self.__limit = seconds
578         self.__function = function
579         self.__timeout_exception = timeout_exception
580         self.__error_message = error_message
581         self.__name__ = function.__name__
582         self.__doc__ = function.__doc__
583         self.__timeout = time.time()
584         self.__process = multiprocessing.Process()
585         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
586
587     def __call__(self, *args, **kwargs):
588         """Execute the embedded function object asynchronously.
589
590         The function given to the constructor is transparently called and
591         requires that "ready" be intermittently polled. If and when it is
592         True, the "value" property may then be checked for returned data.
593         """
594         self.__limit = kwargs.pop("timeout", self.__limit)
595         self.__queue = multiprocessing.Queue(1)
596         args = (self.__queue, self.__function) + args
597         self.__process = multiprocessing.Process(
598             target=_target, args=args, kwargs=kwargs
599         )
600         self.__process.daemon = True
601         self.__process.start()
602         if self.__limit is not None:
603             self.__timeout = self.__limit + time.time()
604         while not self.ready:
605             time.sleep(0.1)
606         return self.value
607
608     def cancel(self):
609         """Terminate any possible execution of the embedded function."""
610         if self.__process.is_alive():
611             self.__process.terminate()
612         _raise_exception(self.__timeout_exception, self.__error_message)
613
614     @property
615     def ready(self):
616         """Read-only property indicating status of "value" property."""
617         if self.__limit and self.__timeout < time.time():
618             self.cancel()
619         return self.__queue.full() and not self.__queue.empty()
620
621     @property
622     def value(self):
623         """Read-only property containing data returned from function."""
624         if self.ready is True:
625             flag, load = self.__queue.get()
626             if flag:
627                 return load
628             raise load
629
630
631 def timeout(
632     seconds: float = 1.0,
633     use_signals: Optional[bool] = None,
634     timeout_exception=exceptions.TimeoutError,
635     error_message="Function call timed out",
636 ):
637     """Add a timeout parameter to a function and return the function.
638
639     Note: the use_signals parameter is included in order to support
640     multiprocessing scenarios (signal can only be used from the process'
641     main thread).  When not using signals, timeout granularity will be
642     rounded to the nearest 0.1s.
643
644     Raises an exception when/if the timeout is reached.
645
646     It is illegal to pass anything other than a function as the first
647     parameter.  The function is wrapped and returned to the caller.
648
649     >>> @timeout(0.2)
650     ... def foo(delay: float):
651     ...     time.sleep(delay)
652     ...     return "ok"
653
654     >>> foo(0)
655     'ok'
656
657     >>> foo(1.0)
658     Traceback (most recent call last):
659     ...
660     Exception: Function call timed out
661
662     """
663     if use_signals is None:
664         import thread_utils
665
666         use_signals = thread_utils.is_current_thread_main_thread()
667
668     def decorate(function):
669         if use_signals:
670
671             def handler(signum, frame):
672                 _raise_exception(timeout_exception, error_message)
673
674             @functools.wraps(function)
675             def new_function(*args, **kwargs):
676                 new_seconds = kwargs.pop("timeout", seconds)
677                 if new_seconds:
678                     old = signal.signal(signal.SIGALRM, handler)
679                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
680
681                 if not seconds:
682                     return function(*args, **kwargs)
683
684                 try:
685                     return function(*args, **kwargs)
686                 finally:
687                     if new_seconds:
688                         signal.setitimer(signal.ITIMER_REAL, 0)
689                         signal.signal(signal.SIGALRM, old)
690
691             return new_function
692         else:
693
694             @functools.wraps(function)
695             def new_function(*args, **kwargs):
696                 timeout_wrapper = _Timeout(
697                     function, timeout_exception, error_message, seconds
698                 )
699                 return timeout_wrapper(*args, **kwargs)
700
701             return new_function
702
703     return decorate
704
705
706 def synchronized(lock):
707     def wrap(f):
708         @functools.wraps(f)
709         def _gatekeeper(*args, **kw):
710             lock.acquire()
711             try:
712                 return f(*args, **kw)
713             finally:
714                 lock.release()
715
716         return _gatekeeper
717
718     return wrap
719
720
721 def call_with_sample_rate(sample_rate: float) -> Callable:
722     if not 0.0 <= sample_rate <= 1.0:
723         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
724         logger.critical(msg)
725         raise ValueError(msg)
726
727     def decorator(f):
728         @functools.wraps(f)
729         def _call_with_sample_rate(*args, **kwargs):
730             if random.uniform(0, 1) < sample_rate:
731                 return f(*args, **kwargs)
732             else:
733                 logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
734
735         return _call_with_sample_rate
736
737     return decorator
738
739
740 def decorate_matching_methods_with(decorator, acl=None):
741     """Apply decorator to all methods in a class whose names begin with
742     prefix.  If prefix is None (default), decorate all methods in the
743     class.
744     """
745
746     def decorate_the_class(cls):
747         for name, m in inspect.getmembers(cls, inspect.isfunction):
748             if acl is None:
749                 setattr(cls, name, decorator(m))
750             else:
751                 if acl(name):
752                     setattr(cls, name, decorator(m))
753         return cls
754
755     return decorate_the_class
756
757
758 if __name__ == '__main__':
759     import doctest
760
761     doctest.testmod()