Initial revision
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Callable, Optional
18 import warnings
19
20 import thread_utils
21
22 logger = logging.getLogger(__name__)
23
24
25 def timed(func: Callable) -> Callable:
26     """Print the runtime of the decorated function."""
27
28     @functools.wraps(func)
29     def wrapper_timer(*args, **kwargs):
30         start_time = time.perf_counter()
31         value = func(*args, **kwargs)
32         end_time = time.perf_counter()
33         run_time = end_time - start_time
34         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
35         print(msg)
36         logger.info(msg)
37         return value
38     return wrapper_timer
39
40
41 def invocation_logged(func: Callable) -> Callable:
42     """Log the call of a function."""
43
44     @functools.wraps(func)
45     def wrapper_invocation_logged(*args, **kwargs):
46         now = datetime.datetime.now()
47         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
48         msg = f"[{ts}]: Entered {func.__name__}"
49         print(msg)
50         logger.info(msg)
51         ret = func(*args, **kwargs)
52         now = datetime.datetime.now()
53         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
54         msg = f"[{ts}]: Exited {func.__name__}"
55         print(msg)
56         logger.info(msg)
57         return ret
58     return wrapper_invocation_logged
59
60
61 def debug_args(func: Callable) -> Callable:
62     """Print the function signature and return value at each call."""
63
64     @functools.wraps(func)
65     def wrapper_debug_args(*args, **kwargs):
66         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
67         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
68         signature = ", ".join(args_repr + kwargs_repr)
69         msg = f"Calling {func.__name__}({signature})"
70         print(msg)
71         logger.info(msg)
72         value = func(*args, **kwargs)
73         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
74         logger.info(msg)
75         return value
76     return wrapper_debug_args
77
78
79 def debug_count_calls(func: Callable) -> Callable:
80     """Count function invocations and print a message befor every call."""
81
82     @functools.wraps(func)
83     def wrapper_debug_count_calls(*args, **kwargs):
84         wrapper_debug_count_calls.num_calls += 1
85         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
86         print(msg)
87         logger.info(msg)
88         return func(*args, **kwargs)
89     wrapper_debug_count_calls.num_calls = 0
90     return wrapper_debug_count_calls
91
92
93 class DelayWhen(enum.Enum):
94     BEFORE_CALL = 1
95     AFTER_CALL = 2
96     BEFORE_AND_AFTER = 3
97
98
99 def delay(
100     _func: Callable = None,
101     *,
102     seconds: float = 1.0,
103     when: DelayWhen = DelayWhen.BEFORE_CALL,
104 ) -> Callable:
105     """Delay the execution of a function by sleeping before and/or after.
106
107     Slow down a function by inserting a delay before and/or after its
108     invocation.
109     """
110
111     def decorator_delay(func: Callable) -> Callable:
112         @functools.wraps(func)
113         def wrapper_delay(*args, **kwargs):
114             if when & DelayWhen.BEFORE_CALL:
115                 logger.debug(
116                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
117                 )
118                 time.sleep(seconds)
119             retval = func(*args, **kwargs)
120             if when & DelayWhen.AFTER_CALL:
121                 logger.debug(
122                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
123                 )
124                 time.sleep(seconds)
125             return retval
126         return wrapper_delay
127
128     if _func is None:
129         return decorator_delay
130     else:
131         return decorator_delay(_func)
132
133
134 class _SingletonWrapper:
135     """
136     A singleton wrapper class. Its instances would be created
137     for each decorated class.
138     """
139
140     def __init__(self, cls):
141         self.__wrapped__ = cls
142         self._instance = None
143
144     def __call__(self, *args, **kwargs):
145         """Returns a single instance of decorated class"""
146         logger.debug(
147             f"@singleton returning global instance of {self.__wrapped__.__name__}"
148         )
149         if self._instance is None:
150             self._instance = self.__wrapped__(*args, **kwargs)
151         return self._instance
152
153
154 def singleton(cls):
155     """
156     A singleton decorator. Returns a wrapper objects. A call on that object
157     returns a single instance object of decorated class. Use the __wrapped__
158     attribute to access decorated class directly in unit tests
159     """
160     return _SingletonWrapper(cls)
161
162
163 def memoized(func: Callable) -> Callable:
164     """Keep a cache of previous function call results.
165
166     The cache here is a dict with a key based on the arguments to the
167     call.  Consider also: functools.lru_cache for a more advanced
168     implementation.
169     """
170
171     @functools.wraps(func)
172     def wrapper_memoized(*args, **kwargs):
173         cache_key = args + tuple(kwargs.items())
174         if cache_key not in wrapper_memoized.cache:
175             value = func(*args, **kwargs)
176             logger.debug(
177                 f"Memoizing {cache_key} => {value} for {func.__name__}"
178             )
179             wrapper_memoized.cache[cache_key] = value
180         else:
181             logger.debug(f"Returning memoized value for {func.__name__}")
182         return wrapper_memoized.cache[cache_key]
183     wrapper_memoized.cache = dict()
184     return wrapper_memoized
185
186
187 def retry_predicate(
188     tries: int,
189     *,
190     predicate: Callable[..., bool],
191     delay_sec: float = 3,
192     backoff: float = 2.0,
193 ):
194     """Retries a function or method up to a certain number of times
195     with a prescribed initial delay period and backoff rate.
196
197     tries is the maximum number of attempts to run the function.
198     delay_sec sets the initial delay period in seconds.
199     backoff is a multiplied (must be >1) used to modify the delay.
200     predicate is a function that will be passed the retval of the
201       decorated function and must return True to stop or False to
202       retry.
203     """
204     if backoff < 1:
205         msg = f"backoff must be greater than or equal to 1, got {backoff}"
206         logger.critical(msg)
207         raise ValueError(msg)
208
209     tries = math.floor(tries)
210     if tries < 0:
211         msg = f"tries must be 0 or greater, got {tries}"
212         logger.critical(msg)
213         raise ValueError(msg)
214
215     if delay_sec <= 0:
216         msg = f"delay_sec must be greater than 0, got {delay_sec}"
217         logger.critical(msg)
218         raise ValueError(msg)
219
220     def deco_retry(f):
221         @functools.wraps(f)
222         def f_retry(*args, **kwargs):
223             mtries, mdelay = tries, delay_sec  # make mutable
224             retval = f(*args, **kwargs)
225             while mtries > 0:
226                 if predicate(retval) is True:
227                     return retval
228                 logger.debug("Predicate failed, sleeping and retrying.")
229                 mtries -= 1
230                 time.sleep(mdelay)
231                 mdelay *= backoff
232                 retval = f(*args, **kwargs)
233             return retval
234         return f_retry
235     return deco_retry
236
237
238 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
239     return retry_predicate(
240         tries,
241         predicate=lambda x: x is True,
242         delay_sec=delay_sec,
243         backoff=backoff,
244     )
245
246
247 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
248     return retry_predicate(
249         tries,
250         predicate=lambda x: x is not None,
251         delay_sec=delay_sec,
252         backoff=backoff,
253     )
254
255
256 def deprecated(func):
257     """This is a decorator which can be used to mark functions
258     as deprecated. It will result in a warning being emitted
259     when the function is used.
260     """
261
262     @functools.wraps(func)
263     def wrapper_deprecated(*args, **kwargs):
264         msg = f"Call to deprecated function {func.__name__}"
265         logger.warning(msg)
266         warnings.warn(msg, category=DeprecationWarning)
267         return func(*args, **kwargs)
268
269     return wrapper_deprecated
270
271
272 def thunkify(func):
273     """
274     Make a function immediately return a function of no args which,
275     when called, waits for the result, which will start being
276     processed in another thread.
277     """
278
279     @functools.wraps(func)
280     def lazy_thunked(*args, **kwargs):
281         wait_event = threading.Event()
282
283         result = [None]
284         exc = [False, None]
285
286         def worker_func():
287             try:
288                 func_result = func(*args, **kwargs)
289                 result[0] = func_result
290             except Exception:
291                 exc[0] = True
292                 exc[1] = sys.exc_info()  # (type, value, traceback)
293                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
294                 logger.warning(msg)
295                 print(msg)
296             finally:
297                 wait_event.set()
298
299         def thunk():
300             wait_event.wait()
301             if exc[0]:
302                 raise exc[1][0](exc[1][1])
303             return result[0]
304
305         threading.Thread(target=worker_func).start()
306         return thunk
307
308     return lazy_thunked
309
310
311 ############################################################
312 # Timeout
313 ############################################################
314
315 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
316 # Used work of Stephen "Zero" Chappell <[email protected]>
317 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
318
319
320 class TimeoutError(AssertionError):
321     def __init__(self, value: str = "Timed Out"):
322         self.value = value
323
324     def __str__(self):
325         return repr(self.value)
326
327
328 def _raise_exception(exception, error_message: Optional[str]):
329     if error_message is None:
330         raise exception()
331     else:
332         raise exception(error_message)
333
334
335 def _target(queue, function, *args, **kwargs):
336     """Run a function with arguments and return output via a queue.
337
338     This is a helper function for the Process created in _Timeout. It runs
339     the function with positional arguments and keyword arguments and then
340     returns the function's output by way of a queue. If an exception gets
341     raised, it is returned to _Timeout to be raised by the value property.
342     """
343     try:
344         queue.put((True, function(*args, **kwargs)))
345     except:
346         queue.put((False, sys.exc_info()[1]))
347
348
349 class _Timeout(object):
350     """Wrap a function and add a timeout (limit) attribute to it.
351
352     Instances of this class are automatically generated by the add_timeout
353     function defined below.
354     """
355
356     def __init__(
357         self,
358         function: Callable,
359         timeout_exception: Exception,
360         error_message: str,
361         seconds: float,
362     ):
363         self.__limit = seconds
364         self.__function = function
365         self.__timeout_exception = timeout_exception
366         self.__error_message = error_message
367         self.__name__ = function.__name__
368         self.__doc__ = function.__doc__
369         self.__timeout = time.time()
370         self.__process = multiprocessing.Process()
371         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
372
373     def __call__(self, *args, **kwargs):
374         """Execute the embedded function object asynchronously.
375
376         The function given to the constructor is transparently called and
377         requires that "ready" be intermittently polled. If and when it is
378         True, the "value" property may then be checked for returned data.
379         """
380         self.__limit = kwargs.pop("timeout", self.__limit)
381         self.__queue = multiprocessing.Queue(1)
382         args = (self.__queue, self.__function) + args
383         self.__process = multiprocessing.Process(
384             target=_target, args=args, kwargs=kwargs
385         )
386         self.__process.daemon = True
387         self.__process.start()
388         if self.__limit is not None:
389             self.__timeout = self.__limit + time.time()
390         while not self.ready:
391             time.sleep(0.1)
392         return self.value
393
394     def cancel(self):
395         """Terminate any possible execution of the embedded function."""
396         if self.__process.is_alive():
397             self.__process.terminate()
398         _raise_exception(self.__timeout_exception, self.__error_message)
399
400     @property
401     def ready(self):
402         """Read-only property indicating status of "value" property."""
403         if self.__limit and self.__timeout < time.time():
404             self.cancel()
405         return self.__queue.full() and not self.__queue.empty()
406
407     @property
408     def value(self):
409         """Read-only property containing data returned from function."""
410         if self.ready is True:
411             flag, load = self.__queue.get()
412             if flag:
413                 return load
414             raise load
415
416
417 def timeout(
418     seconds: float = 1.0,
419     use_signals: Optional[bool] = None,
420     timeout_exception=TimeoutError,
421     error_message="Function call timed out",
422 ):
423     """Add a timeout parameter to a function and return the function.
424
425     Note: the use_signals parameter is included in order to support
426     multiprocessing scenarios (signal can only be used from the process'
427     main thread).  When not using signals, timeout granularity will be
428     rounded to the nearest 0.1s.
429
430     Raises an exception when the timeout is reached.
431
432     It is illegal to pass anything other than a function as the first
433     parameter.  The function is wrapped and returned to the caller.
434     """
435     if use_signals is None:
436         use_signals = thread_utils.is_current_thread_main_thread()
437
438     def decorate(function):
439
440         if use_signals:
441
442             def handler(signum, frame):
443                 _raise_exception(timeout_exception, error_message)
444
445             @functools.wraps(function)
446             def new_function(*args, **kwargs):
447                 new_seconds = kwargs.pop("timeout", seconds)
448                 if new_seconds:
449                     old = signal.signal(signal.SIGALRM, handler)
450                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
451
452                 if not seconds:
453                     return function(*args, **kwargs)
454
455                 try:
456                     return function(*args, **kwargs)
457                 finally:
458                     if new_seconds:
459                         signal.setitimer(signal.ITIMER_REAL, 0)
460                         signal.signal(signal.SIGALRM, old)
461
462             return new_function
463         else:
464
465             @functools.wraps(function)
466             def new_function(*args, **kwargs):
467                 timeout_wrapper = _Timeout(
468                     function, timeout_exception, error_message, seconds
469                 )
470                 return timeout_wrapper(*args, **kwargs)
471
472             return new_function
473
474     return decorate
475
476
477 class non_reentrant_code(object):
478     def __init__(self):
479         self._lock = threading.RLock
480         self._entered = False
481
482     def __call__(self, f):
483         def _gatekeeper(*args, **kwargs):
484             with self._lock:
485                 if self._entered:
486                     return
487                 self._entered = True
488                 f(*args, **kwargs)
489                 self._entered = False
490
491         return _gatekeeper
492
493
494 class rlocked(object):
495     def __init__(self):
496         self._lock = threading.RLock
497         self._entered = False
498
499     def __call__(self, f):
500         def _gatekeeper(*args, **kwargs):
501             with self._lock:
502                 if self._entered:
503                     return
504                 self._entered = True
505                 f(*args, **kwargs)
506                 self._entered = False
507         return _gatekeeper
508
509
510 def call_with_sample_rate(sample_rate: float) -> Callable:
511     if not 0.0 <= sample_rate <= 1.0:
512         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
513         logger.critical(msg)
514         raise ValueError(msg)
515
516     def decorator(f):
517         @functools.wraps(f)
518         def _call_with_sample_rate(*args, **kwargs):
519             if random.uniform(0, 1) < sample_rate:
520                 return f(*args, **kwargs)
521             else:
522                 logger.debug(
523                     f"@call_with_sample_rate skipping a call to {f.__name__}"
524                 )
525         return _call_with_sample_rate
526     return decorator