Ran black code formatter on everything.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function.
30
31     >>> @timed
32     ... def foo():
33     ...     import time
34     ...     time.sleep(0.1)
35
36     >>> foo()  # doctest: +ELLIPSIS
37     Finished foo in ...
38
39     """
40
41     @functools.wraps(func)
42     def wrapper_timer(*args, **kwargs):
43         start_time = time.perf_counter()
44         value = func(*args, **kwargs)
45         end_time = time.perf_counter()
46         run_time = end_time - start_time
47         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
48         print(msg)
49         logger.info(msg)
50         return value
51
52     return wrapper_timer
53
54
55 def invocation_logged(func: Callable) -> Callable:
56     """Log the call of a function.
57
58     >>> @invocation_logged
59     ... def foo():
60     ...     print('Hello, world.')
61
62     >>> foo()
63     Entered foo
64     Hello, world.
65     Exited foo
66
67     """
68
69     @functools.wraps(func)
70     def wrapper_invocation_logged(*args, **kwargs):
71         msg = f"Entered {func.__qualname__}"
72         print(msg)
73         logger.info(msg)
74         ret = func(*args, **kwargs)
75         msg = f"Exited {func.__qualname__}"
76         print(msg)
77         logger.info(msg)
78         return ret
79
80     return wrapper_invocation_logged
81
82
83 def rate_limited(
84     n_calls: int, *, per_period_in_seconds: float = 1.0
85 ) -> Callable:
86     """Limit invocation of a wrapped function to n calls per period.
87     Thread safe.  In testing this was relatively fair with multiple
88     threads using it though that hasn't been measured.
89
90     >>> import time
91     >>> import decorator_utils
92     >>> import thread_utils
93
94     >>> calls = 0
95
96     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97     ... def limited(x: int):
98     ...     global calls
99     ...     calls += 1
100
101     >>> @thread_utils.background_thread
102     ... def a(stop):
103     ...     for _ in range(3):
104     ...         limited(_)
105
106     >>> @thread_utils.background_thread
107     ... def b(stop):
108     ...     for _ in range(3):
109     ...         limited(_)
110
111     >>> start = time.time()
112     >>> (t1, e1) = a()
113     >>> (t2, e2) = b()
114     >>> t1.join()
115     >>> t2.join()
116     >>> end = time.time()
117     >>> dur = end - start
118     >>> dur > 0.5
119     True
120
121     >>> calls
122     6
123
124     """
125     min_interval_seconds = per_period_in_seconds / float(n_calls)
126
127     def wrapper_rate_limited(func: Callable) -> Callable:
128         cv = threading.Condition()
129         last_invocation_timestamp = [0.0]
130
131         def may_proceed() -> float:
132             now = time.time()
133             last_invocation = last_invocation_timestamp[0]
134             if last_invocation != 0.0:
135                 elapsed_since_last = now - last_invocation
136                 wait_time = min_interval_seconds - elapsed_since_last
137             else:
138                 wait_time = 0.0
139             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
140             return wait_time
141
142         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143             with cv:
144                 while True:
145                     if cv.wait_for(
146                         lambda: may_proceed() <= 0.0,
147                         timeout=may_proceed(),
148                     ):
149                         break
150             with cv:
151                 logger.debug(f'@{time.time()}> calling it...')
152                 ret = func(*args, **kargs)
153                 last_invocation_timestamp[0] = time.time()
154                 logger.debug(
155                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
156                 )
157                 cv.notify()
158             return ret
159
160         return wrapper_wrapper_rate_limited
161
162     return wrapper_rate_limited
163
164
165 def debug_args(func: Callable) -> Callable:
166     """Print the function signature and return value at each call.
167
168     >>> @debug_args
169     ... def foo(a, b, c):
170     ...     print(a)
171     ...     print(b)
172     ...     print(c)
173     ...     return (a + b, c)
174
175     >>> foo(1, 2.0, "test")
176     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
177     1
178     2.0
179     test
180     foo returned (3.0, 'test'):<class 'tuple'>
181     (3.0, 'test')
182     """
183
184     @functools.wraps(func)
185     def wrapper_debug_args(*args, **kwargs):
186         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188         signature = ", ".join(args_repr + kwargs_repr)
189         msg = f"Calling {func.__qualname__}({signature})"
190         print(msg)
191         logger.info(msg)
192         value = func(*args, **kwargs)
193         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
194         print(msg)
195         logger.info(msg)
196         return value
197
198     return wrapper_debug_args
199
200
201 def debug_count_calls(func: Callable) -> Callable:
202     """Count function invocations and print a message befor every call.
203
204     >>> @debug_count_calls
205     ... def factoral(x):
206     ...     if x == 1:
207     ...         return 1
208     ...     return x * factoral(x - 1)
209
210     >>> factoral(5)
211     Call #1 of 'factoral'
212     Call #2 of 'factoral'
213     Call #3 of 'factoral'
214     Call #4 of 'factoral'
215     Call #5 of 'factoral'
216     120
217
218     """
219
220     @functools.wraps(func)
221     def wrapper_debug_count_calls(*args, **kwargs):
222         wrapper_debug_count_calls.num_calls += 1
223         msg = (
224             f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
225         )
226         print(msg)
227         logger.info(msg)
228         return func(*args, **kwargs)
229
230     wrapper_debug_count_calls.num_calls = 0
231     return wrapper_debug_count_calls
232
233
234 class DelayWhen(enum.IntEnum):
235     BEFORE_CALL = 1
236     AFTER_CALL = 2
237     BEFORE_AND_AFTER = 3
238
239
240 def delay(
241     _func: Callable = None,
242     *,
243     seconds: float = 1.0,
244     when: DelayWhen = DelayWhen.BEFORE_CALL,
245 ) -> Callable:
246     """Delay the execution of a function by sleeping before and/or after.
247
248     Slow down a function by inserting a delay before and/or after its
249     invocation.
250
251     >>> import time
252
253     >>> @delay(seconds=1.0)
254     ... def foo():
255     ...     pass
256
257     >>> start = time.time()
258     >>> foo()
259     >>> dur = time.time() - start
260     >>> dur >= 1.0
261     True
262
263     """
264
265     def decorator_delay(func: Callable) -> Callable:
266         @functools.wraps(func)
267         def wrapper_delay(*args, **kwargs):
268             if when & DelayWhen.BEFORE_CALL:
269                 logger.debug(
270                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
271                 )
272                 time.sleep(seconds)
273             retval = func(*args, **kwargs)
274             if when & DelayWhen.AFTER_CALL:
275                 logger.debug(
276                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
277                 )
278                 time.sleep(seconds)
279             return retval
280
281         return wrapper_delay
282
283     if _func is None:
284         return decorator_delay
285     else:
286         return decorator_delay(_func)
287
288
289 class _SingletonWrapper:
290     """
291     A singleton wrapper class. Its instances would be created
292     for each decorated class.
293
294     """
295
296     def __init__(self, cls):
297         self.__wrapped__ = cls
298         self._instance = None
299
300     def __call__(self, *args, **kwargs):
301         """Returns a single instance of decorated class"""
302         logger.debug(
303             f"@singleton returning global instance of {self.__wrapped__.__name__}"
304         )
305         if self._instance is None:
306             self._instance = self.__wrapped__(*args, **kwargs)
307         return self._instance
308
309
310 def singleton(cls):
311     """
312     A singleton decorator. Returns a wrapper objects. A call on that object
313     returns a single instance object of decorated class. Use the __wrapped__
314     attribute to access decorated class directly in unit tests
315
316     >>> @singleton
317     ... class foo(object):
318     ...     pass
319
320     >>> a = foo()
321     >>> b = foo()
322     >>> a is b
323     True
324
325     >>> id(a) == id(b)
326     True
327
328     """
329     return _SingletonWrapper(cls)
330
331
332 def memoized(func: Callable) -> Callable:
333     """Keep a cache of previous function call results.
334
335     The cache here is a dict with a key based on the arguments to the
336     call.  Consider also: functools.lru_cache for a more advanced
337     implementation.
338
339     >>> import time
340
341     >>> @memoized
342     ... def expensive(arg) -> int:
343     ...     # Simulate something slow to compute or lookup
344     ...     time.sleep(1.0)
345     ...     return arg * arg
346
347     >>> start = time.time()
348     >>> expensive(5)           # Takes about 1 sec
349     25
350
351     >>> expensive(3)           # Also takes about 1 sec
352     9
353
354     >>> expensive(5)           # Pulls from cache, fast
355     25
356
357     >>> expensive(3)           # Pulls from cache again, fast
358     9
359
360     >>> dur = time.time() - start
361     >>> dur < 3.0
362     True
363
364     """
365
366     @functools.wraps(func)
367     def wrapper_memoized(*args, **kwargs):
368         cache_key = args + tuple(kwargs.items())
369         if cache_key not in wrapper_memoized.cache:
370             value = func(*args, **kwargs)
371             logger.debug(
372                 f"Memoizing {cache_key} => {value} for {func.__name__}"
373             )
374             wrapper_memoized.cache[cache_key] = value
375         else:
376             logger.debug(f"Returning memoized value for {func.__name__}")
377         return wrapper_memoized.cache[cache_key]
378
379     wrapper_memoized.cache = dict()
380     return wrapper_memoized
381
382
383 def retry_predicate(
384     tries: int,
385     *,
386     predicate: Callable[..., bool],
387     delay_sec: float = 3.0,
388     backoff: float = 2.0,
389 ):
390     """Retries a function or method up to a certain number of times
391     with a prescribed initial delay period and backoff rate.
392
393     tries is the maximum number of attempts to run the function.
394     delay_sec sets the initial delay period in seconds.
395     backoff is a multiplied (must be >1) used to modify the delay.
396     predicate is a function that will be passed the retval of the
397     decorated function and must return True to stop or False to
398     retry.
399
400     """
401     if backoff < 1.0:
402         msg = f"backoff must be greater than or equal to 1, got {backoff}"
403         logger.critical(msg)
404         raise ValueError(msg)
405
406     tries = math.floor(tries)
407     if tries < 0:
408         msg = f"tries must be 0 or greater, got {tries}"
409         logger.critical(msg)
410         raise ValueError(msg)
411
412     if delay_sec <= 0:
413         msg = f"delay_sec must be greater than 0, got {delay_sec}"
414         logger.critical(msg)
415         raise ValueError(msg)
416
417     def deco_retry(f):
418         @functools.wraps(f)
419         def f_retry(*args, **kwargs):
420             mtries, mdelay = tries, delay_sec  # make mutable
421             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
422             retval = f(*args, **kwargs)
423             while mtries > 0:
424                 if predicate(retval) is True:
425                     logger.debug('Predicate succeeded, deco_retry is done.')
426                     return retval
427                 logger.debug("Predicate failed, sleeping and retrying.")
428                 mtries -= 1
429                 time.sleep(mdelay)
430                 mdelay *= backoff
431                 retval = f(*args, **kwargs)
432             return retval
433
434         return f_retry
435
436     return deco_retry
437
438
439 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
440     """A helper for @retry_predicate that retries a decorated
441     function as long as it keeps returning False.
442
443     >>> import time
444
445     >>> counter = 0
446
447     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
448     ... def foo():
449     ...     global counter
450     ...     counter += 1
451     ...     return counter >= 3
452
453     >>> start = time.time()
454     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
455     True
456
457     >>> dur = time.time() - start
458     >>> counter
459     3
460     >>> dur > 2.0
461     True
462     >>> dur < 2.3
463     True
464
465     """
466     return retry_predicate(
467         tries,
468         predicate=lambda x: x is True,
469         delay_sec=delay_sec,
470         backoff=backoff,
471     )
472
473
474 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
475     """Another helper for @retry_predicate above.  Retries up to N
476     times so long as the wrapped function returns None with a delay
477     between each retry and a backoff that can increase the delay.
478
479     """
480     return retry_predicate(
481         tries,
482         predicate=lambda x: x is not None,
483         delay_sec=delay_sec,
484         backoff=backoff,
485     )
486
487
488 def deprecated(func):
489     """This is a decorator which can be used to mark functions
490     as deprecated. It will result in a warning being emitted
491     when the function is used.
492
493     """
494
495     @functools.wraps(func)
496     def wrapper_deprecated(*args, **kwargs):
497         msg = f"Call to deprecated function {func.__qualname__}"
498         logger.warning(msg)
499         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
500         print(msg, file=sys.stderr)
501         return func(*args, **kwargs)
502
503     return wrapper_deprecated
504
505
506 def thunkify(func):
507     """
508     Make a function immediately return a function of no args which,
509     when called, waits for the result, which will start being
510     processed in another thread.
511     """
512
513     @functools.wraps(func)
514     def lazy_thunked(*args, **kwargs):
515         wait_event = threading.Event()
516
517         result = [None]
518         exc = [False, None]
519
520         def worker_func():
521             try:
522                 func_result = func(*args, **kwargs)
523                 result[0] = func_result
524             except Exception:
525                 exc[0] = True
526                 exc[1] = sys.exc_info()  # (type, value, traceback)
527                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
528                 logger.warning(msg)
529             finally:
530                 wait_event.set()
531
532         def thunk():
533             wait_event.wait()
534             if exc[0]:
535                 raise exc[1][0](exc[1][1])
536             return result[0]
537
538         threading.Thread(target=worker_func).start()
539         return thunk
540
541     return lazy_thunked
542
543
544 ############################################################
545 # Timeout
546 ############################################################
547
548 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
549 # Used work of Stephen "Zero" Chappell <[email protected]>
550 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
551
552
553 def _raise_exception(exception, error_message: Optional[str]):
554     if error_message is None:
555         raise Exception()
556     else:
557         raise Exception(error_message)
558
559
560 def _target(queue, function, *args, **kwargs):
561     """Run a function with arguments and return output via a queue.
562
563     This is a helper function for the Process created in _Timeout. It runs
564     the function with positional arguments and keyword arguments and then
565     returns the function's output by way of a queue. If an exception gets
566     raised, it is returned to _Timeout to be raised by the value property.
567     """
568     try:
569         queue.put((True, function(*args, **kwargs)))
570     except Exception:
571         queue.put((False, sys.exc_info()[1]))
572
573
574 class _Timeout(object):
575     """Wrap a function and add a timeout to it.
576
577     Instances of this class are automatically generated by the add_timeout
578     function defined below.  Do not use directly.
579     """
580
581     def __init__(
582         self,
583         function: Callable,
584         timeout_exception: Exception,
585         error_message: str,
586         seconds: float,
587     ):
588         self.__limit = seconds
589         self.__function = function
590         self.__timeout_exception = timeout_exception
591         self.__error_message = error_message
592         self.__name__ = function.__name__
593         self.__doc__ = function.__doc__
594         self.__timeout = time.time()
595         self.__process = multiprocessing.Process()
596         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
597
598     def __call__(self, *args, **kwargs):
599         """Execute the embedded function object asynchronously.
600
601         The function given to the constructor is transparently called and
602         requires that "ready" be intermittently polled. If and when it is
603         True, the "value" property may then be checked for returned data.
604         """
605         self.__limit = kwargs.pop("timeout", self.__limit)
606         self.__queue = multiprocessing.Queue(1)
607         args = (self.__queue, self.__function) + args
608         self.__process = multiprocessing.Process(
609             target=_target, args=args, kwargs=kwargs
610         )
611         self.__process.daemon = True
612         self.__process.start()
613         if self.__limit is not None:
614             self.__timeout = self.__limit + time.time()
615         while not self.ready:
616             time.sleep(0.1)
617         return self.value
618
619     def cancel(self):
620         """Terminate any possible execution of the embedded function."""
621         if self.__process.is_alive():
622             self.__process.terminate()
623         _raise_exception(self.__timeout_exception, self.__error_message)
624
625     @property
626     def ready(self):
627         """Read-only property indicating status of "value" property."""
628         if self.__limit and self.__timeout < time.time():
629             self.cancel()
630         return self.__queue.full() and not self.__queue.empty()
631
632     @property
633     def value(self):
634         """Read-only property containing data returned from function."""
635         if self.ready is True:
636             flag, load = self.__queue.get()
637             if flag:
638                 return load
639             raise load
640
641
642 def timeout(
643     seconds: float = 1.0,
644     use_signals: Optional[bool] = None,
645     timeout_exception=exceptions.TimeoutError,
646     error_message="Function call timed out",
647 ):
648     """Add a timeout parameter to a function and return the function.
649
650     Note: the use_signals parameter is included in order to support
651     multiprocessing scenarios (signal can only be used from the process'
652     main thread).  When not using signals, timeout granularity will be
653     rounded to the nearest 0.1s.
654
655     Raises an exception when/if the timeout is reached.
656
657     It is illegal to pass anything other than a function as the first
658     parameter.  The function is wrapped and returned to the caller.
659
660     >>> @timeout(0.2)
661     ... def foo(delay: float):
662     ...     time.sleep(delay)
663     ...     return "ok"
664
665     >>> foo(0)
666     'ok'
667
668     >>> foo(1.0)
669     Traceback (most recent call last):
670     ...
671     Exception: Function call timed out
672
673     """
674     if use_signals is None:
675         import thread_utils
676
677         use_signals = thread_utils.is_current_thread_main_thread()
678
679     def decorate(function):
680         if use_signals:
681
682             def handler(signum, frame):
683                 _raise_exception(timeout_exception, error_message)
684
685             @functools.wraps(function)
686             def new_function(*args, **kwargs):
687                 new_seconds = kwargs.pop("timeout", seconds)
688                 if new_seconds:
689                     old = signal.signal(signal.SIGALRM, handler)
690                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
691
692                 if not seconds:
693                     return function(*args, **kwargs)
694
695                 try:
696                     return function(*args, **kwargs)
697                 finally:
698                     if new_seconds:
699                         signal.setitimer(signal.ITIMER_REAL, 0)
700                         signal.signal(signal.SIGALRM, old)
701
702             return new_function
703         else:
704
705             @functools.wraps(function)
706             def new_function(*args, **kwargs):
707                 timeout_wrapper = _Timeout(
708                     function, timeout_exception, error_message, seconds
709                 )
710                 return timeout_wrapper(*args, **kwargs)
711
712             return new_function
713
714     return decorate
715
716
717 class non_reentrant_code(object):
718     def __init__(self):
719         self._lock = threading.RLock
720         self._entered = False
721
722     def __call__(self, f):
723         def _gatekeeper(*args, **kwargs):
724             with self._lock:
725                 if self._entered:
726                     return
727                 self._entered = True
728                 f(*args, **kwargs)
729                 self._entered = False
730
731         return _gatekeeper
732
733
734 class rlocked(object):
735     def __init__(self):
736         self._lock = threading.RLock
737         self._entered = False
738
739     def __call__(self, f):
740         def _gatekeeper(*args, **kwargs):
741             with self._lock:
742                 if self._entered:
743                     return
744                 self._entered = True
745                 f(*args, **kwargs)
746                 self._entered = False
747
748         return _gatekeeper
749
750
751 def call_with_sample_rate(sample_rate: float) -> Callable:
752     if not 0.0 <= sample_rate <= 1.0:
753         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
754         logger.critical(msg)
755         raise ValueError(msg)
756
757     def decorator(f):
758         @functools.wraps(f)
759         def _call_with_sample_rate(*args, **kwargs):
760             if random.uniform(0, 1) < sample_rate:
761                 return f(*args, **kwargs)
762             else:
763                 logger.debug(
764                     f"@call_with_sample_rate skipping a call to {f.__name__}"
765                 )
766
767         return _call_with_sample_rate
768
769     return decorator
770
771
772 def decorate_matching_methods_with(decorator, acl=None):
773     """Apply decorator to all methods in a class whose names begin with
774     prefix.  If prefix is None (default), decorate all methods in the
775     class.
776     """
777
778     def decorate_the_class(cls):
779         for name, m in inspect.getmembers(cls, inspect.isfunction):
780             if acl is None:
781                 setattr(cls, name, decorator(m))
782             else:
783                 if acl(name):
784                     setattr(cls, name, decorator(m))
785         return cls
786
787     return decorate_the_class
788
789
790 if __name__ == '__main__':
791     import doctest
792
793     doctest.testmod()