Add another doctest.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function.
30
31     >>> @timed
32     ... def foo():
33     ...     import time
34     ...     time.sleep(0.1)
35
36     >>> foo()  # doctest: +ELLIPSIS
37     Finished foo in ...
38
39     """
40
41     @functools.wraps(func)
42     def wrapper_timer(*args, **kwargs):
43         start_time = time.perf_counter()
44         value = func(*args, **kwargs)
45         end_time = time.perf_counter()
46         run_time = end_time - start_time
47         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
48         print(msg)
49         logger.info(msg)
50         return value
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78     return wrapper_invocation_logged
79
80
81 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
82     """Limit invocation of a wrapped function to n calls per period.
83     Thread safe.  In testing this was relatively fair with multiple
84     threads using it though that hasn't been measured.
85
86     >>> import time
87     >>> import decorator_utils
88     >>> import thread_utils
89
90     >>> calls = 0
91
92     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
93     ... def limited(x: int):
94     ...     global calls
95     ...     calls += 1
96
97     >>> @thread_utils.background_thread
98     ... def a(stop):
99     ...     for _ in range(3):
100     ...         limited(_)
101
102     >>> @thread_utils.background_thread
103     ... def b(stop):
104     ...     for _ in range(3):
105     ...         limited(_)
106
107     >>> start = time.time()
108     >>> (t1, e1) = a()
109     >>> (t2, e2) = b()
110     >>> t1.join()
111     >>> t2.join()
112     >>> end = time.time()
113     >>> dur = end - start
114     >>> dur > 0.5
115     True
116
117     >>> calls
118     6
119
120     """
121     min_interval_seconds = per_period_in_seconds / float(n_calls)
122
123     def wrapper_rate_limited(func: Callable) -> Callable:
124         cv = threading.Condition()
125         last_invocation_timestamp = [0.0]
126
127         def may_proceed() -> float:
128             now = time.time()
129             last_invocation = last_invocation_timestamp[0]
130             if last_invocation != 0.0:
131                 elapsed_since_last = now - last_invocation
132                 wait_time = min_interval_seconds - elapsed_since_last
133             else:
134                 wait_time = 0.0
135             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
136             return wait_time
137
138         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
139             with cv:
140                 while True:
141                     if cv.wait_for(
142                         lambda: may_proceed() <= 0.0,
143                         timeout=may_proceed(),
144                     ):
145                         break
146             with cv:
147                 logger.debug(f'@{time.time()}> calling it...')
148                 ret = func(*args, **kargs)
149                 last_invocation_timestamp[0] = time.time()
150                 logger.debug(
151                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
152                 )
153                 cv.notify()
154             return ret
155         return wrapper_wrapper_rate_limited
156     return wrapper_rate_limited
157
158
159 def debug_args(func: Callable) -> Callable:
160     """Print the function signature and return value at each call.
161
162     >>> @debug_args
163     ... def foo(a, b, c):
164     ...     print(a)
165     ...     print(b)
166     ...     print(c)
167     ...     return (a + b, c)
168
169     >>> foo(1, 2.0, "test")
170     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
171     1
172     2.0
173     test
174     foo returned (3.0, 'test'):<class 'tuple'>
175     (3.0, 'test')
176     """
177
178     @functools.wraps(func)
179     def wrapper_debug_args(*args, **kwargs):
180         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
181         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
182         signature = ", ".join(args_repr + kwargs_repr)
183         msg = f"Calling {func.__qualname__}({signature})"
184         print(msg)
185         logger.info(msg)
186         value = func(*args, **kwargs)
187         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
188         print(msg)
189         logger.info(msg)
190         return value
191     return wrapper_debug_args
192
193
194 def debug_count_calls(func: Callable) -> Callable:
195     """Count function invocations and print a message befor every call.
196
197     >>> @debug_count_calls
198     ... def factoral(x):
199     ...     if x == 1:
200     ...         return 1
201     ...     return x * factoral(x - 1)
202
203     >>> factoral(5)
204     Call #1 of 'factoral'
205     Call #2 of 'factoral'
206     Call #3 of 'factoral'
207     Call #4 of 'factoral'
208     Call #5 of 'factoral'
209     120
210
211     """
212
213     @functools.wraps(func)
214     def wrapper_debug_count_calls(*args, **kwargs):
215         wrapper_debug_count_calls.num_calls += 1
216         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
217         print(msg)
218         logger.info(msg)
219         return func(*args, **kwargs)
220     wrapper_debug_count_calls.num_calls = 0
221     return wrapper_debug_count_calls
222
223
224 class DelayWhen(enum.IntEnum):
225     BEFORE_CALL = 1
226     AFTER_CALL = 2
227     BEFORE_AND_AFTER = 3
228
229
230 def delay(
231     _func: Callable = None,
232     *,
233     seconds: float = 1.0,
234     when: DelayWhen = DelayWhen.BEFORE_CALL,
235 ) -> Callable:
236     """Delay the execution of a function by sleeping before and/or after.
237
238     Slow down a function by inserting a delay before and/or after its
239     invocation.
240
241     >>> import time
242
243     >>> @delay(seconds=1.0)
244     ... def foo():
245     ...     pass
246
247     >>> start = time.time()
248     >>> foo()
249     >>> dur = time.time() - start
250     >>> dur >= 1.0
251     True
252
253     """
254     def decorator_delay(func: Callable) -> Callable:
255         @functools.wraps(func)
256         def wrapper_delay(*args, **kwargs):
257             if when & DelayWhen.BEFORE_CALL:
258                 logger.debug(
259                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
260                 )
261                 time.sleep(seconds)
262             retval = func(*args, **kwargs)
263             if when & DelayWhen.AFTER_CALL:
264                 logger.debug(
265                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
266                 )
267                 time.sleep(seconds)
268             return retval
269         return wrapper_delay
270
271     if _func is None:
272         return decorator_delay
273     else:
274         return decorator_delay(_func)
275
276
277 class _SingletonWrapper:
278     """
279     A singleton wrapper class. Its instances would be created
280     for each decorated class.
281
282     """
283
284     def __init__(self, cls):
285         self.__wrapped__ = cls
286         self._instance = None
287
288     def __call__(self, *args, **kwargs):
289         """Returns a single instance of decorated class"""
290         logger.debug(
291             f"@singleton returning global instance of {self.__wrapped__.__name__}"
292         )
293         if self._instance is None:
294             self._instance = self.__wrapped__(*args, **kwargs)
295         return self._instance
296
297
298 def singleton(cls):
299     """
300     A singleton decorator. Returns a wrapper objects. A call on that object
301     returns a single instance object of decorated class. Use the __wrapped__
302     attribute to access decorated class directly in unit tests
303
304     >>> @singleton
305     ... class foo(object):
306     ...     pass
307
308     >>> a = foo()
309     >>> b = foo()
310     >>> a is b
311     True
312
313     >>> id(a) == id(b)
314     True
315
316     """
317     return _SingletonWrapper(cls)
318
319
320 def memoized(func: Callable) -> Callable:
321     """Keep a cache of previous function call results.
322
323     The cache here is a dict with a key based on the arguments to the
324     call.  Consider also: functools.lru_cache for a more advanced
325     implementation.
326
327     >>> import time
328
329     >>> @memoized
330     ... def expensive(arg) -> int:
331     ...     # Simulate something slow to compute or lookup
332     ...     time.sleep(1.0)
333     ...     return arg * arg
334
335     >>> start = time.time()
336     >>> expensive(5)           # Takes about 1 sec
337     25
338
339     >>> expensive(3)           # Also takes about 1 sec
340     9
341
342     >>> expensive(5)           # Pulls from cache, fast
343     25
344
345     >>> expensive(3)           # Pulls from cache again, fast
346     9
347
348     >>> dur = time.time() - start
349     >>> dur < 3.0
350     True
351
352     """
353     @functools.wraps(func)
354     def wrapper_memoized(*args, **kwargs):
355         cache_key = args + tuple(kwargs.items())
356         if cache_key not in wrapper_memoized.cache:
357             value = func(*args, **kwargs)
358             logger.debug(
359                 f"Memoizing {cache_key} => {value} for {func.__name__}"
360             )
361             wrapper_memoized.cache[cache_key] = value
362         else:
363             logger.debug(f"Returning memoized value for {func.__name__}")
364         return wrapper_memoized.cache[cache_key]
365     wrapper_memoized.cache = dict()
366     return wrapper_memoized
367
368
369 def retry_predicate(
370     tries: int,
371     *,
372     predicate: Callable[..., bool],
373     delay_sec: float = 3.0,
374     backoff: float = 2.0,
375 ):
376     """Retries a function or method up to a certain number of times
377     with a prescribed initial delay period and backoff rate.
378
379     tries is the maximum number of attempts to run the function.
380     delay_sec sets the initial delay period in seconds.
381     backoff is a multiplied (must be >1) used to modify the delay.
382     predicate is a function that will be passed the retval of the
383     decorated function and must return True to stop or False to
384     retry.
385
386     """
387     if backoff < 1.0:
388         msg = f"backoff must be greater than or equal to 1, got {backoff}"
389         logger.critical(msg)
390         raise ValueError(msg)
391
392     tries = math.floor(tries)
393     if tries < 0:
394         msg = f"tries must be 0 or greater, got {tries}"
395         logger.critical(msg)
396         raise ValueError(msg)
397
398     if delay_sec <= 0:
399         msg = f"delay_sec must be greater than 0, got {delay_sec}"
400         logger.critical(msg)
401         raise ValueError(msg)
402
403     def deco_retry(f):
404         @functools.wraps(f)
405         def f_retry(*args, **kwargs):
406             mtries, mdelay = tries, delay_sec  # make mutable
407             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
408             retval = f(*args, **kwargs)
409             while mtries > 0:
410                 if predicate(retval) is True:
411                     logger.debug('Predicate succeeded, deco_retry is done.')
412                     return retval
413                 logger.debug("Predicate failed, sleeping and retrying.")
414                 mtries -= 1
415                 time.sleep(mdelay)
416                 mdelay *= backoff
417                 retval = f(*args, **kwargs)
418             return retval
419         return f_retry
420     return deco_retry
421
422
423 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
424     """A helper for @retry_predicate that retries a decorated
425     function as long as it keeps returning False.
426
427     >>> import time
428
429     >>> counter = 0
430
431     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
432     ... def foo():
433     ...     global counter
434     ...     counter += 1
435     ...     return counter >= 3
436
437     >>> start = time.time()
438     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
439     True
440
441     >>> dur = time.time() - start
442     >>> counter
443     3
444     >>> dur > 2.0
445     True
446     >>> dur < 2.2
447     True
448
449     """
450     return retry_predicate(
451         tries,
452         predicate=lambda x: x is True,
453         delay_sec=delay_sec,
454         backoff=backoff,
455     )
456
457
458 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
459     """Another helper for @retry_predicate above.  Retries up to N
460     times so long as the wrapped function returns None with a delay
461     between each retry and a backoff that can increase the delay.
462
463     """
464     return retry_predicate(
465         tries,
466         predicate=lambda x: x is not None,
467         delay_sec=delay_sec,
468         backoff=backoff,
469     )
470
471
472 def deprecated(func):
473     """This is a decorator which can be used to mark functions
474     as deprecated. It will result in a warning being emitted
475     when the function is used.
476
477     """
478     @functools.wraps(func)
479     def wrapper_deprecated(*args, **kwargs):
480         msg = f"Call to deprecated function {func.__qualname__}"
481         logger.warning(msg)
482         warnings.warn(msg, category=DeprecationWarning)
483         print(msg, file=sys.stderr)
484         return func(*args, **kwargs)
485     return wrapper_deprecated
486
487
488 def thunkify(func):
489     """
490     Make a function immediately return a function of no args which,
491     when called, waits for the result, which will start being
492     processed in another thread.
493     """
494
495     @functools.wraps(func)
496     def lazy_thunked(*args, **kwargs):
497         wait_event = threading.Event()
498
499         result = [None]
500         exc = [False, None]
501
502         def worker_func():
503             try:
504                 func_result = func(*args, **kwargs)
505                 result[0] = func_result
506             except Exception:
507                 exc[0] = True
508                 exc[1] = sys.exc_info()  # (type, value, traceback)
509                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
510                 print(msg)
511                 logger.warning(msg)
512             finally:
513                 wait_event.set()
514
515         def thunk():
516             wait_event.wait()
517             if exc[0]:
518                 raise exc[1][0](exc[1][1])
519             return result[0]
520
521         threading.Thread(target=worker_func).start()
522         return thunk
523
524     return lazy_thunked
525
526
527 ############################################################
528 # Timeout
529 ############################################################
530
531 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
532 # Used work of Stephen "Zero" Chappell <[email protected]>
533 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
534
535
536 def _raise_exception(exception, error_message: Optional[str]):
537     if error_message is None:
538         raise Exception()
539     else:
540         raise Exception(error_message)
541
542
543 def _target(queue, function, *args, **kwargs):
544     """Run a function with arguments and return output via a queue.
545
546     This is a helper function for the Process created in _Timeout. It runs
547     the function with positional arguments and keyword arguments and then
548     returns the function's output by way of a queue. If an exception gets
549     raised, it is returned to _Timeout to be raised by the value property.
550     """
551     try:
552         queue.put((True, function(*args, **kwargs)))
553     except Exception:
554         queue.put((False, sys.exc_info()[1]))
555
556
557 class _Timeout(object):
558     """Wrap a function and add a timeout to it.
559
560     Instances of this class are automatically generated by the add_timeout
561     function defined below.  Do not use directly.
562     """
563
564     def __init__(
565         self,
566         function: Callable,
567         timeout_exception: Exception,
568         error_message: str,
569         seconds: float,
570     ):
571         self.__limit = seconds
572         self.__function = function
573         self.__timeout_exception = timeout_exception
574         self.__error_message = error_message
575         self.__name__ = function.__name__
576         self.__doc__ = function.__doc__
577         self.__timeout = time.time()
578         self.__process = multiprocessing.Process()
579         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
580
581     def __call__(self, *args, **kwargs):
582         """Execute the embedded function object asynchronously.
583
584         The function given to the constructor is transparently called and
585         requires that "ready" be intermittently polled. If and when it is
586         True, the "value" property may then be checked for returned data.
587         """
588         self.__limit = kwargs.pop("timeout", self.__limit)
589         self.__queue = multiprocessing.Queue(1)
590         args = (self.__queue, self.__function) + args
591         self.__process = multiprocessing.Process(
592             target=_target, args=args, kwargs=kwargs
593         )
594         self.__process.daemon = True
595         self.__process.start()
596         if self.__limit is not None:
597             self.__timeout = self.__limit + time.time()
598         while not self.ready:
599             time.sleep(0.1)
600         return self.value
601
602     def cancel(self):
603         """Terminate any possible execution of the embedded function."""
604         if self.__process.is_alive():
605             self.__process.terminate()
606         _raise_exception(self.__timeout_exception, self.__error_message)
607
608     @property
609     def ready(self):
610         """Read-only property indicating status of "value" property."""
611         if self.__limit and self.__timeout < time.time():
612             self.cancel()
613         return self.__queue.full() and not self.__queue.empty()
614
615     @property
616     def value(self):
617         """Read-only property containing data returned from function."""
618         if self.ready is True:
619             flag, load = self.__queue.get()
620             if flag:
621                 return load
622             raise load
623
624
625 def timeout(
626     seconds: float = 1.0,
627     use_signals: Optional[bool] = None,
628     timeout_exception=exceptions.TimeoutError,
629     error_message="Function call timed out",
630 ):
631     """Add a timeout parameter to a function and return the function.
632
633     Note: the use_signals parameter is included in order to support
634     multiprocessing scenarios (signal can only be used from the process'
635     main thread).  When not using signals, timeout granularity will be
636     rounded to the nearest 0.1s.
637
638     Raises an exception when/if the timeout is reached.
639
640     It is illegal to pass anything other than a function as the first
641     parameter.  The function is wrapped and returned to the caller.
642
643     >>> @timeout(0.2)
644     ... def foo(delay: float):
645     ...     time.sleep(delay)
646     ...     return "ok"
647
648     >>> foo(0)
649     'ok'
650
651     >>> foo(1.0)
652     Traceback (most recent call last):
653     ...
654     Exception: Function call timed out
655
656     """
657     if use_signals is None:
658         import thread_utils
659         use_signals = thread_utils.is_current_thread_main_thread()
660
661     def decorate(function):
662         if use_signals:
663
664             def handler(signum, frame):
665                 _raise_exception(timeout_exception, error_message)
666
667             @functools.wraps(function)
668             def new_function(*args, **kwargs):
669                 new_seconds = kwargs.pop("timeout", seconds)
670                 if new_seconds:
671                     old = signal.signal(signal.SIGALRM, handler)
672                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
673
674                 if not seconds:
675                     return function(*args, **kwargs)
676
677                 try:
678                     return function(*args, **kwargs)
679                 finally:
680                     if new_seconds:
681                         signal.setitimer(signal.ITIMER_REAL, 0)
682                         signal.signal(signal.SIGALRM, old)
683
684             return new_function
685         else:
686
687             @functools.wraps(function)
688             def new_function(*args, **kwargs):
689                 timeout_wrapper = _Timeout(
690                     function, timeout_exception, error_message, seconds
691                 )
692                 return timeout_wrapper(*args, **kwargs)
693
694             return new_function
695
696     return decorate
697
698
699 class non_reentrant_code(object):
700     def __init__(self):
701         self._lock = threading.RLock
702         self._entered = False
703
704     def __call__(self, f):
705         def _gatekeeper(*args, **kwargs):
706             with self._lock:
707                 if self._entered:
708                     return
709                 self._entered = True
710                 f(*args, **kwargs)
711                 self._entered = False
712         return _gatekeeper
713
714
715 class rlocked(object):
716     def __init__(self):
717         self._lock = threading.RLock
718         self._entered = False
719
720     def __call__(self, f):
721         def _gatekeeper(*args, **kwargs):
722             with self._lock:
723                 if self._entered:
724                     return
725                 self._entered = True
726                 f(*args, **kwargs)
727                 self._entered = False
728         return _gatekeeper
729
730
731 def call_with_sample_rate(sample_rate: float) -> Callable:
732     if not 0.0 <= sample_rate <= 1.0:
733         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
734         logger.critical(msg)
735         raise ValueError(msg)
736
737     def decorator(f):
738         @functools.wraps(f)
739         def _call_with_sample_rate(*args, **kwargs):
740             if random.uniform(0, 1) < sample_rate:
741                 return f(*args, **kwargs)
742             else:
743                 logger.debug(
744                     f"@call_with_sample_rate skipping a call to {f.__name__}"
745                 )
746         return _call_with_sample_rate
747     return decorator
748
749
750 def decorate_matching_methods_with(decorator, acl=None):
751     """Apply decorator to all methods in a class whose names begin with
752     prefix.  If prefix is None (default), decorate all methods in the
753     class.
754     """
755     def decorate_the_class(cls):
756         for name, m in inspect.getmembers(cls, inspect.isfunction):
757             if acl is None:
758                 setattr(cls, name, decorator(m))
759             else:
760                 if acl(name):
761                     setattr(cls, name, decorator(m))
762         return cls
763     return decorate_the_class
764
765
766 if __name__ == '__main__':
767     import doctest
768     doctest.testmod()
769