Improve the locking decorators; @synchronized(lock).
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function.
30
31     >>> @timed
32     ... def foo():
33     ...     import time
34     ...     time.sleep(0.1)
35
36     >>> foo()  # doctest: +ELLIPSIS
37     Finished foo in ...
38
39     """
40
41     @functools.wraps(func)
42     def wrapper_timer(*args, **kwargs):
43         start_time = time.perf_counter()
44         value = func(*args, **kwargs)
45         end_time = time.perf_counter()
46         run_time = end_time - start_time
47         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
48         print(msg)
49         logger.info(msg)
50         return value
51
52     return wrapper_timer
53
54
55 def invocation_logged(func: Callable) -> Callable:
56     """Log the call of a function.
57
58     >>> @invocation_logged
59     ... def foo():
60     ...     print('Hello, world.')
61
62     >>> foo()
63     Entered foo
64     Hello, world.
65     Exited foo
66
67     """
68
69     @functools.wraps(func)
70     def wrapper_invocation_logged(*args, **kwargs):
71         msg = f"Entered {func.__qualname__}"
72         print(msg)
73         logger.info(msg)
74         ret = func(*args, **kwargs)
75         msg = f"Exited {func.__qualname__}"
76         print(msg)
77         logger.info(msg)
78         return ret
79
80     return wrapper_invocation_logged
81
82
83 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
84     """Limit invocation of a wrapped function to n calls per period.
85     Thread safe.  In testing this was relatively fair with multiple
86     threads using it though that hasn't been measured.
87
88     >>> import time
89     >>> import decorator_utils
90     >>> import thread_utils
91
92     >>> calls = 0
93
94     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
95     ... def limited(x: int):
96     ...     global calls
97     ...     calls += 1
98
99     >>> @thread_utils.background_thread
100     ... def a(stop):
101     ...     for _ in range(3):
102     ...         limited(_)
103
104     >>> @thread_utils.background_thread
105     ... def b(stop):
106     ...     for _ in range(3):
107     ...         limited(_)
108
109     >>> start = time.time()
110     >>> (t1, e1) = a()
111     >>> (t2, e2) = b()
112     >>> t1.join()
113     >>> t2.join()
114     >>> end = time.time()
115     >>> dur = end - start
116     >>> dur > 0.5
117     True
118
119     >>> calls
120     6
121
122     """
123     min_interval_seconds = per_period_in_seconds / float(n_calls)
124
125     def wrapper_rate_limited(func: Callable) -> Callable:
126         cv = threading.Condition()
127         last_invocation_timestamp = [0.0]
128
129         def may_proceed() -> float:
130             now = time.time()
131             last_invocation = last_invocation_timestamp[0]
132             if last_invocation != 0.0:
133                 elapsed_since_last = now - last_invocation
134                 wait_time = min_interval_seconds - elapsed_since_last
135             else:
136                 wait_time = 0.0
137             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
138             return wait_time
139
140         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
141             with cv:
142                 while True:
143                     if cv.wait_for(
144                         lambda: may_proceed() <= 0.0,
145                         timeout=may_proceed(),
146                     ):
147                         break
148             with cv:
149                 logger.debug(f'@{time.time()}> calling it...')
150                 ret = func(*args, **kargs)
151                 last_invocation_timestamp[0] = time.time()
152                 logger.debug(
153                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
154                 )
155                 cv.notify()
156             return ret
157
158         return wrapper_wrapper_rate_limited
159
160     return wrapper_rate_limited
161
162
163 def debug_args(func: Callable) -> Callable:
164     """Print the function signature and return value at each call.
165
166     >>> @debug_args
167     ... def foo(a, b, c):
168     ...     print(a)
169     ...     print(b)
170     ...     print(c)
171     ...     return (a + b, c)
172
173     >>> foo(1, 2.0, "test")
174     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
175     1
176     2.0
177     test
178     foo returned (3.0, 'test'):<class 'tuple'>
179     (3.0, 'test')
180     """
181
182     @functools.wraps(func)
183     def wrapper_debug_args(*args, **kwargs):
184         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
185         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
186         signature = ", ".join(args_repr + kwargs_repr)
187         msg = f"Calling {func.__qualname__}({signature})"
188         print(msg)
189         logger.info(msg)
190         value = func(*args, **kwargs)
191         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
192         print(msg)
193         logger.info(msg)
194         return value
195
196     return wrapper_debug_args
197
198
199 def debug_count_calls(func: Callable) -> Callable:
200     """Count function invocations and print a message befor every call.
201
202     >>> @debug_count_calls
203     ... def factoral(x):
204     ...     if x == 1:
205     ...         return 1
206     ...     return x * factoral(x - 1)
207
208     >>> factoral(5)
209     Call #1 of 'factoral'
210     Call #2 of 'factoral'
211     Call #3 of 'factoral'
212     Call #4 of 'factoral'
213     Call #5 of 'factoral'
214     120
215
216     """
217
218     @functools.wraps(func)
219     def wrapper_debug_count_calls(*args, **kwargs):
220         wrapper_debug_count_calls.num_calls += 1
221         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
222         print(msg)
223         logger.info(msg)
224         return func(*args, **kwargs)
225
226     wrapper_debug_count_calls.num_calls = 0
227     return wrapper_debug_count_calls
228
229
230 class DelayWhen(enum.IntEnum):
231     BEFORE_CALL = 1
232     AFTER_CALL = 2
233     BEFORE_AND_AFTER = 3
234
235
236 def delay(
237     _func: Callable = None,
238     *,
239     seconds: float = 1.0,
240     when: DelayWhen = DelayWhen.BEFORE_CALL,
241 ) -> Callable:
242     """Delay the execution of a function by sleeping before and/or after.
243
244     Slow down a function by inserting a delay before and/or after its
245     invocation.
246
247     >>> import time
248
249     >>> @delay(seconds=1.0)
250     ... def foo():
251     ...     pass
252
253     >>> start = time.time()
254     >>> foo()
255     >>> dur = time.time() - start
256     >>> dur >= 1.0
257     True
258
259     """
260
261     def decorator_delay(func: Callable) -> Callable:
262         @functools.wraps(func)
263         def wrapper_delay(*args, **kwargs):
264             if when & DelayWhen.BEFORE_CALL:
265                 logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
266                 time.sleep(seconds)
267             retval = func(*args, **kwargs)
268             if when & DelayWhen.AFTER_CALL:
269                 logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
270                 time.sleep(seconds)
271             return retval
272
273         return wrapper_delay
274
275     if _func is None:
276         return decorator_delay
277     else:
278         return decorator_delay(_func)
279
280
281 class _SingletonWrapper:
282     """
283     A singleton wrapper class. Its instances would be created
284     for each decorated class.
285
286     """
287
288     def __init__(self, cls):
289         self.__wrapped__ = cls
290         self._instance = None
291
292     def __call__(self, *args, **kwargs):
293         """Returns a single instance of decorated class"""
294         logger.debug(
295             f"@singleton returning global instance of {self.__wrapped__.__name__}"
296         )
297         if self._instance is None:
298             self._instance = self.__wrapped__(*args, **kwargs)
299         return self._instance
300
301
302 def singleton(cls):
303     """
304     A singleton decorator. Returns a wrapper objects. A call on that object
305     returns a single instance object of decorated class. Use the __wrapped__
306     attribute to access decorated class directly in unit tests
307
308     >>> @singleton
309     ... class foo(object):
310     ...     pass
311
312     >>> a = foo()
313     >>> b = foo()
314     >>> a is b
315     True
316
317     >>> id(a) == id(b)
318     True
319
320     """
321     return _SingletonWrapper(cls)
322
323
324 def memoized(func: Callable) -> Callable:
325     """Keep a cache of previous function call results.
326
327     The cache here is a dict with a key based on the arguments to the
328     call.  Consider also: functools.lru_cache for a more advanced
329     implementation.
330
331     >>> import time
332
333     >>> @memoized
334     ... def expensive(arg) -> int:
335     ...     # Simulate something slow to compute or lookup
336     ...     time.sleep(1.0)
337     ...     return arg * arg
338
339     >>> start = time.time()
340     >>> expensive(5)           # Takes about 1 sec
341     25
342
343     >>> expensive(3)           # Also takes about 1 sec
344     9
345
346     >>> expensive(5)           # Pulls from cache, fast
347     25
348
349     >>> expensive(3)           # Pulls from cache again, fast
350     9
351
352     >>> dur = time.time() - start
353     >>> dur < 3.0
354     True
355
356     """
357
358     @functools.wraps(func)
359     def wrapper_memoized(*args, **kwargs):
360         cache_key = args + tuple(kwargs.items())
361         if cache_key not in wrapper_memoized.cache:
362             value = func(*args, **kwargs)
363             logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
364             wrapper_memoized.cache[cache_key] = value
365         else:
366             logger.debug(f"Returning memoized value for {func.__name__}")
367         return wrapper_memoized.cache[cache_key]
368
369     wrapper_memoized.cache = dict()
370     return wrapper_memoized
371
372
373 def retry_predicate(
374     tries: int,
375     *,
376     predicate: Callable[..., bool],
377     delay_sec: float = 3.0,
378     backoff: float = 2.0,
379 ):
380     """Retries a function or method up to a certain number of times
381     with a prescribed initial delay period and backoff rate.
382
383     tries is the maximum number of attempts to run the function.
384     delay_sec sets the initial delay period in seconds.
385     backoff is a multiplied (must be >1) used to modify the delay.
386     predicate is a function that will be passed the retval of the
387     decorated function and must return True to stop or False to
388     retry.
389
390     """
391     if backoff < 1.0:
392         msg = f"backoff must be greater than or equal to 1, got {backoff}"
393         logger.critical(msg)
394         raise ValueError(msg)
395
396     tries = math.floor(tries)
397     if tries < 0:
398         msg = f"tries must be 0 or greater, got {tries}"
399         logger.critical(msg)
400         raise ValueError(msg)
401
402     if delay_sec <= 0:
403         msg = f"delay_sec must be greater than 0, got {delay_sec}"
404         logger.critical(msg)
405         raise ValueError(msg)
406
407     def deco_retry(f):
408         @functools.wraps(f)
409         def f_retry(*args, **kwargs):
410             mtries, mdelay = tries, delay_sec  # make mutable
411             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
412             retval = f(*args, **kwargs)
413             while mtries > 0:
414                 if predicate(retval) is True:
415                     logger.debug('Predicate succeeded, deco_retry is done.')
416                     return retval
417                 logger.debug("Predicate failed, sleeping and retrying.")
418                 mtries -= 1
419                 time.sleep(mdelay)
420                 mdelay *= backoff
421                 retval = f(*args, **kwargs)
422             return retval
423
424         return f_retry
425
426     return deco_retry
427
428
429 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
430     """A helper for @retry_predicate that retries a decorated
431     function as long as it keeps returning False.
432
433     >>> import time
434
435     >>> counter = 0
436
437     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
438     ... def foo():
439     ...     global counter
440     ...     counter += 1
441     ...     return counter >= 3
442
443     >>> start = time.time()
444     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
445     True
446
447     >>> dur = time.time() - start
448     >>> counter
449     3
450     >>> dur > 2.0
451     True
452     >>> dur < 2.3
453     True
454
455     """
456     return retry_predicate(
457         tries,
458         predicate=lambda x: x is True,
459         delay_sec=delay_sec,
460         backoff=backoff,
461     )
462
463
464 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
465     """Another helper for @retry_predicate above.  Retries up to N
466     times so long as the wrapped function returns None with a delay
467     between each retry and a backoff that can increase the delay.
468
469     """
470     return retry_predicate(
471         tries,
472         predicate=lambda x: x is not None,
473         delay_sec=delay_sec,
474         backoff=backoff,
475     )
476
477
478 def deprecated(func):
479     """This is a decorator which can be used to mark functions
480     as deprecated. It will result in a warning being emitted
481     when the function is used.
482
483     """
484
485     @functools.wraps(func)
486     def wrapper_deprecated(*args, **kwargs):
487         msg = f"Call to deprecated function {func.__qualname__}"
488         logger.warning(msg)
489         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
490         print(msg, file=sys.stderr)
491         return func(*args, **kwargs)
492
493     return wrapper_deprecated
494
495
496 def thunkify(func):
497     """
498     Make a function immediately return a function of no args which,
499     when called, waits for the result, which will start being
500     processed in another thread.
501     """
502
503     @functools.wraps(func)
504     def lazy_thunked(*args, **kwargs):
505         wait_event = threading.Event()
506
507         result = [None]
508         exc = [False, None]
509
510         def worker_func():
511             try:
512                 func_result = func(*args, **kwargs)
513                 result[0] = func_result
514             except Exception:
515                 exc[0] = True
516                 exc[1] = sys.exc_info()  # (type, value, traceback)
517                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
518                 logger.warning(msg)
519             finally:
520                 wait_event.set()
521
522         def thunk():
523             wait_event.wait()
524             if exc[0]:
525                 raise exc[1][0](exc[1][1])
526             return result[0]
527
528         threading.Thread(target=worker_func).start()
529         return thunk
530
531     return lazy_thunked
532
533
534 ############################################################
535 # Timeout
536 ############################################################
537
538 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
539 # Used work of Stephen "Zero" Chappell <[email protected]>
540 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
541
542
543 def _raise_exception(exception, error_message: Optional[str]):
544     if error_message is None:
545         raise Exception()
546     else:
547         raise Exception(error_message)
548
549
550 def _target(queue, function, *args, **kwargs):
551     """Run a function with arguments and return output via a queue.
552
553     This is a helper function for the Process created in _Timeout. It runs
554     the function with positional arguments and keyword arguments and then
555     returns the function's output by way of a queue. If an exception gets
556     raised, it is returned to _Timeout to be raised by the value property.
557     """
558     try:
559         queue.put((True, function(*args, **kwargs)))
560     except Exception:
561         queue.put((False, sys.exc_info()[1]))
562
563
564 class _Timeout(object):
565     """Wrap a function and add a timeout to it.
566
567     Instances of this class are automatically generated by the add_timeout
568     function defined below.  Do not use directly.
569     """
570
571     def __init__(
572         self,
573         function: Callable,
574         timeout_exception: Exception,
575         error_message: str,
576         seconds: float,
577     ):
578         self.__limit = seconds
579         self.__function = function
580         self.__timeout_exception = timeout_exception
581         self.__error_message = error_message
582         self.__name__ = function.__name__
583         self.__doc__ = function.__doc__
584         self.__timeout = time.time()
585         self.__process = multiprocessing.Process()
586         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
587
588     def __call__(self, *args, **kwargs):
589         """Execute the embedded function object asynchronously.
590
591         The function given to the constructor is transparently called and
592         requires that "ready" be intermittently polled. If and when it is
593         True, the "value" property may then be checked for returned data.
594         """
595         self.__limit = kwargs.pop("timeout", self.__limit)
596         self.__queue = multiprocessing.Queue(1)
597         args = (self.__queue, self.__function) + args
598         self.__process = multiprocessing.Process(
599             target=_target, args=args, kwargs=kwargs
600         )
601         self.__process.daemon = True
602         self.__process.start()
603         if self.__limit is not None:
604             self.__timeout = self.__limit + time.time()
605         while not self.ready:
606             time.sleep(0.1)
607         return self.value
608
609     def cancel(self):
610         """Terminate any possible execution of the embedded function."""
611         if self.__process.is_alive():
612             self.__process.terminate()
613         _raise_exception(self.__timeout_exception, self.__error_message)
614
615     @property
616     def ready(self):
617         """Read-only property indicating status of "value" property."""
618         if self.__limit and self.__timeout < time.time():
619             self.cancel()
620         return self.__queue.full() and not self.__queue.empty()
621
622     @property
623     def value(self):
624         """Read-only property containing data returned from function."""
625         if self.ready is True:
626             flag, load = self.__queue.get()
627             if flag:
628                 return load
629             raise load
630
631
632 def timeout(
633     seconds: float = 1.0,
634     use_signals: Optional[bool] = None,
635     timeout_exception=exceptions.TimeoutError,
636     error_message="Function call timed out",
637 ):
638     """Add a timeout parameter to a function and return the function.
639
640     Note: the use_signals parameter is included in order to support
641     multiprocessing scenarios (signal can only be used from the process'
642     main thread).  When not using signals, timeout granularity will be
643     rounded to the nearest 0.1s.
644
645     Raises an exception when/if the timeout is reached.
646
647     It is illegal to pass anything other than a function as the first
648     parameter.  The function is wrapped and returned to the caller.
649
650     >>> @timeout(0.2)
651     ... def foo(delay: float):
652     ...     time.sleep(delay)
653     ...     return "ok"
654
655     >>> foo(0)
656     'ok'
657
658     >>> foo(1.0)
659     Traceback (most recent call last):
660     ...
661     Exception: Function call timed out
662
663     """
664     if use_signals is None:
665         import thread_utils
666
667         use_signals = thread_utils.is_current_thread_main_thread()
668
669     def decorate(function):
670         if use_signals:
671
672             def handler(signum, frame):
673                 _raise_exception(timeout_exception, error_message)
674
675             @functools.wraps(function)
676             def new_function(*args, **kwargs):
677                 new_seconds = kwargs.pop("timeout", seconds)
678                 if new_seconds:
679                     old = signal.signal(signal.SIGALRM, handler)
680                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
681
682                 if not seconds:
683                     return function(*args, **kwargs)
684
685                 try:
686                     return function(*args, **kwargs)
687                 finally:
688                     if new_seconds:
689                         signal.setitimer(signal.ITIMER_REAL, 0)
690                         signal.signal(signal.SIGALRM, old)
691
692             return new_function
693         else:
694
695             @functools.wraps(function)
696             def new_function(*args, **kwargs):
697                 timeout_wrapper = _Timeout(
698                     function, timeout_exception, error_message, seconds
699                 )
700                 return timeout_wrapper(*args, **kwargs)
701
702             return new_function
703
704     return decorate
705
706
707 def synchronized(lock):
708     def wrap(f):
709         @functools.wraps(f)
710         def _gatekeeper(*args, **kw):
711             lock.acquire()
712             try:
713                 return f(*args, **kw)
714             finally:
715                 lock.release()
716
717         return _gatekeeper
718
719     return wrap
720
721
722 def call_with_sample_rate(sample_rate: float) -> Callable:
723     if not 0.0 <= sample_rate <= 1.0:
724         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
725         logger.critical(msg)
726         raise ValueError(msg)
727
728     def decorator(f):
729         @functools.wraps(f)
730         def _call_with_sample_rate(*args, **kwargs):
731             if random.uniform(0, 1) < sample_rate:
732                 return f(*args, **kwargs)
733             else:
734                 logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
735
736         return _call_with_sample_rate
737
738     return decorator
739
740
741 def decorate_matching_methods_with(decorator, acl=None):
742     """Apply decorator to all methods in a class whose names begin with
743     prefix.  If prefix is None (default), decorate all methods in the
744     class.
745     """
746
747     def decorate_the_class(cls):
748         for name, m in inspect.getmembers(cls, inspect.isfunction):
749             if acl is None:
750                 setattr(cls, name, decorator(m))
751             else:
752                 if acl(name):
753                     setattr(cls, name, decorator(m))
754         return cls
755
756     return decorate_the_class
757
758
759 if __name__ == '__main__':
760     import doctest
761
762     doctest.testmod()