5d1e779deeeaa3df34547f853ce5e4fc0cdcb87b
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 import warnings
18 from typing import Any, Callable, Optional
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24 logger = logging.getLogger(__name__)
25
26
27 def timed(func: Callable) -> Callable:
28     """Print the runtime of the decorated function.
29
30     >>> @timed
31     ... def foo():
32     ...     import time
33     ...     time.sleep(0.1)
34
35     >>> foo()  # doctest: +ELLIPSIS
36     Finished foo in ...
37
38     """
39
40     @functools.wraps(func)
41     def wrapper_timer(*args, **kwargs):
42         start_time = time.perf_counter()
43         value = func(*args, **kwargs)
44         end_time = time.perf_counter()
45         run_time = end_time - start_time
46         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
47         print(msg)
48         logger.info(msg)
49         return value
50
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78
79     return wrapper_invocation_logged
80
81
82 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
83     """Limit invocation of a wrapped function to n calls per period.
84     Thread safe.  In testing this was relatively fair with multiple
85     threads using it though that hasn't been measured.
86
87     >>> import time
88     >>> import decorator_utils
89     >>> import thread_utils
90
91     >>> calls = 0
92
93     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
94     ... def limited(x: int):
95     ...     global calls
96     ...     calls += 1
97
98     >>> @thread_utils.background_thread
99     ... def a(stop):
100     ...     for _ in range(3):
101     ...         limited(_)
102
103     >>> @thread_utils.background_thread
104     ... def b(stop):
105     ...     for _ in range(3):
106     ...         limited(_)
107
108     >>> start = time.time()
109     >>> (t1, e1) = a()
110     >>> (t2, e2) = b()
111     >>> t1.join()
112     >>> t2.join()
113     >>> end = time.time()
114     >>> dur = end - start
115     >>> dur > 0.5
116     True
117
118     >>> calls
119     6
120
121     """
122     min_interval_seconds = per_period_in_seconds / float(n_calls)
123
124     def wrapper_rate_limited(func: Callable) -> Callable:
125         cv = threading.Condition()
126         last_invocation_timestamp = [0.0]
127
128         def may_proceed() -> float:
129             now = time.time()
130             last_invocation = last_invocation_timestamp[0]
131             if last_invocation != 0.0:
132                 elapsed_since_last = now - last_invocation
133                 wait_time = min_interval_seconds - elapsed_since_last
134             else:
135                 wait_time = 0.0
136             logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
137             return wait_time
138
139         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
140             with cv:
141                 while True:
142                     if cv.wait_for(
143                         lambda: may_proceed() <= 0.0,
144                         timeout=may_proceed(),
145                     ):
146                         break
147             with cv:
148                 logger.debug('@%.4f> calling it...', time.time())
149                 ret = func(*args, **kargs)
150                 last_invocation_timestamp[0] = time.time()
151                 logger.debug(
152                     '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
153                 )
154                 cv.notify()
155             return ret
156
157         return wrapper_wrapper_rate_limited
158
159     return wrapper_rate_limited
160
161
162 def debug_args(func: Callable) -> Callable:
163     """Print the function signature and return value at each call.
164
165     >>> @debug_args
166     ... def foo(a, b, c):
167     ...     print(a)
168     ...     print(b)
169     ...     print(c)
170     ...     return (a + b, c)
171
172     >>> foo(1, 2.0, "test")
173     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
174     1
175     2.0
176     test
177     foo returned (3.0, 'test'):<class 'tuple'>
178     (3.0, 'test')
179     """
180
181     @functools.wraps(func)
182     def wrapper_debug_args(*args, **kwargs):
183         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
184         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
185         signature = ", ".join(args_repr + kwargs_repr)
186         msg = f"Calling {func.__qualname__}({signature})"
187         print(msg)
188         logger.info(msg)
189         value = func(*args, **kwargs)
190         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
191         print(msg)
192         logger.info(msg)
193         return value
194
195     return wrapper_debug_args
196
197
198 def debug_count_calls(func: Callable) -> Callable:
199     """Count function invocations and print a message befor every call.
200
201     >>> @debug_count_calls
202     ... def factoral(x):
203     ...     if x == 1:
204     ...         return 1
205     ...     return x * factoral(x - 1)
206
207     >>> factoral(5)
208     Call #1 of 'factoral'
209     Call #2 of 'factoral'
210     Call #3 of 'factoral'
211     Call #4 of 'factoral'
212     Call #5 of 'factoral'
213     120
214
215     """
216
217     @functools.wraps(func)
218     def wrapper_debug_count_calls(*args, **kwargs):
219         wrapper_debug_count_calls.num_calls += 1
220         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
221         print(msg)
222         logger.info(msg)
223         return func(*args, **kwargs)
224
225     wrapper_debug_count_calls.num_calls = 0  # type: ignore
226     return wrapper_debug_count_calls
227
228
229 class DelayWhen(enum.IntEnum):
230     """When should we delay: before or after calling the function (or
231     both)?
232
233     """
234
235     BEFORE_CALL = 1
236     AFTER_CALL = 2
237     BEFORE_AND_AFTER = 3
238
239
240 def delay(
241     _func: Callable = None,
242     *,
243     seconds: float = 1.0,
244     when: DelayWhen = DelayWhen.BEFORE_CALL,
245 ) -> Callable:
246     """Delay the execution of a function by sleeping before and/or after.
247
248     Slow down a function by inserting a delay before and/or after its
249     invocation.
250
251     >>> import time
252
253     >>> @delay(seconds=1.0)
254     ... def foo():
255     ...     pass
256
257     >>> start = time.time()
258     >>> foo()
259     >>> dur = time.time() - start
260     >>> dur >= 1.0
261     True
262
263     """
264
265     def decorator_delay(func: Callable) -> Callable:
266         @functools.wraps(func)
267         def wrapper_delay(*args, **kwargs):
268             if when & DelayWhen.BEFORE_CALL:
269                 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
270                 time.sleep(seconds)
271             retval = func(*args, **kwargs)
272             if when & DelayWhen.AFTER_CALL:
273                 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
274                 time.sleep(seconds)
275             return retval
276
277         return wrapper_delay
278
279     if _func is None:
280         return decorator_delay
281     else:
282         return decorator_delay(_func)
283
284
285 class _SingletonWrapper:
286     """
287     A singleton wrapper class. Its instances would be created
288     for each decorated class.
289
290     """
291
292     def __init__(self, cls):
293         self.__wrapped__ = cls
294         self._instance = None
295
296     def __call__(self, *args, **kwargs):
297         """Returns a single instance of decorated class"""
298         logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
299         if self._instance is None:
300             self._instance = self.__wrapped__(*args, **kwargs)
301         return self._instance
302
303
304 def singleton(cls):
305     """
306     A singleton decorator. Returns a wrapper objects. A call on that object
307     returns a single instance object of decorated class. Use the __wrapped__
308     attribute to access decorated class directly in unit tests
309
310     >>> @singleton
311     ... class foo(object):
312     ...     pass
313
314     >>> a = foo()
315     >>> b = foo()
316     >>> a is b
317     True
318
319     >>> id(a) == id(b)
320     True
321
322     """
323     return _SingletonWrapper(cls)
324
325
326 def memoized(func: Callable) -> Callable:
327     """Keep a cache of previous function call results.
328
329     The cache here is a dict with a key based on the arguments to the
330     call.  Consider also: functools.lru_cache for a more advanced
331     implementation.
332
333     >>> import time
334
335     >>> @memoized
336     ... def expensive(arg) -> int:
337     ...     # Simulate something slow to compute or lookup
338     ...     time.sleep(1.0)
339     ...     return arg * arg
340
341     >>> start = time.time()
342     >>> expensive(5)           # Takes about 1 sec
343     25
344
345     >>> expensive(3)           # Also takes about 1 sec
346     9
347
348     >>> expensive(5)           # Pulls from cache, fast
349     25
350
351     >>> expensive(3)           # Pulls from cache again, fast
352     9
353
354     >>> dur = time.time() - start
355     >>> dur < 3.0
356     True
357
358     """
359
360     @functools.wraps(func)
361     def wrapper_memoized(*args, **kwargs):
362         cache_key = args + tuple(kwargs.items())
363         if cache_key not in wrapper_memoized.cache:
364             value = func(*args, **kwargs)
365             logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
366             wrapper_memoized.cache[cache_key] = value
367         else:
368             logger.debug('Returning memoized value for %s', {func.__name__})
369         return wrapper_memoized.cache[cache_key]
370
371     wrapper_memoized.cache = {}  # type: ignore
372     return wrapper_memoized
373
374
375 def retry_predicate(
376     tries: int,
377     *,
378     predicate: Callable[..., bool],
379     delay_sec: float = 3.0,
380     backoff: float = 2.0,
381 ):
382     """Retries a function or method up to a certain number of times
383     with a prescribed initial delay period and backoff rate.
384
385     tries is the maximum number of attempts to run the function.
386     delay_sec sets the initial delay period in seconds.
387     backoff is a multiplied (must be >1) used to modify the delay.
388     predicate is a function that will be passed the retval of the
389     decorated function and must return True to stop or False to
390     retry.
391
392     """
393     if backoff < 1.0:
394         msg = f"backoff must be greater than or equal to 1, got {backoff}"
395         logger.critical(msg)
396         raise ValueError(msg)
397
398     tries = math.floor(tries)
399     if tries < 0:
400         msg = f"tries must be 0 or greater, got {tries}"
401         logger.critical(msg)
402         raise ValueError(msg)
403
404     if delay_sec <= 0:
405         msg = f"delay_sec must be greater than 0, got {delay_sec}"
406         logger.critical(msg)
407         raise ValueError(msg)
408
409     def deco_retry(f):
410         @functools.wraps(f)
411         def f_retry(*args, **kwargs):
412             mtries, mdelay = tries, delay_sec  # make mutable
413             logger.debug('deco_retry: will make up to %d attempts...', mtries)
414             retval = f(*args, **kwargs)
415             while mtries > 0:
416                 if predicate(retval) is True:
417                     logger.debug('Predicate succeeded, deco_retry is done.')
418                     return retval
419                 logger.debug("Predicate failed, sleeping and retrying.")
420                 mtries -= 1
421                 time.sleep(mdelay)
422                 mdelay *= backoff
423                 retval = f(*args, **kwargs)
424             return retval
425
426         return f_retry
427
428     return deco_retry
429
430
431 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
432     """A helper for @retry_predicate that retries a decorated
433     function as long as it keeps returning False.
434
435     >>> import time
436
437     >>> counter = 0
438
439     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
440     ... def foo():
441     ...     global counter
442     ...     counter += 1
443     ...     return counter >= 3
444
445     >>> start = time.time()
446     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
447     True
448
449     >>> dur = time.time() - start
450     >>> counter
451     3
452     >>> dur > 2.0
453     True
454     >>> dur < 2.3
455     True
456
457     """
458     return retry_predicate(
459         tries,
460         predicate=lambda x: x is True,
461         delay_sec=delay_sec,
462         backoff=backoff,
463     )
464
465
466 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
467     """Another helper for @retry_predicate above.  Retries up to N
468     times so long as the wrapped function returns None with a delay
469     between each retry and a backoff that can increase the delay.
470
471     """
472     return retry_predicate(
473         tries,
474         predicate=lambda x: x is not None,
475         delay_sec=delay_sec,
476         backoff=backoff,
477     )
478
479
480 def deprecated(func):
481     """This is a decorator which can be used to mark functions
482     as deprecated. It will result in a warning being emitted
483     when the function is used.
484
485     """
486
487     @functools.wraps(func)
488     def wrapper_deprecated(*args, **kwargs):
489         msg = f"Call to deprecated function {func.__qualname__}"
490         logger.warning(msg)
491         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
492         print(msg, file=sys.stderr)
493         return func(*args, **kwargs)
494
495     return wrapper_deprecated
496
497
498 def thunkify(func):
499     """
500     Make a function immediately return a function of no args which,
501     when called, waits for the result, which will start being
502     processed in another thread.
503     """
504
505     @functools.wraps(func)
506     def lazy_thunked(*args, **kwargs):
507         wait_event = threading.Event()
508
509         result = [None]
510         exc = [False, None]
511
512         def worker_func():
513             try:
514                 func_result = func(*args, **kwargs)
515                 result[0] = func_result
516             except Exception:
517                 exc[0] = True
518                 exc[1] = sys.exc_info()  # (type, value, traceback)
519                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
520                 logger.warning(msg)
521             finally:
522                 wait_event.set()
523
524         def thunk():
525             wait_event.wait()
526             if exc[0]:
527                 raise exc[1][0](exc[1][1])
528             return result[0]
529
530         threading.Thread(target=worker_func).start()
531         return thunk
532
533     return lazy_thunked
534
535
536 ############################################################
537 # Timeout
538 ############################################################
539
540 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
541 # Used work of Stephen "Zero" Chappell <[email protected]>
542 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
543
544
545 def _raise_exception(exception, error_message: Optional[str]):
546     if error_message is None:
547         raise Exception(exception)
548     else:
549         raise Exception(error_message)
550
551
552 def _target(queue, function, *args, **kwargs):
553     """Run a function with arguments and return output via a queue.
554
555     This is a helper function for the Process created in _Timeout. It runs
556     the function with positional arguments and keyword arguments and then
557     returns the function's output by way of a queue. If an exception gets
558     raised, it is returned to _Timeout to be raised by the value property.
559     """
560     try:
561         queue.put((True, function(*args, **kwargs)))
562     except Exception:
563         queue.put((False, sys.exc_info()[1]))
564
565
566 class _Timeout(object):
567     """Wrap a function and add a timeout to it.
568
569     Instances of this class are automatically generated by the add_timeout
570     function defined below.  Do not use directly.
571     """
572
573     def __init__(
574         self,
575         function: Callable,
576         timeout_exception: Exception,
577         error_message: str,
578         seconds: float,
579     ):
580         self.__limit = seconds
581         self.__function = function
582         self.__timeout_exception = timeout_exception
583         self.__error_message = error_message
584         self.__name__ = function.__name__
585         self.__doc__ = function.__doc__
586         self.__timeout = time.time()
587         self.__process = multiprocessing.Process()
588         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
589
590     def __call__(self, *args, **kwargs):
591         """Execute the embedded function object asynchronously.
592
593         The function given to the constructor is transparently called and
594         requires that "ready" be intermittently polled. If and when it is
595         True, the "value" property may then be checked for returned data.
596         """
597         self.__limit = kwargs.pop("timeout", self.__limit)
598         self.__queue = multiprocessing.Queue(1)
599         args = (self.__queue, self.__function) + args
600         self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
601         self.__process.daemon = True
602         self.__process.start()
603         if self.__limit is not None:
604             self.__timeout = self.__limit + time.time()
605         while not self.ready:
606             time.sleep(0.1)
607         return self.value
608
609     def cancel(self):
610         """Terminate any possible execution of the embedded function."""
611         if self.__process.is_alive():
612             self.__process.terminate()
613         _raise_exception(self.__timeout_exception, self.__error_message)
614
615     @property
616     def ready(self):
617         """Read-only property indicating status of "value" property."""
618         if self.__limit and self.__timeout < time.time():
619             self.cancel()
620         return self.__queue.full() and not self.__queue.empty()
621
622     @property
623     def value(self):
624         """Read-only property containing data returned from function."""
625         if self.ready is True:
626             flag, load = self.__queue.get()
627             if flag:
628                 return load
629             raise load
630         return None
631
632
633 def timeout(
634     seconds: float = 1.0,
635     use_signals: Optional[bool] = None,
636     timeout_exception=exceptions.TimeoutError,
637     error_message="Function call timed out",
638 ):
639     """Add a timeout parameter to a function and return the function.
640
641     Note: the use_signals parameter is included in order to support
642     multiprocessing scenarios (signal can only be used from the process'
643     main thread).  When not using signals, timeout granularity will be
644     rounded to the nearest 0.1s.
645
646     Raises an exception when/if the timeout is reached.
647
648     It is illegal to pass anything other than a function as the first
649     parameter.  The function is wrapped and returned to the caller.
650
651     >>> @timeout(0.2)
652     ... def foo(delay: float):
653     ...     time.sleep(delay)
654     ...     return "ok"
655
656     >>> foo(0)
657     'ok'
658
659     >>> foo(1.0)
660     Traceback (most recent call last):
661     ...
662     Exception: Function call timed out
663
664     """
665     if use_signals is None:
666         import thread_utils
667
668         use_signals = thread_utils.is_current_thread_main_thread()
669
670     def decorate(function):
671         if use_signals:
672
673             def handler(signum, frame):
674                 _raise_exception(timeout_exception, error_message)
675
676             @functools.wraps(function)
677             def new_function(*args, **kwargs):
678                 new_seconds = kwargs.pop("timeout", seconds)
679                 if new_seconds:
680                     old = signal.signal(signal.SIGALRM, handler)
681                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
682
683                 if not seconds:
684                     return function(*args, **kwargs)
685
686                 try:
687                     return function(*args, **kwargs)
688                 finally:
689                     if new_seconds:
690                         signal.setitimer(signal.ITIMER_REAL, 0)
691                         signal.signal(signal.SIGALRM, old)
692
693             return new_function
694         else:
695
696             @functools.wraps(function)
697             def new_function(*args, **kwargs):
698                 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
699                 return timeout_wrapper(*args, **kwargs)
700
701             return new_function
702
703     return decorate
704
705
706 def synchronized(lock):
707     def wrap(f):
708         @functools.wraps(f)
709         def _gatekeeper(*args, **kw):
710             lock.acquire()
711             try:
712                 return f(*args, **kw)
713             finally:
714                 lock.release()
715
716         return _gatekeeper
717
718     return wrap
719
720
721 def call_with_sample_rate(sample_rate: float) -> Callable:
722     if not 0.0 <= sample_rate <= 1.0:
723         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
724         logger.critical(msg)
725         raise ValueError(msg)
726
727     def decorator(f):
728         @functools.wraps(f)
729         def _call_with_sample_rate(*args, **kwargs):
730             if random.uniform(0, 1) < sample_rate:
731                 return f(*args, **kwargs)
732             else:
733                 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
734                 return None
735
736         return _call_with_sample_rate
737
738     return decorator
739
740
741 def decorate_matching_methods_with(decorator, acl=None):
742     """Apply decorator to all methods in a class whose names begin with
743     prefix.  If prefix is None (default), decorate all methods in the
744     class.
745     """
746
747     def decorate_the_class(cls):
748         for name, m in inspect.getmembers(cls, inspect.isfunction):
749             if acl is None:
750                 setattr(cls, name, decorator(m))
751             else:
752                 if acl(name):
753                     setattr(cls, name, decorator(m))
754         return cls
755
756     return decorate_the_class
757
758
759 if __name__ == '__main__':
760     import doctest
761
762     doctest.testmod()