Add coding comments for files with utf8 characters in there.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 import warnings
18 from typing import Any, Callable, Optional
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24 logger = logging.getLogger(__name__)
25
26
27 def timed(func: Callable) -> Callable:
28     """Print the runtime of the decorated function.
29
30     >>> @timed
31     ... def foo():
32     ...     import time
33     ...     time.sleep(0.1)
34
35     >>> foo()  # doctest: +ELLIPSIS
36     Finished foo in ...
37
38     """
39
40     @functools.wraps(func)
41     def wrapper_timer(*args, **kwargs):
42         start_time = time.perf_counter()
43         value = func(*args, **kwargs)
44         end_time = time.perf_counter()
45         run_time = end_time - start_time
46         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
47         print(msg)
48         logger.info(msg)
49         return value
50
51     return wrapper_timer
52
53
54 def invocation_logged(func: Callable) -> Callable:
55     """Log the call of a function.
56
57     >>> @invocation_logged
58     ... def foo():
59     ...     print('Hello, world.')
60
61     >>> foo()
62     Entered foo
63     Hello, world.
64     Exited foo
65
66     """
67
68     @functools.wraps(func)
69     def wrapper_invocation_logged(*args, **kwargs):
70         msg = f"Entered {func.__qualname__}"
71         print(msg)
72         logger.info(msg)
73         ret = func(*args, **kwargs)
74         msg = f"Exited {func.__qualname__}"
75         print(msg)
76         logger.info(msg)
77         return ret
78
79     return wrapper_invocation_logged
80
81
82 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
83     """Limit invocation of a wrapped function to n calls per period.
84     Thread safe.  In testing this was relatively fair with multiple
85     threads using it though that hasn't been measured.
86
87     >>> import time
88     >>> import decorator_utils
89     >>> import thread_utils
90
91     >>> calls = 0
92
93     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
94     ... def limited(x: int):
95     ...     global calls
96     ...     calls += 1
97
98     >>> @thread_utils.background_thread
99     ... def a(stop):
100     ...     for _ in range(3):
101     ...         limited(_)
102
103     >>> @thread_utils.background_thread
104     ... def b(stop):
105     ...     for _ in range(3):
106     ...         limited(_)
107
108     >>> start = time.time()
109     >>> (t1, e1) = a()
110     >>> (t2, e2) = b()
111     >>> t1.join()
112     >>> t2.join()
113     >>> end = time.time()
114     >>> dur = end - start
115     >>> dur > 0.5
116     True
117
118     >>> calls
119     6
120
121     """
122     min_interval_seconds = per_period_in_seconds / float(n_calls)
123
124     def wrapper_rate_limited(func: Callable) -> Callable:
125         cv = threading.Condition()
126         last_invocation_timestamp = [0.0]
127
128         def may_proceed() -> float:
129             now = time.time()
130             last_invocation = last_invocation_timestamp[0]
131             if last_invocation != 0.0:
132                 elapsed_since_last = now - last_invocation
133                 wait_time = min_interval_seconds - elapsed_since_last
134             else:
135                 wait_time = 0.0
136             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
137             return wait_time
138
139         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
140             with cv:
141                 while True:
142                     if cv.wait_for(
143                         lambda: may_proceed() <= 0.0,
144                         timeout=may_proceed(),
145                     ):
146                         break
147             with cv:
148                 logger.debug(f'@{time.time()}> calling it...')
149                 ret = func(*args, **kargs)
150                 last_invocation_timestamp[0] = time.time()
151                 logger.debug(f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}')
152                 cv.notify()
153             return ret
154
155         return wrapper_wrapper_rate_limited
156
157     return wrapper_rate_limited
158
159
160 def debug_args(func: Callable) -> Callable:
161     """Print the function signature and return value at each call.
162
163     >>> @debug_args
164     ... def foo(a, b, c):
165     ...     print(a)
166     ...     print(b)
167     ...     print(c)
168     ...     return (a + b, c)
169
170     >>> foo(1, 2.0, "test")
171     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
172     1
173     2.0
174     test
175     foo returned (3.0, 'test'):<class 'tuple'>
176     (3.0, 'test')
177     """
178
179     @functools.wraps(func)
180     def wrapper_debug_args(*args, **kwargs):
181         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
182         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
183         signature = ", ".join(args_repr + kwargs_repr)
184         msg = f"Calling {func.__qualname__}({signature})"
185         print(msg)
186         logger.info(msg)
187         value = func(*args, **kwargs)
188         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
189         print(msg)
190         logger.info(msg)
191         return value
192
193     return wrapper_debug_args
194
195
196 def debug_count_calls(func: Callable) -> Callable:
197     """Count function invocations and print a message befor every call.
198
199     >>> @debug_count_calls
200     ... def factoral(x):
201     ...     if x == 1:
202     ...         return 1
203     ...     return x * factoral(x - 1)
204
205     >>> factoral(5)
206     Call #1 of 'factoral'
207     Call #2 of 'factoral'
208     Call #3 of 'factoral'
209     Call #4 of 'factoral'
210     Call #5 of 'factoral'
211     120
212
213     """
214
215     @functools.wraps(func)
216     def wrapper_debug_count_calls(*args, **kwargs):
217         wrapper_debug_count_calls.num_calls += 1
218         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
219         print(msg)
220         logger.info(msg)
221         return func(*args, **kwargs)
222
223     wrapper_debug_count_calls.num_calls = 0  # type: ignore
224     return wrapper_debug_count_calls
225
226
227 class DelayWhen(enum.IntEnum):
228     BEFORE_CALL = 1
229     AFTER_CALL = 2
230     BEFORE_AND_AFTER = 3
231
232
233 def delay(
234     _func: Callable = None,
235     *,
236     seconds: float = 1.0,
237     when: DelayWhen = DelayWhen.BEFORE_CALL,
238 ) -> Callable:
239     """Delay the execution of a function by sleeping before and/or after.
240
241     Slow down a function by inserting a delay before and/or after its
242     invocation.
243
244     >>> import time
245
246     >>> @delay(seconds=1.0)
247     ... def foo():
248     ...     pass
249
250     >>> start = time.time()
251     >>> foo()
252     >>> dur = time.time() - start
253     >>> dur >= 1.0
254     True
255
256     """
257
258     def decorator_delay(func: Callable) -> Callable:
259         @functools.wraps(func)
260         def wrapper_delay(*args, **kwargs):
261             if when & DelayWhen.BEFORE_CALL:
262                 logger.debug(f"@delay for {seconds}s BEFORE_CALL to {func.__name__}")
263                 time.sleep(seconds)
264             retval = func(*args, **kwargs)
265             if when & DelayWhen.AFTER_CALL:
266                 logger.debug(f"@delay for {seconds}s AFTER_CALL to {func.__name__}")
267                 time.sleep(seconds)
268             return retval
269
270         return wrapper_delay
271
272     if _func is None:
273         return decorator_delay
274     else:
275         return decorator_delay(_func)
276
277
278 class _SingletonWrapper:
279     """
280     A singleton wrapper class. Its instances would be created
281     for each decorated class.
282
283     """
284
285     def __init__(self, cls):
286         self.__wrapped__ = cls
287         self._instance = None
288
289     def __call__(self, *args, **kwargs):
290         """Returns a single instance of decorated class"""
291         logger.debug(f"@singleton returning global instance of {self.__wrapped__.__name__}")
292         if self._instance is None:
293             self._instance = self.__wrapped__(*args, **kwargs)
294         return self._instance
295
296
297 def singleton(cls):
298     """
299     A singleton decorator. Returns a wrapper objects. A call on that object
300     returns a single instance object of decorated class. Use the __wrapped__
301     attribute to access decorated class directly in unit tests
302
303     >>> @singleton
304     ... class foo(object):
305     ...     pass
306
307     >>> a = foo()
308     >>> b = foo()
309     >>> a is b
310     True
311
312     >>> id(a) == id(b)
313     True
314
315     """
316     return _SingletonWrapper(cls)
317
318
319 def memoized(func: Callable) -> Callable:
320     """Keep a cache of previous function call results.
321
322     The cache here is a dict with a key based on the arguments to the
323     call.  Consider also: functools.lru_cache for a more advanced
324     implementation.
325
326     >>> import time
327
328     >>> @memoized
329     ... def expensive(arg) -> int:
330     ...     # Simulate something slow to compute or lookup
331     ...     time.sleep(1.0)
332     ...     return arg * arg
333
334     >>> start = time.time()
335     >>> expensive(5)           # Takes about 1 sec
336     25
337
338     >>> expensive(3)           # Also takes about 1 sec
339     9
340
341     >>> expensive(5)           # Pulls from cache, fast
342     25
343
344     >>> expensive(3)           # Pulls from cache again, fast
345     9
346
347     >>> dur = time.time() - start
348     >>> dur < 3.0
349     True
350
351     """
352
353     @functools.wraps(func)
354     def wrapper_memoized(*args, **kwargs):
355         cache_key = args + tuple(kwargs.items())
356         if cache_key not in wrapper_memoized.cache:
357             value = func(*args, **kwargs)
358             logger.debug(f"Memoizing {cache_key} => {value} for {func.__name__}")
359             wrapper_memoized.cache[cache_key] = value
360         else:
361             logger.debug(f"Returning memoized value for {func.__name__}")
362         return wrapper_memoized.cache[cache_key]
363
364     wrapper_memoized.cache = dict()  # type: ignore
365     return wrapper_memoized
366
367
368 def retry_predicate(
369     tries: int,
370     *,
371     predicate: Callable[..., bool],
372     delay_sec: float = 3.0,
373     backoff: float = 2.0,
374 ):
375     """Retries a function or method up to a certain number of times
376     with a prescribed initial delay period and backoff rate.
377
378     tries is the maximum number of attempts to run the function.
379     delay_sec sets the initial delay period in seconds.
380     backoff is a multiplied (must be >1) used to modify the delay.
381     predicate is a function that will be passed the retval of the
382     decorated function and must return True to stop or False to
383     retry.
384
385     """
386     if backoff < 1.0:
387         msg = f"backoff must be greater than or equal to 1, got {backoff}"
388         logger.critical(msg)
389         raise ValueError(msg)
390
391     tries = math.floor(tries)
392     if tries < 0:
393         msg = f"tries must be 0 or greater, got {tries}"
394         logger.critical(msg)
395         raise ValueError(msg)
396
397     if delay_sec <= 0:
398         msg = f"delay_sec must be greater than 0, got {delay_sec}"
399         logger.critical(msg)
400         raise ValueError(msg)
401
402     def deco_retry(f):
403         @functools.wraps(f)
404         def f_retry(*args, **kwargs):
405             mtries, mdelay = tries, delay_sec  # make mutable
406             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
407             retval = f(*args, **kwargs)
408             while mtries > 0:
409                 if predicate(retval) is True:
410                     logger.debug('Predicate succeeded, deco_retry is done.')
411                     return retval
412                 logger.debug("Predicate failed, sleeping and retrying.")
413                 mtries -= 1
414                 time.sleep(mdelay)
415                 mdelay *= backoff
416                 retval = f(*args, **kwargs)
417             return retval
418
419         return f_retry
420
421     return deco_retry
422
423
424 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
425     """A helper for @retry_predicate that retries a decorated
426     function as long as it keeps returning False.
427
428     >>> import time
429
430     >>> counter = 0
431
432     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
433     ... def foo():
434     ...     global counter
435     ...     counter += 1
436     ...     return counter >= 3
437
438     >>> start = time.time()
439     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
440     True
441
442     >>> dur = time.time() - start
443     >>> counter
444     3
445     >>> dur > 2.0
446     True
447     >>> dur < 2.3
448     True
449
450     """
451     return retry_predicate(
452         tries,
453         predicate=lambda x: x is True,
454         delay_sec=delay_sec,
455         backoff=backoff,
456     )
457
458
459 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
460     """Another helper for @retry_predicate above.  Retries up to N
461     times so long as the wrapped function returns None with a delay
462     between each retry and a backoff that can increase the delay.
463
464     """
465     return retry_predicate(
466         tries,
467         predicate=lambda x: x is not None,
468         delay_sec=delay_sec,
469         backoff=backoff,
470     )
471
472
473 def deprecated(func):
474     """This is a decorator which can be used to mark functions
475     as deprecated. It will result in a warning being emitted
476     when the function is used.
477
478     """
479
480     @functools.wraps(func)
481     def wrapper_deprecated(*args, **kwargs):
482         msg = f"Call to deprecated function {func.__qualname__}"
483         logger.warning(msg)
484         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
485         print(msg, file=sys.stderr)
486         return func(*args, **kwargs)
487
488     return wrapper_deprecated
489
490
491 def thunkify(func):
492     """
493     Make a function immediately return a function of no args which,
494     when called, waits for the result, which will start being
495     processed in another thread.
496     """
497
498     @functools.wraps(func)
499     def lazy_thunked(*args, **kwargs):
500         wait_event = threading.Event()
501
502         result = [None]
503         exc = [False, None]
504
505         def worker_func():
506             try:
507                 func_result = func(*args, **kwargs)
508                 result[0] = func_result
509             except Exception:
510                 exc[0] = True
511                 exc[1] = sys.exc_info()  # (type, value, traceback)
512                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
513                 logger.warning(msg)
514             finally:
515                 wait_event.set()
516
517         def thunk():
518             wait_event.wait()
519             if exc[0]:
520                 raise exc[1][0](exc[1][1])
521             return result[0]
522
523         threading.Thread(target=worker_func).start()
524         return thunk
525
526     return lazy_thunked
527
528
529 ############################################################
530 # Timeout
531 ############################################################
532
533 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
534 # Used work of Stephen "Zero" Chappell <[email protected]>
535 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
536
537
538 def _raise_exception(exception, error_message: Optional[str]):
539     if error_message is None:
540         raise Exception()
541     else:
542         raise Exception(error_message)
543
544
545 def _target(queue, function, *args, **kwargs):
546     """Run a function with arguments and return output via a queue.
547
548     This is a helper function for the Process created in _Timeout. It runs
549     the function with positional arguments and keyword arguments and then
550     returns the function's output by way of a queue. If an exception gets
551     raised, it is returned to _Timeout to be raised by the value property.
552     """
553     try:
554         queue.put((True, function(*args, **kwargs)))
555     except Exception:
556         queue.put((False, sys.exc_info()[1]))
557
558
559 class _Timeout(object):
560     """Wrap a function and add a timeout to it.
561
562     Instances of this class are automatically generated by the add_timeout
563     function defined below.  Do not use directly.
564     """
565
566     def __init__(
567         self,
568         function: Callable,
569         timeout_exception: Exception,
570         error_message: str,
571         seconds: float,
572     ):
573         self.__limit = seconds
574         self.__function = function
575         self.__timeout_exception = timeout_exception
576         self.__error_message = error_message
577         self.__name__ = function.__name__
578         self.__doc__ = function.__doc__
579         self.__timeout = time.time()
580         self.__process = multiprocessing.Process()
581         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
582
583     def __call__(self, *args, **kwargs):
584         """Execute the embedded function object asynchronously.
585
586         The function given to the constructor is transparently called and
587         requires that "ready" be intermittently polled. If and when it is
588         True, the "value" property may then be checked for returned data.
589         """
590         self.__limit = kwargs.pop("timeout", self.__limit)
591         self.__queue = multiprocessing.Queue(1)
592         args = (self.__queue, self.__function) + args
593         self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
594         self.__process.daemon = True
595         self.__process.start()
596         if self.__limit is not None:
597             self.__timeout = self.__limit + time.time()
598         while not self.ready:
599             time.sleep(0.1)
600         return self.value
601
602     def cancel(self):
603         """Terminate any possible execution of the embedded function."""
604         if self.__process.is_alive():
605             self.__process.terminate()
606         _raise_exception(self.__timeout_exception, self.__error_message)
607
608     @property
609     def ready(self):
610         """Read-only property indicating status of "value" property."""
611         if self.__limit and self.__timeout < time.time():
612             self.cancel()
613         return self.__queue.full() and not self.__queue.empty()
614
615     @property
616     def value(self):
617         """Read-only property containing data returned from function."""
618         if self.ready is True:
619             flag, load = self.__queue.get()
620             if flag:
621                 return load
622             raise load
623
624
625 def timeout(
626     seconds: float = 1.0,
627     use_signals: Optional[bool] = None,
628     timeout_exception=exceptions.TimeoutError,
629     error_message="Function call timed out",
630 ):
631     """Add a timeout parameter to a function and return the function.
632
633     Note: the use_signals parameter is included in order to support
634     multiprocessing scenarios (signal can only be used from the process'
635     main thread).  When not using signals, timeout granularity will be
636     rounded to the nearest 0.1s.
637
638     Raises an exception when/if the timeout is reached.
639
640     It is illegal to pass anything other than a function as the first
641     parameter.  The function is wrapped and returned to the caller.
642
643     >>> @timeout(0.2)
644     ... def foo(delay: float):
645     ...     time.sleep(delay)
646     ...     return "ok"
647
648     >>> foo(0)
649     'ok'
650
651     >>> foo(1.0)
652     Traceback (most recent call last):
653     ...
654     Exception: Function call timed out
655
656     """
657     if use_signals is None:
658         import thread_utils
659
660         use_signals = thread_utils.is_current_thread_main_thread()
661
662     def decorate(function):
663         if use_signals:
664
665             def handler(signum, frame):
666                 _raise_exception(timeout_exception, error_message)
667
668             @functools.wraps(function)
669             def new_function(*args, **kwargs):
670                 new_seconds = kwargs.pop("timeout", seconds)
671                 if new_seconds:
672                     old = signal.signal(signal.SIGALRM, handler)
673                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
674
675                 if not seconds:
676                     return function(*args, **kwargs)
677
678                 try:
679                     return function(*args, **kwargs)
680                 finally:
681                     if new_seconds:
682                         signal.setitimer(signal.ITIMER_REAL, 0)
683                         signal.signal(signal.SIGALRM, old)
684
685             return new_function
686         else:
687
688             @functools.wraps(function)
689             def new_function(*args, **kwargs):
690                 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
691                 return timeout_wrapper(*args, **kwargs)
692
693             return new_function
694
695     return decorate
696
697
698 def synchronized(lock):
699     def wrap(f):
700         @functools.wraps(f)
701         def _gatekeeper(*args, **kw):
702             lock.acquire()
703             try:
704                 return f(*args, **kw)
705             finally:
706                 lock.release()
707
708         return _gatekeeper
709
710     return wrap
711
712
713 def call_with_sample_rate(sample_rate: float) -> Callable:
714     if not 0.0 <= sample_rate <= 1.0:
715         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
716         logger.critical(msg)
717         raise ValueError(msg)
718
719     def decorator(f):
720         @functools.wraps(f)
721         def _call_with_sample_rate(*args, **kwargs):
722             if random.uniform(0, 1) < sample_rate:
723                 return f(*args, **kwargs)
724             else:
725                 logger.debug(f"@call_with_sample_rate skipping a call to {f.__name__}")
726
727         return _call_with_sample_rate
728
729     return decorator
730
731
732 def decorate_matching_methods_with(decorator, acl=None):
733     """Apply decorator to all methods in a class whose names begin with
734     prefix.  If prefix is None (default), decorate all methods in the
735     class.
736     """
737
738     def decorate_the_class(cls):
739         for name, m in inspect.getmembers(cls, inspect.isfunction):
740             if acl is None:
741                 setattr(cls, name, decorator(m))
742             else:
743                 if acl(name):
744                     setattr(cls, name, decorator(m))
745         return cls
746
747     return decorate_the_class
748
749
750 if __name__ == '__main__':
751     import doctest
752
753     doctest.testmod()