1e0fe18c4063b285e84d561b30b39e5b4245d001
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function."""
30
31     @functools.wraps(func)
32     def wrapper_timer(*args, **kwargs):
33         start_time = time.perf_counter()
34         value = func(*args, **kwargs)
35         end_time = time.perf_counter()
36         run_time = end_time - start_time
37         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
38         print(msg)
39         logger.info(msg)
40         return value
41     return wrapper_timer
42
43
44 def invocation_logged(func: Callable) -> Callable:
45     """Log the call of a function."""
46
47     @functools.wraps(func)
48     def wrapper_invocation_logged(*args, **kwargs):
49         msg = f"Entered {func.__qualname__}"
50         print(msg)
51         logger.info(msg)
52         ret = func(*args, **kwargs)
53         msg = f"Exited {func.__qualname__}"
54         print(msg)
55         logger.info(msg)
56         return ret
57     return wrapper_invocation_logged
58
59
60 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
61     """Limit invocation of a wrapped function to n calls per period.
62     Thread safe.  In testing this was relatively fair with multiple
63     threads using it though that hasn't been measured.
64
65     >>> import time
66     >>> import decorator_utils
67     >>> import thread_utils
68
69     >>> calls = 0
70
71     >>> @decorator_utils.rate_limited(1, per_period_in_seconds=1.0)
72     ... def limited(x: int):
73     ...     global calls
74     ...     calls += 1
75
76     >>> @thread_utils.background_thread
77     ... def a(stop):
78     ...     for _ in range(3):
79     ...         limited(_)
80
81     >>> @thread_utils.background_thread
82     ... def b(stop):
83     ...     for _ in range(3):
84     ...         limited(_)
85
86     >>> start = time.time()
87     >>> (t1, e1) = a()
88     >>> (t2, e2) = b()
89     >>> t1.join()
90     >>> t2.join()
91     >>> end = time.time()
92     >>> dur = end - start
93     >>> dur > 5.0
94     True
95
96     >>> calls
97     6
98
99     """
100     min_interval_seconds = per_period_in_seconds / float(n_calls)
101
102     def wrapper_rate_limited(func: Callable) -> Callable:
103         cv = threading.Condition()
104         last_invocation_timestamp = [0.0]
105
106         def may_proceed() -> float:
107             now = time.time()
108             last_invocation = last_invocation_timestamp[0]
109             if last_invocation != 0.0:
110                 elapsed_since_last = now - last_invocation
111                 wait_time = min_interval_seconds - elapsed_since_last
112             else:
113                 wait_time = 0.0
114             logger.debug(f'@{time.time()}> wait_time = {wait_time}')
115             return wait_time
116
117         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
118             with cv:
119                 while True:
120                     if cv.wait_for(
121                         lambda: may_proceed() <= 0.0,
122                         timeout=may_proceed(),
123                     ):
124                         break
125             with cv:
126                 logger.debug(f'@{time.time()}> calling it...')
127                 ret = func(*args, **kargs)
128                 last_invocation_timestamp[0] = time.time()
129                 logger.debug(
130                     f'@{time.time()}> Last invocation <- {last_invocation_timestamp[0]}'
131                 )
132                 cv.notify()
133             return ret
134         return wrapper_wrapper_rate_limited
135     return wrapper_rate_limited
136
137
138 def debug_args(func: Callable) -> Callable:
139     """Print the function signature and return value at each call."""
140
141     @functools.wraps(func)
142     def wrapper_debug_args(*args, **kwargs):
143         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
144         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
145         signature = ", ".join(args_repr + kwargs_repr)
146         msg = f"Calling {func.__name__}({signature})"
147         print(msg)
148         logger.info(msg)
149         value = func(*args, **kwargs)
150         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
151         logger.info(msg)
152         return value
153     return wrapper_debug_args
154
155
156 def debug_count_calls(func: Callable) -> Callable:
157     """Count function invocations and print a message befor every call."""
158
159     @functools.wraps(func)
160     def wrapper_debug_count_calls(*args, **kwargs):
161         wrapper_debug_count_calls.num_calls += 1
162         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
163         print(msg)
164         logger.info(msg)
165         return func(*args, **kwargs)
166     wrapper_debug_count_calls.num_calls = 0
167     return wrapper_debug_count_calls
168
169
170 class DelayWhen(enum.Enum):
171     BEFORE_CALL = 1
172     AFTER_CALL = 2
173     BEFORE_AND_AFTER = 3
174
175
176 def delay(
177     _func: Callable = None,
178     *,
179     seconds: float = 1.0,
180     when: DelayWhen = DelayWhen.BEFORE_CALL,
181 ) -> Callable:
182     """Delay the execution of a function by sleeping before and/or after.
183
184     Slow down a function by inserting a delay before and/or after its
185     invocation.
186     """
187
188     def decorator_delay(func: Callable) -> Callable:
189         @functools.wraps(func)
190         def wrapper_delay(*args, **kwargs):
191             if when & DelayWhen.BEFORE_CALL:
192                 logger.debug(
193                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
194                 )
195                 time.sleep(seconds)
196             retval = func(*args, **kwargs)
197             if when & DelayWhen.AFTER_CALL:
198                 logger.debug(
199                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
200                 )
201                 time.sleep(seconds)
202             return retval
203         return wrapper_delay
204
205     if _func is None:
206         return decorator_delay
207     else:
208         return decorator_delay(_func)
209
210
211 class _SingletonWrapper:
212     """
213     A singleton wrapper class. Its instances would be created
214     for each decorated class.
215     """
216
217     def __init__(self, cls):
218         self.__wrapped__ = cls
219         self._instance = None
220
221     def __call__(self, *args, **kwargs):
222         """Returns a single instance of decorated class"""
223         logger.debug(
224             f"@singleton returning global instance of {self.__wrapped__.__name__}"
225         )
226         if self._instance is None:
227             self._instance = self.__wrapped__(*args, **kwargs)
228         return self._instance
229
230
231 def singleton(cls):
232     """
233     A singleton decorator. Returns a wrapper objects. A call on that object
234     returns a single instance object of decorated class. Use the __wrapped__
235     attribute to access decorated class directly in unit tests
236     """
237     return _SingletonWrapper(cls)
238
239
240 def memoized(func: Callable) -> Callable:
241     """Keep a cache of previous function call results.
242
243     The cache here is a dict with a key based on the arguments to the
244     call.  Consider also: functools.lru_cache for a more advanced
245     implementation.
246     """
247
248     @functools.wraps(func)
249     def wrapper_memoized(*args, **kwargs):
250         cache_key = args + tuple(kwargs.items())
251         if cache_key not in wrapper_memoized.cache:
252             value = func(*args, **kwargs)
253             logger.debug(
254                 f"Memoizing {cache_key} => {value} for {func.__name__}"
255             )
256             wrapper_memoized.cache[cache_key] = value
257         else:
258             logger.debug(f"Returning memoized value for {func.__name__}")
259         return wrapper_memoized.cache[cache_key]
260     wrapper_memoized.cache = dict()
261     return wrapper_memoized
262
263
264 def retry_predicate(
265     tries: int,
266     *,
267     predicate: Callable[..., bool],
268     delay_sec: float = 3.0,
269     backoff: float = 2.0,
270 ):
271     """Retries a function or method up to a certain number of times
272     with a prescribed initial delay period and backoff rate.
273
274     tries is the maximum number of attempts to run the function.
275     delay_sec sets the initial delay period in seconds.
276     backoff is a multiplied (must be >1) used to modify the delay.
277     predicate is a function that will be passed the retval of the
278     decorated function and must return True to stop or False to
279     retry.
280     """
281     if backoff < 1.0:
282         msg = f"backoff must be greater than or equal to 1, got {backoff}"
283         logger.critical(msg)
284         raise ValueError(msg)
285
286     tries = math.floor(tries)
287     if tries < 0:
288         msg = f"tries must be 0 or greater, got {tries}"
289         logger.critical(msg)
290         raise ValueError(msg)
291
292     if delay_sec <= 0:
293         msg = f"delay_sec must be greater than 0, got {delay_sec}"
294         logger.critical(msg)
295         raise ValueError(msg)
296
297     def deco_retry(f):
298         @functools.wraps(f)
299         def f_retry(*args, **kwargs):
300             mtries, mdelay = tries, delay_sec  # make mutable
301             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
302             retval = f(*args, **kwargs)
303             while mtries > 0:
304                 if predicate(retval) is True:
305                     logger.debug('Predicate succeeded, deco_retry is done.')
306                     return retval
307                 logger.debug("Predicate failed, sleeping and retrying.")
308                 mtries -= 1
309                 time.sleep(mdelay)
310                 mdelay *= backoff
311                 retval = f(*args, **kwargs)
312             return retval
313         return f_retry
314     return deco_retry
315
316
317 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
318     return retry_predicate(
319         tries,
320         predicate=lambda x: x is True,
321         delay_sec=delay_sec,
322         backoff=backoff,
323     )
324
325
326 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
327     return retry_predicate(
328         tries,
329         predicate=lambda x: x is not None,
330         delay_sec=delay_sec,
331         backoff=backoff,
332     )
333
334
335 def deprecated(func):
336     """This is a decorator which can be used to mark functions
337     as deprecated. It will result in a warning being emitted
338     when the function is used.
339     """
340
341     @functools.wraps(func)
342     def wrapper_deprecated(*args, **kwargs):
343         msg = f"Call to deprecated function {func.__name__}"
344         logger.warning(msg)
345         warnings.warn(msg, category=DeprecationWarning)
346         return func(*args, **kwargs)
347
348     return wrapper_deprecated
349
350
351 def thunkify(func):
352     """
353     Make a function immediately return a function of no args which,
354     when called, waits for the result, which will start being
355     processed in another thread.
356     """
357
358     @functools.wraps(func)
359     def lazy_thunked(*args, **kwargs):
360         wait_event = threading.Event()
361
362         result = [None]
363         exc = [False, None]
364
365         def worker_func():
366             try:
367                 func_result = func(*args, **kwargs)
368                 result[0] = func_result
369             except Exception:
370                 exc[0] = True
371                 exc[1] = sys.exc_info()  # (type, value, traceback)
372                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
373                 print(msg)
374                 logger.warning(msg)
375             finally:
376                 wait_event.set()
377
378         def thunk():
379             wait_event.wait()
380             if exc[0]:
381                 raise exc[1][0](exc[1][1])
382             return result[0]
383
384         threading.Thread(target=worker_func).start()
385         return thunk
386
387     return lazy_thunked
388
389
390 ############################################################
391 # Timeout
392 ############################################################
393
394 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
395 # Used work of Stephen "Zero" Chappell <[email protected]>
396 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
397
398
399 def _raise_exception(exception, error_message: Optional[str]):
400     if error_message is None:
401         raise exception()
402     else:
403         raise exception(error_message)
404
405
406 def _target(queue, function, *args, **kwargs):
407     """Run a function with arguments and return output via a queue.
408
409     This is a helper function for the Process created in _Timeout. It runs
410     the function with positional arguments and keyword arguments and then
411     returns the function's output by way of a queue. If an exception gets
412     raised, it is returned to _Timeout to be raised by the value property.
413     """
414     try:
415         queue.put((True, function(*args, **kwargs)))
416     except Exception:
417         queue.put((False, sys.exc_info()[1]))
418
419
420 class _Timeout(object):
421     """Wrap a function and add a timeout (limit) attribute to it.
422
423     Instances of this class are automatically generated by the add_timeout
424     function defined below.
425     """
426
427     def __init__(
428         self,
429         function: Callable,
430         timeout_exception: Exception,
431         error_message: str,
432         seconds: float,
433     ):
434         self.__limit = seconds
435         self.__function = function
436         self.__timeout_exception = timeout_exception
437         self.__error_message = error_message
438         self.__name__ = function.__name__
439         self.__doc__ = function.__doc__
440         self.__timeout = time.time()
441         self.__process = multiprocessing.Process()
442         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
443
444     def __call__(self, *args, **kwargs):
445         """Execute the embedded function object asynchronously.
446
447         The function given to the constructor is transparently called and
448         requires that "ready" be intermittently polled. If and when it is
449         True, the "value" property may then be checked for returned data.
450         """
451         self.__limit = kwargs.pop("timeout", self.__limit)
452         self.__queue = multiprocessing.Queue(1)
453         args = (self.__queue, self.__function) + args
454         self.__process = multiprocessing.Process(
455             target=_target, args=args, kwargs=kwargs
456         )
457         self.__process.daemon = True
458         self.__process.start()
459         if self.__limit is not None:
460             self.__timeout = self.__limit + time.time()
461         while not self.ready:
462             time.sleep(0.1)
463         return self.value
464
465     def cancel(self):
466         """Terminate any possible execution of the embedded function."""
467         if self.__process.is_alive():
468             self.__process.terminate()
469         _raise_exception(self.__timeout_exception, self.__error_message)
470
471     @property
472     def ready(self):
473         """Read-only property indicating status of "value" property."""
474         if self.__limit and self.__timeout < time.time():
475             self.cancel()
476         return self.__queue.full() and not self.__queue.empty()
477
478     @property
479     def value(self):
480         """Read-only property containing data returned from function."""
481         if self.ready is True:
482             flag, load = self.__queue.get()
483             if flag:
484                 return load
485             raise load
486
487
488 def timeout(
489     seconds: float = 1.0,
490     use_signals: Optional[bool] = None,
491     timeout_exception=exceptions.TimeoutError,
492     error_message="Function call timed out",
493 ):
494     """Add a timeout parameter to a function and return the function.
495
496     Note: the use_signals parameter is included in order to support
497     multiprocessing scenarios (signal can only be used from the process'
498     main thread).  When not using signals, timeout granularity will be
499     rounded to the nearest 0.1s.
500
501     Raises an exception when the timeout is reached.
502
503     It is illegal to pass anything other than a function as the first
504     parameter.  The function is wrapped and returned to the caller.
505     """
506     if use_signals is None:
507         import thread_utils
508         use_signals = thread_utils.is_current_thread_main_thread()
509
510     def decorate(function):
511         if use_signals:
512
513             def handler(signum, frame):
514                 _raise_exception(timeout_exception, error_message)
515
516             @functools.wraps(function)
517             def new_function(*args, **kwargs):
518                 new_seconds = kwargs.pop("timeout", seconds)
519                 if new_seconds:
520                     old = signal.signal(signal.SIGALRM, handler)
521                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
522
523                 if not seconds:
524                     return function(*args, **kwargs)
525
526                 try:
527                     return function(*args, **kwargs)
528                 finally:
529                     if new_seconds:
530                         signal.setitimer(signal.ITIMER_REAL, 0)
531                         signal.signal(signal.SIGALRM, old)
532
533             return new_function
534         else:
535
536             @functools.wraps(function)
537             def new_function(*args, **kwargs):
538                 timeout_wrapper = _Timeout(
539                     function, timeout_exception, error_message, seconds
540                 )
541                 return timeout_wrapper(*args, **kwargs)
542
543             return new_function
544
545     return decorate
546
547
548 class non_reentrant_code(object):
549     def __init__(self):
550         self._lock = threading.RLock
551         self._entered = False
552
553     def __call__(self, f):
554         def _gatekeeper(*args, **kwargs):
555             with self._lock:
556                 if self._entered:
557                     return
558                 self._entered = True
559                 f(*args, **kwargs)
560                 self._entered = False
561
562         return _gatekeeper
563
564
565 class rlocked(object):
566     def __init__(self):
567         self._lock = threading.RLock
568         self._entered = False
569
570     def __call__(self, f):
571         def _gatekeeper(*args, **kwargs):
572             with self._lock:
573                 if self._entered:
574                     return
575                 self._entered = True
576                 f(*args, **kwargs)
577                 self._entered = False
578         return _gatekeeper
579
580
581 def call_with_sample_rate(sample_rate: float) -> Callable:
582     if not 0.0 <= sample_rate <= 1.0:
583         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
584         logger.critical(msg)
585         raise ValueError(msg)
586
587     def decorator(f):
588         @functools.wraps(f)
589         def _call_with_sample_rate(*args, **kwargs):
590             if random.uniform(0, 1) < sample_rate:
591                 return f(*args, **kwargs)
592             else:
593                 logger.debug(
594                     f"@call_with_sample_rate skipping a call to {f.__name__}"
595                 )
596         return _call_with_sample_rate
597     return decorator
598
599
600 def decorate_matching_methods_with(decorator, acl=None):
601     """Apply decorator to all methods in a class whose names begin with
602     prefix.  If prefix is None (default), decorate all methods in the
603     class.
604     """
605     def decorate_the_class(cls):
606         for name, m in inspect.getmembers(cls, inspect.isfunction):
607             if acl is None:
608                 setattr(cls, name, decorator(m))
609             else:
610                 if acl(name):
611                     setattr(cls, name, decorator(m))
612         return cls
613     return decorate_the_class
614
615
616 if __name__ == '__main__':
617     import doctest
618     doctest.testmod()
619
620