Adds some doctests to decorators.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function.
30
31     >>> @timed
32     ... def foo():
33     ...     import time
34     ...     time.sleep(0.1)
35
36     >>> foo()  # doctest: +ELLIPSIS
37     Finished foo in ...
38
39     """
40
41     @functools.wraps(func)
42     def wrapper_timer(*args, **kwargs):
43         start_time = time.perf_counter()
44         value = func(*args, **kwargs)
45         end_time = time.perf_counter()
46         run_time = end_time - start_time
47         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
48         print(msg)
49         logger.info(msg)
50         return value
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78     return wrapper_invocation_logged
79
80
81 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
82     """Limit invocation of a wrapped function to n calls per period.
83     Thread safe.  In testing this was relatively fair with multiple
84     threads using it though that hasn't been measured.
85
86     >>> import time
87     >>> import decorator_utils
88     >>> import thread_utils
89
90     >>> calls = 0
91
92     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
93     ... def limited(x: int):
94     ...     global calls
95     ...     calls += 1
96
97     >>> @thread_utils.background_thread
98     ... def a(stop):
99     ...     for _ in range(3):
100     ...         limited(_)
101
102     >>> @thread_utils.background_thread
103     ... def b(stop):
104     ...     for _ in range(3):
105     ...         limited(_)
106
107     >>> start = time.time()
108     >>> (t1, e1) = a()
109     >>> (t2, e2) = b()
110     >>> t1.join()
111     >>> t2.join()
112     >>> end = time.time()
113     >>> dur = end - start
114     >>> dur > 0.5
115     True
116
117     >>> calls
118     6
119
120     """
121     min_interval_seconds = per_period_in_seconds / float(n_calls)
122
123     def wrapper_rate_limited(func: Callable) -> Callable:
124         cv = threading.Condition()
125         last_invocation_timestamp = [0.0]
126
127         def may_proceed() -> float:
128             now = time.time()
129             last_invocation = last_invocation_timestamp[0]
130             if last_invocation != 0.0:
131                 elapsed_since_last = now - last_invocation
132                 wait_time = min_interval_seconds - elapsed_since_last
133             else:
134                 wait_time = 0.0
135             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
136             return wait_time
137
138         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
139             with cv:
140                 while True:
141                     if cv.wait_for(
142                         lambda: may_proceed() <= 0.0,
143                         timeout=may_proceed(),
144                     ):
145                         break
146             with cv:
147                 logger.debug(f'@{time.time()}> calling it...')
148                 ret = func(*args, **kargs)
149                 last_invocation_timestamp[0] = time.time()
150                 logger.debug(
151                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
152                 )
153                 cv.notify()
154             return ret
155         return wrapper_wrapper_rate_limited
156     return wrapper_rate_limited
157
158
159 def debug_args(func: Callable) -> Callable:
160     """Print the function signature and return value at each call.
161
162     >>> @debug_args
163     ... def foo(a, b, c):
164     ...     print(a)
165     ...     print(b)
166     ...     print(c)
167     ...     return (a + b, c)
168
169     >>> foo(1, 2.0, "test")
170     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
171     1
172     2.0
173     test
174     foo returned (3.0, 'test'):<class 'tuple'>
175     (3.0, 'test')
176     """
177
178     @functools.wraps(func)
179     def wrapper_debug_args(*args, **kwargs):
180         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
181         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
182         signature = ", ".join(args_repr + kwargs_repr)
183         msg = f"Calling {func.__qualname__}({signature})"
184         print(msg)
185         logger.info(msg)
186         value = func(*args, **kwargs)
187         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
188         print(msg)
189         logger.info(msg)
190         return value
191     return wrapper_debug_args
192
193
194 def debug_count_calls(func: Callable) -> Callable:
195     """Count function invocations and print a message befor every call.
196
197     >>> @debug_count_calls
198     ... def factoral(x):
199     ...     if x == 1:
200     ...         return 1
201     ...     return x * factoral(x - 1)
202
203     >>> factoral(5)
204     Call #1 of 'factoral'
205     Call #2 of 'factoral'
206     Call #3 of 'factoral'
207     Call #4 of 'factoral'
208     Call #5 of 'factoral'
209     120
210
211     """
212
213     @functools.wraps(func)
214     def wrapper_debug_count_calls(*args, **kwargs):
215         wrapper_debug_count_calls.num_calls += 1
216         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
217         print(msg)
218         logger.info(msg)
219         return func(*args, **kwargs)
220     wrapper_debug_count_calls.num_calls = 0
221     return wrapper_debug_count_calls
222
223
224 class DelayWhen(enum.IntEnum):
225     BEFORE_CALL = 1
226     AFTER_CALL = 2
227     BEFORE_AND_AFTER = 3
228
229
230 def delay(
231     _func: Callable = None,
232     *,
233     seconds: float = 1.0,
234     when: DelayWhen = DelayWhen.BEFORE_CALL,
235 ) -> Callable:
236     """Delay the execution of a function by sleeping before and/or after.
237
238     Slow down a function by inserting a delay before and/or after its
239     invocation.
240
241     >>> import time
242
243     >>> @delay(seconds=1.0)
244     ... def foo():
245     ...     pass
246
247     >>> start = time.time()
248     >>> foo()
249     >>> dur = time.time() - start
250     >>> dur >= 1.0
251     True
252
253     """
254     def decorator_delay(func: Callable) -> Callable:
255         @functools.wraps(func)
256         def wrapper_delay(*args, **kwargs):
257             if when & DelayWhen.BEFORE_CALL:
258                 logger.debug(
259                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
260                 )
261                 time.sleep(seconds)
262             retval = func(*args, **kwargs)
263             if when & DelayWhen.AFTER_CALL:
264                 logger.debug(
265                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
266                 )
267                 time.sleep(seconds)
268             return retval
269         return wrapper_delay
270
271     if _func is None:
272         return decorator_delay
273     else:
274         return decorator_delay(_func)
275
276
277 class _SingletonWrapper:
278     """
279     A singleton wrapper class. Its instances would be created
280     for each decorated class.
281
282     """
283
284     def __init__(self, cls):
285         self.__wrapped__ = cls
286         self._instance = None
287
288     def __call__(self, *args, **kwargs):
289         """Returns a single instance of decorated class"""
290         logger.debug(
291             f"@singleton returning global instance of {self.__wrapped__.__name__}"
292         )
293         if self._instance is None:
294             self._instance = self.__wrapped__(*args, **kwargs)
295         return self._instance
296
297
298 def singleton(cls):
299     """
300     A singleton decorator. Returns a wrapper objects. A call on that object
301     returns a single instance object of decorated class. Use the __wrapped__
302     attribute to access decorated class directly in unit tests
303
304     >>> @singleton
305     ... class foo(object):
306     ...     pass
307
308     >>> a = foo()
309     >>> b = foo()
310     >>> a is b
311     True
312
313     >>> id(a) == id(b)
314     True
315
316     """
317     return _SingletonWrapper(cls)
318
319
320 def memoized(func: Callable) -> Callable:
321     """Keep a cache of previous function call results.
322
323     The cache here is a dict with a key based on the arguments to the
324     call.  Consider also: functools.lru_cache for a more advanced
325     implementation.
326
327     >>> import time
328
329     >>> @memoized
330     ... def expensive(arg) -> int:
331     ...     # Simulate something slow to compute or lookup
332     ...     time.sleep(1.0)
333     ...     return arg * arg
334
335     >>> start = time.time()
336     >>> expensive(5)           # Takes about 1 sec
337     25
338
339     >>> expensive(3)           # Also takes about 1 sec
340     9
341
342     >>> expensive(5)           # Pulls from cache, fast
343     25
344
345     >>> expensive(3)           # Pulls from cache again, fast
346     9
347
348     >>> dur = time.time() - start
349     >>> dur < 3.0
350     True
351
352     """
353     @functools.wraps(func)
354     def wrapper_memoized(*args, **kwargs):
355         cache_key = args + tuple(kwargs.items())
356         if cache_key not in wrapper_memoized.cache:
357             value = func(*args, **kwargs)
358             logger.debug(
359                 f"Memoizing {cache_key} => {value} for {func.__name__}"
360             )
361             wrapper_memoized.cache[cache_key] = value
362         else:
363             logger.debug(f"Returning memoized value for {func.__name__}")
364         return wrapper_memoized.cache[cache_key]
365     wrapper_memoized.cache = dict()
366     return wrapper_memoized
367
368
369 def retry_predicate(
370     tries: int,
371     *,
372     predicate: Callable[..., bool],
373     delay_sec: float = 3.0,
374     backoff: float = 2.0,
375 ):
376     """Retries a function or method up to a certain number of times
377     with a prescribed initial delay period and backoff rate.
378
379     tries is the maximum number of attempts to run the function.
380     delay_sec sets the initial delay period in seconds.
381     backoff is a multiplied (must be >1) used to modify the delay.
382     predicate is a function that will be passed the retval of the
383     decorated function and must return True to stop or False to
384     retry.
385
386     """
387     if backoff < 1.0:
388         msg = f"backoff must be greater than or equal to 1, got {backoff}"
389         logger.critical(msg)
390         raise ValueError(msg)
391
392     tries = math.floor(tries)
393     if tries < 0:
394         msg = f"tries must be 0 or greater, got {tries}"
395         logger.critical(msg)
396         raise ValueError(msg)
397
398     if delay_sec <= 0:
399         msg = f"delay_sec must be greater than 0, got {delay_sec}"
400         logger.critical(msg)
401         raise ValueError(msg)
402
403     def deco_retry(f):
404         @functools.wraps(f)
405         def f_retry(*args, **kwargs):
406             mtries, mdelay = tries, delay_sec  # make mutable
407             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
408             retval = f(*args, **kwargs)
409             while mtries > 0:
410                 if predicate(retval) is True:
411                     logger.debug('Predicate succeeded, deco_retry is done.')
412                     return retval
413                 logger.debug("Predicate failed, sleeping and retrying.")
414                 mtries -= 1
415                 time.sleep(mdelay)
416                 mdelay *= backoff
417                 retval = f(*args, **kwargs)
418             return retval
419         return f_retry
420     return deco_retry
421
422
423 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
424     """A helper for @retry_predicate that retries a decorated
425     function as long as it keeps returning False.
426
427     >>> import time
428
429     >>> counter = 0
430
431     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
432     ... def foo():
433     ...     global counter
434     ...     counter += 1
435     ...     return counter >= 3
436
437     >>> start = time.time()
438     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
439     True
440
441     >>> dur = time.time() - start
442     >>> counter
443     3
444     >>> dur > 2.0
445     True
446     >>> dur < 2.2
447     True
448
449     """
450     return retry_predicate(
451         tries,
452         predicate=lambda x: x is True,
453         delay_sec=delay_sec,
454         backoff=backoff,
455     )
456
457
458 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
459     """Another helper for @retry_predicate above.  Retries up to N
460     times so long as the wrapped function returns None with a delay
461     between each retry and a backoff that can increase the delay.
462
463     """
464     return retry_predicate(
465         tries,
466         predicate=lambda x: x is not None,
467         delay_sec=delay_sec,
468         backoff=backoff,
469     )
470
471
472 def deprecated(func):
473     """This is a decorator which can be used to mark functions
474     as deprecated. It will result in a warning being emitted
475     when the function is used.
476
477     """
478     @functools.wraps(func)
479     def wrapper_deprecated(*args, **kwargs):
480         msg = f"Call to deprecated function {func.__qualname__}"
481         logger.warning(msg)
482         warnings.warn(msg, category=DeprecationWarning)
483         print(msg, file=sys.stderr)
484         return func(*args, **kwargs)
485     return wrapper_deprecated
486
487
488 def thunkify(func):
489     """
490     Make a function immediately return a function of no args which,
491     when called, waits for the result, which will start being
492     processed in another thread.
493     """
494
495     @functools.wraps(func)
496     def lazy_thunked(*args, **kwargs):
497         wait_event = threading.Event()
498
499         result = [None]
500         exc = [False, None]
501
502         def worker_func():
503             try:
504                 func_result = func(*args, **kwargs)
505                 result[0] = func_result
506             except Exception:
507                 exc[0] = True
508                 exc[1] = sys.exc_info()  # (type, value, traceback)
509                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
510                 print(msg)
511                 logger.warning(msg)
512             finally:
513                 wait_event.set()
514
515         def thunk():
516             wait_event.wait()
517             if exc[0]:
518                 raise exc[1][0](exc[1][1])
519             return result[0]
520
521         threading.Thread(target=worker_func).start()
522         return thunk
523
524     return lazy_thunked
525
526
527 ############################################################
528 # Timeout
529 ############################################################
530
531 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
532 # Used work of Stephen "Zero" Chappell <[email protected]>
533 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
534
535
536 def _raise_exception(exception, error_message: Optional[str]):
537     if error_message is None:
538         raise exception()
539     else:
540         raise exception(error_message)
541
542
543 def _target(queue, function, *args, **kwargs):
544     """Run a function with arguments and return output via a queue.
545
546     This is a helper function for the Process created in _Timeout. It runs
547     the function with positional arguments and keyword arguments and then
548     returns the function's output by way of a queue. If an exception gets
549     raised, it is returned to _Timeout to be raised by the value property.
550     """
551     try:
552         queue.put((True, function(*args, **kwargs)))
553     except Exception:
554         queue.put((False, sys.exc_info()[1]))
555
556
557 class _Timeout(object):
558     """Wrap a function and add a timeout (limit) attribute to it.
559
560     Instances of this class are automatically generated by the add_timeout
561     function defined below.
562     """
563
564     def __init__(
565         self,
566         function: Callable,
567         timeout_exception: Exception,
568         error_message: str,
569         seconds: float,
570     ):
571         self.__limit = seconds
572         self.__function = function
573         self.__timeout_exception = timeout_exception
574         self.__error_message = error_message
575         self.__name__ = function.__name__
576         self.__doc__ = function.__doc__
577         self.__timeout = time.time()
578         self.__process = multiprocessing.Process()
579         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
580
581     def __call__(self, *args, **kwargs):
582         """Execute the embedded function object asynchronously.
583
584         The function given to the constructor is transparently called and
585         requires that "ready" be intermittently polled. If and when it is
586         True, the "value" property may then be checked for returned data.
587         """
588         self.__limit = kwargs.pop("timeout", self.__limit)
589         self.__queue = multiprocessing.Queue(1)
590         args = (self.__queue, self.__function) + args
591         self.__process = multiprocessing.Process(
592             target=_target, args=args, kwargs=kwargs
593         )
594         self.__process.daemon = True
595         self.__process.start()
596         if self.__limit is not None:
597             self.__timeout = self.__limit + time.time()
598         while not self.ready:
599             time.sleep(0.1)
600         return self.value
601
602     def cancel(self):
603         """Terminate any possible execution of the embedded function."""
604         if self.__process.is_alive():
605             self.__process.terminate()
606         _raise_exception(self.__timeout_exception, self.__error_message)
607
608     @property
609     def ready(self):
610         """Read-only property indicating status of "value" property."""
611         if self.__limit and self.__timeout < time.time():
612             self.cancel()
613         return self.__queue.full() and not self.__queue.empty()
614
615     @property
616     def value(self):
617         """Read-only property containing data returned from function."""
618         if self.ready is True:
619             flag, load = self.__queue.get()
620             if flag:
621                 return load
622             raise load
623
624
625 def timeout(
626     seconds: float = 1.0,
627     use_signals: Optional[bool] = None,
628     timeout_exception=exceptions.TimeoutError,
629     error_message="Function call timed out",
630 ):
631     """Add a timeout parameter to a function and return the function.
632
633     Note: the use_signals parameter is included in order to support
634     multiprocessing scenarios (signal can only be used from the process'
635     main thread).  When not using signals, timeout granularity will be
636     rounded to the nearest 0.1s.
637
638     Raises an exception when the timeout is reached.
639
640     It is illegal to pass anything other than a function as the first
641     parameter.  The function is wrapped and returned to the caller.
642     """
643     if use_signals is None:
644         import thread_utils
645         use_signals = thread_utils.is_current_thread_main_thread()
646
647     def decorate(function):
648         if use_signals:
649
650             def handler(signum, frame):
651                 _raise_exception(timeout_exception, error_message)
652
653             @functools.wraps(function)
654             def new_function(*args, **kwargs):
655                 new_seconds = kwargs.pop("timeout", seconds)
656                 if new_seconds:
657                     old = signal.signal(signal.SIGALRM, handler)
658                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
659
660                 if not seconds:
661                     return function(*args, **kwargs)
662
663                 try:
664                     return function(*args, **kwargs)
665                 finally:
666                     if new_seconds:
667                         signal.setitimer(signal.ITIMER_REAL, 0)
668                         signal.signal(signal.SIGALRM, old)
669
670             return new_function
671         else:
672
673             @functools.wraps(function)
674             def new_function(*args, **kwargs):
675                 timeout_wrapper = _Timeout(
676                     function, timeout_exception, error_message, seconds
677                 )
678                 return timeout_wrapper(*args, **kwargs)
679
680             return new_function
681
682     return decorate
683
684
685 class non_reentrant_code(object):
686     def __init__(self):
687         self._lock = threading.RLock
688         self._entered = False
689
690     def __call__(self, f):
691         def _gatekeeper(*args, **kwargs):
692             with self._lock:
693                 if self._entered:
694                     return
695                 self._entered = True
696                 f(*args, **kwargs)
697                 self._entered = False
698
699         return _gatekeeper
700
701
702 class rlocked(object):
703     def __init__(self):
704         self._lock = threading.RLock
705         self._entered = False
706
707     def __call__(self, f):
708         def _gatekeeper(*args, **kwargs):
709             with self._lock:
710                 if self._entered:
711                     return
712                 self._entered = True
713                 f(*args, **kwargs)
714                 self._entered = False
715         return _gatekeeper
716
717
718 def call_with_sample_rate(sample_rate: float) -> Callable:
719     if not 0.0 <= sample_rate <= 1.0:
720         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
721         logger.critical(msg)
722         raise ValueError(msg)
723
724     def decorator(f):
725         @functools.wraps(f)
726         def _call_with_sample_rate(*args, **kwargs):
727             if random.uniform(0, 1) < sample_rate:
728                 return f(*args, **kwargs)
729             else:
730                 logger.debug(
731                     f"@call_with_sample_rate skipping a call to {f.__name__}"
732                 )
733         return _call_with_sample_rate
734     return decorator
735
736
737 def decorate_matching_methods_with(decorator, acl=None):
738     """Apply decorator to all methods in a class whose names begin with
739     prefix.  If prefix is None (default), decorate all methods in the
740     class.
741     """
742     def decorate_the_class(cls):
743         for name, m in inspect.getmembers(cls, inspect.isfunction):
744             if acl is None:
745                 setattr(cls, name, decorator(m))
746             else:
747                 if acl(name):
748                     setattr(cls, name, decorator(m))
749         return cls
750     return decorate_the_class
751
752
753 if __name__ == '__main__':
754     import doctest
755     doctest.testmod()
756