375cbad436feca20d57750b8276537c94f66f060
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 import exceptions
22
23
24 logger = logging.getLogger(__name__)
25
26
27 def timed(func: Callable) -> Callable:
28     """Print the runtime of the decorated function."""
29
30     @functools.wraps(func)
31     def wrapper_timer(*args, **kwargs):
32         start_time = time.perf_counter()
33         value = func(*args, **kwargs)
34         end_time = time.perf_counter()
35         run_time = end_time - start_time
36         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
37         print(msg)
38         logger.info(msg)
39         return value
40     return wrapper_timer
41
42
43 def invocation_logged(func: Callable) -> Callable:
44     """Log the call of a function."""
45
46     @functools.wraps(func)
47     def wrapper_invocation_logged(*args, **kwargs):
48         now = datetime.datetime.now()
49         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
50         msg = f"[{ts}]: Entered {func.__name__}"
51         print(msg)
52         logger.info(msg)
53         ret = func(*args, **kwargs)
54         now = datetime.datetime.now()
55         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
56         msg = f"[{ts}]: Exited {func.__name__}"
57         print(msg)
58         logger.info(msg)
59         return ret
60     return wrapper_invocation_logged
61
62
63 def debug_args(func: Callable) -> Callable:
64     """Print the function signature and return value at each call."""
65
66     @functools.wraps(func)
67     def wrapper_debug_args(*args, **kwargs):
68         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
69         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
70         signature = ", ".join(args_repr + kwargs_repr)
71         msg = f"Calling {func.__name__}({signature})"
72         print(msg)
73         logger.info(msg)
74         value = func(*args, **kwargs)
75         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
76         logger.info(msg)
77         return value
78     return wrapper_debug_args
79
80
81 def debug_count_calls(func: Callable) -> Callable:
82     """Count function invocations and print a message befor every call."""
83
84     @functools.wraps(func)
85     def wrapper_debug_count_calls(*args, **kwargs):
86         wrapper_debug_count_calls.num_calls += 1
87         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
88         print(msg)
89         logger.info(msg)
90         return func(*args, **kwargs)
91     wrapper_debug_count_calls.num_calls = 0
92     return wrapper_debug_count_calls
93
94
95 class DelayWhen(enum.Enum):
96     BEFORE_CALL = 1
97     AFTER_CALL = 2
98     BEFORE_AND_AFTER = 3
99
100
101 def delay(
102     _func: Callable = None,
103     *,
104     seconds: float = 1.0,
105     when: DelayWhen = DelayWhen.BEFORE_CALL,
106 ) -> Callable:
107     """Delay the execution of a function by sleeping before and/or after.
108
109     Slow down a function by inserting a delay before and/or after its
110     invocation.
111     """
112
113     def decorator_delay(func: Callable) -> Callable:
114         @functools.wraps(func)
115         def wrapper_delay(*args, **kwargs):
116             if when & DelayWhen.BEFORE_CALL:
117                 logger.debug(
118                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
119                 )
120                 time.sleep(seconds)
121             retval = func(*args, **kwargs)
122             if when & DelayWhen.AFTER_CALL:
123                 logger.debug(
124                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
125                 )
126                 time.sleep(seconds)
127             return retval
128         return wrapper_delay
129
130     if _func is None:
131         return decorator_delay
132     else:
133         return decorator_delay(_func)
134
135
136 class _SingletonWrapper:
137     """
138     A singleton wrapper class. Its instances would be created
139     for each decorated class.
140     """
141
142     def __init__(self, cls):
143         self.__wrapped__ = cls
144         self._instance = None
145
146     def __call__(self, *args, **kwargs):
147         """Returns a single instance of decorated class"""
148         logger.debug(
149             f"@singleton returning global instance of {self.__wrapped__.__name__}"
150         )
151         if self._instance is None:
152             self._instance = self.__wrapped__(*args, **kwargs)
153         return self._instance
154
155
156 def singleton(cls):
157     """
158     A singleton decorator. Returns a wrapper objects. A call on that object
159     returns a single instance object of decorated class. Use the __wrapped__
160     attribute to access decorated class directly in unit tests
161     """
162     return _SingletonWrapper(cls)
163
164
165 def memoized(func: Callable) -> Callable:
166     """Keep a cache of previous function call results.
167
168     The cache here is a dict with a key based on the arguments to the
169     call.  Consider also: functools.lru_cache for a more advanced
170     implementation.
171     """
172
173     @functools.wraps(func)
174     def wrapper_memoized(*args, **kwargs):
175         cache_key = args + tuple(kwargs.items())
176         if cache_key not in wrapper_memoized.cache:
177             value = func(*args, **kwargs)
178             logger.debug(
179                 f"Memoizing {cache_key} => {value} for {func.__name__}"
180             )
181             wrapper_memoized.cache[cache_key] = value
182         else:
183             logger.debug(f"Returning memoized value for {func.__name__}")
184         return wrapper_memoized.cache[cache_key]
185     wrapper_memoized.cache = dict()
186     return wrapper_memoized
187
188
189 def retry_predicate(
190     tries: int,
191     *,
192     predicate: Callable[..., bool],
193     delay_sec: float = 3,
194     backoff: float = 2.0,
195 ):
196     """Retries a function or method up to a certain number of times
197     with a prescribed initial delay period and backoff rate.
198
199     tries is the maximum number of attempts to run the function.
200     delay_sec sets the initial delay period in seconds.
201     backoff is a multiplied (must be >1) used to modify the delay.
202     predicate is a function that will be passed the retval of the
203       decorated function and must return True to stop or False to
204       retry.
205     """
206     if backoff < 1:
207         msg = f"backoff must be greater than or equal to 1, got {backoff}"
208         logger.critical(msg)
209         raise ValueError(msg)
210
211     tries = math.floor(tries)
212     if tries < 0:
213         msg = f"tries must be 0 or greater, got {tries}"
214         logger.critical(msg)
215         raise ValueError(msg)
216
217     if delay_sec <= 0:
218         msg = f"delay_sec must be greater than 0, got {delay_sec}"
219         logger.critical(msg)
220         raise ValueError(msg)
221
222     def deco_retry(f):
223         @functools.wraps(f)
224         def f_retry(*args, **kwargs):
225             mtries, mdelay = tries, delay_sec  # make mutable
226             retval = f(*args, **kwargs)
227             while mtries > 0:
228                 if predicate(retval) is True:
229                     return retval
230                 logger.debug("Predicate failed, sleeping and retrying.")
231                 mtries -= 1
232                 time.sleep(mdelay)
233                 mdelay *= backoff
234                 retval = f(*args, **kwargs)
235             return retval
236         return f_retry
237     return deco_retry
238
239
240 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
241     return retry_predicate(
242         tries,
243         predicate=lambda x: x is True,
244         delay_sec=delay_sec,
245         backoff=backoff,
246     )
247
248
249 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
250     return retry_predicate(
251         tries,
252         predicate=lambda x: x is not None,
253         delay_sec=delay_sec,
254         backoff=backoff,
255     )
256
257
258 def deprecated(func):
259     """This is a decorator which can be used to mark functions
260     as deprecated. It will result in a warning being emitted
261     when the function is used.
262     """
263
264     @functools.wraps(func)
265     def wrapper_deprecated(*args, **kwargs):
266         msg = f"Call to deprecated function {func.__name__}"
267         logger.warning(msg)
268         warnings.warn(msg, category=DeprecationWarning)
269         return func(*args, **kwargs)
270
271     return wrapper_deprecated
272
273
274 def thunkify(func):
275     """
276     Make a function immediately return a function of no args which,
277     when called, waits for the result, which will start being
278     processed in another thread.
279     """
280
281     @functools.wraps(func)
282     def lazy_thunked(*args, **kwargs):
283         wait_event = threading.Event()
284
285         result = [None]
286         exc = [False, None]
287
288         def worker_func():
289             try:
290                 func_result = func(*args, **kwargs)
291                 result[0] = func_result
292             except Exception:
293                 exc[0] = True
294                 exc[1] = sys.exc_info()  # (type, value, traceback)
295                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
296                 logger.warning(msg)
297                 print(msg)
298             finally:
299                 wait_event.set()
300
301         def thunk():
302             wait_event.wait()
303             if exc[0]:
304                 raise exc[1][0](exc[1][1])
305             return result[0]
306
307         threading.Thread(target=worker_func).start()
308         return thunk
309
310     return lazy_thunked
311
312
313 ############################################################
314 # Timeout
315 ############################################################
316
317 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
318 # Used work of Stephen "Zero" Chappell <[email protected]>
319 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
320
321
322 def _raise_exception(exception, error_message: Optional[str]):
323     if error_message is None:
324         raise exception()
325     else:
326         raise exception(error_message)
327
328
329 def _target(queue, function, *args, **kwargs):
330     """Run a function with arguments and return output via a queue.
331
332     This is a helper function for the Process created in _Timeout. It runs
333     the function with positional arguments and keyword arguments and then
334     returns the function's output by way of a queue. If an exception gets
335     raised, it is returned to _Timeout to be raised by the value property.
336     """
337     try:
338         queue.put((True, function(*args, **kwargs)))
339     except:
340         queue.put((False, sys.exc_info()[1]))
341
342
343 class _Timeout(object):
344     """Wrap a function and add a timeout (limit) attribute to it.
345
346     Instances of this class are automatically generated by the add_timeout
347     function defined below.
348     """
349
350     def __init__(
351         self,
352         function: Callable,
353         timeout_exception: Exception,
354         error_message: str,
355         seconds: float,
356     ):
357         self.__limit = seconds
358         self.__function = function
359         self.__timeout_exception = timeout_exception
360         self.__error_message = error_message
361         self.__name__ = function.__name__
362         self.__doc__ = function.__doc__
363         self.__timeout = time.time()
364         self.__process = multiprocessing.Process()
365         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
366
367     def __call__(self, *args, **kwargs):
368         """Execute the embedded function object asynchronously.
369
370         The function given to the constructor is transparently called and
371         requires that "ready" be intermittently polled. If and when it is
372         True, the "value" property may then be checked for returned data.
373         """
374         self.__limit = kwargs.pop("timeout", self.__limit)
375         self.__queue = multiprocessing.Queue(1)
376         args = (self.__queue, self.__function) + args
377         self.__process = multiprocessing.Process(
378             target=_target, args=args, kwargs=kwargs
379         )
380         self.__process.daemon = True
381         self.__process.start()
382         if self.__limit is not None:
383             self.__timeout = self.__limit + time.time()
384         while not self.ready:
385             time.sleep(0.1)
386         return self.value
387
388     def cancel(self):
389         """Terminate any possible execution of the embedded function."""
390         if self.__process.is_alive():
391             self.__process.terminate()
392         _raise_exception(self.__timeout_exception, self.__error_message)
393
394     @property
395     def ready(self):
396         """Read-only property indicating status of "value" property."""
397         if self.__limit and self.__timeout < time.time():
398             self.cancel()
399         return self.__queue.full() and not self.__queue.empty()
400
401     @property
402     def value(self):
403         """Read-only property containing data returned from function."""
404         if self.ready is True:
405             flag, load = self.__queue.get()
406             if flag:
407                 return load
408             raise load
409
410
411 def timeout(
412     seconds: float = 1.0,
413     use_signals: Optional[bool] = None,
414     timeout_exception=exceptions.TimeoutError,
415     error_message="Function call timed out",
416 ):
417     """Add a timeout parameter to a function and return the function.
418
419     Note: the use_signals parameter is included in order to support
420     multiprocessing scenarios (signal can only be used from the process'
421     main thread).  When not using signals, timeout granularity will be
422     rounded to the nearest 0.1s.
423
424     Raises an exception when the timeout is reached.
425
426     It is illegal to pass anything other than a function as the first
427     parameter.  The function is wrapped and returned to the caller.
428     """
429     if use_signals is None:
430         import thread_utils
431         use_signals = thread_utils.is_current_thread_main_thread()
432
433     def decorate(function):
434
435         if use_signals:
436
437             def handler(signum, frame):
438                 _raise_exception(timeout_exception, error_message)
439
440             @functools.wraps(function)
441             def new_function(*args, **kwargs):
442                 new_seconds = kwargs.pop("timeout", seconds)
443                 if new_seconds:
444                     old = signal.signal(signal.SIGALRM, handler)
445                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
446
447                 if not seconds:
448                     return function(*args, **kwargs)
449
450                 try:
451                     return function(*args, **kwargs)
452                 finally:
453                     if new_seconds:
454                         signal.setitimer(signal.ITIMER_REAL, 0)
455                         signal.signal(signal.SIGALRM, old)
456
457             return new_function
458         else:
459
460             @functools.wraps(function)
461             def new_function(*args, **kwargs):
462                 timeout_wrapper = _Timeout(
463                     function, timeout_exception, error_message, seconds
464                 )
465                 return timeout_wrapper(*args, **kwargs)
466
467             return new_function
468
469     return decorate
470
471
472 class non_reentrant_code(object):
473     def __init__(self):
474         self._lock = threading.RLock
475         self._entered = False
476
477     def __call__(self, f):
478         def _gatekeeper(*args, **kwargs):
479             with self._lock:
480                 if self._entered:
481                     return
482                 self._entered = True
483                 f(*args, **kwargs)
484                 self._entered = False
485
486         return _gatekeeper
487
488
489 class rlocked(object):
490     def __init__(self):
491         self._lock = threading.RLock
492         self._entered = False
493
494     def __call__(self, f):
495         def _gatekeeper(*args, **kwargs):
496             with self._lock:
497                 if self._entered:
498                     return
499                 self._entered = True
500                 f(*args, **kwargs)
501                 self._entered = False
502         return _gatekeeper
503
504
505 def call_with_sample_rate(sample_rate: float) -> Callable:
506     if not 0.0 <= sample_rate <= 1.0:
507         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
508         logger.critical(msg)
509         raise ValueError(msg)
510
511     def decorator(f):
512         @functools.wraps(f)
513         def _call_with_sample_rate(*args, **kwargs):
514             if random.uniform(0, 1) < sample_rate:
515                 return f(*args, **kwargs)
516             else:
517                 logger.debug(
518                     f"@call_with_sample_rate skipping a call to {f.__name__}"
519                 )
520         return _call_with_sample_rate
521     return decorator
522
523
524 def decorate_matching_methods_with(decorator, acl=None):
525     """Apply decorator to all methods in a class whose names begin with
526     prefix.  If prefix is None (default), decorate all methods in the
527     class.
528     """
529     def decorate_the_class(cls):
530         for name, m in inspect.getmembers(cls, inspect.isfunction):
531             if acl is None:
532                 setattr(cls, name, decorator(m))
533             else:
534                 if acl(name):
535                     setattr(cls, name, decorator(m))
536         return cls
537     return decorate_the_class