Overrides + debugging modules / functions in logging.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 # This module is commonly used by others in here and should avoid
22 # taking any unnecessary dependencies back on them.
23 import exceptions
24
25
26 logger = logging.getLogger(__name__)
27
28
29 def timed(func: Callable) -> Callable:
30     """Print the runtime of the decorated function."""
31
32     @functools.wraps(func)
33     def wrapper_timer(*args, **kwargs):
34         start_time = time.perf_counter()
35         value = func(*args, **kwargs)
36         end_time = time.perf_counter()
37         run_time = end_time - start_time
38         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
39         print(msg)
40         logger.info(msg)
41         return value
42     return wrapper_timer
43
44
45 def invocation_logged(func: Callable) -> Callable:
46     """Log the call of a function."""
47
48     @functools.wraps(func)
49     def wrapper_invocation_logged(*args, **kwargs):
50         now = datetime.datetime.now()
51         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
52         msg = f"[{ts}]: Entered {func.__name__}"
53         print(msg)
54         logger.info(msg)
55         ret = func(*args, **kwargs)
56         now = datetime.datetime.now()
57         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
58         msg = f"[{ts}]: Exited {func.__name__}"
59         print(msg)
60         logger.info(msg)
61         return ret
62     return wrapper_invocation_logged
63
64
65 def debug_args(func: Callable) -> Callable:
66     """Print the function signature and return value at each call."""
67
68     @functools.wraps(func)
69     def wrapper_debug_args(*args, **kwargs):
70         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
71         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
72         signature = ", ".join(args_repr + kwargs_repr)
73         msg = f"Calling {func.__name__}({signature})"
74         print(msg)
75         logger.info(msg)
76         value = func(*args, **kwargs)
77         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
78         logger.info(msg)
79         return value
80     return wrapper_debug_args
81
82
83 def debug_count_calls(func: Callable) -> Callable:
84     """Count function invocations and print a message befor every call."""
85
86     @functools.wraps(func)
87     def wrapper_debug_count_calls(*args, **kwargs):
88         wrapper_debug_count_calls.num_calls += 1
89         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
90         print(msg)
91         logger.info(msg)
92         return func(*args, **kwargs)
93     wrapper_debug_count_calls.num_calls = 0
94     return wrapper_debug_count_calls
95
96
97 class DelayWhen(enum.Enum):
98     BEFORE_CALL = 1
99     AFTER_CALL = 2
100     BEFORE_AND_AFTER = 3
101
102
103 def delay(
104     _func: Callable = None,
105     *,
106     seconds: float = 1.0,
107     when: DelayWhen = DelayWhen.BEFORE_CALL,
108 ) -> Callable:
109     """Delay the execution of a function by sleeping before and/or after.
110
111     Slow down a function by inserting a delay before and/or after its
112     invocation.
113     """
114
115     def decorator_delay(func: Callable) -> Callable:
116         @functools.wraps(func)
117         def wrapper_delay(*args, **kwargs):
118             if when & DelayWhen.BEFORE_CALL:
119                 logger.debug(
120                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
121                 )
122                 time.sleep(seconds)
123             retval = func(*args, **kwargs)
124             if when & DelayWhen.AFTER_CALL:
125                 logger.debug(
126                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
127                 )
128                 time.sleep(seconds)
129             return retval
130         return wrapper_delay
131
132     if _func is None:
133         return decorator_delay
134     else:
135         return decorator_delay(_func)
136
137
138 class _SingletonWrapper:
139     """
140     A singleton wrapper class. Its instances would be created
141     for each decorated class.
142     """
143
144     def __init__(self, cls):
145         self.__wrapped__ = cls
146         self._instance = None
147
148     def __call__(self, *args, **kwargs):
149         """Returns a single instance of decorated class"""
150         logger.debug(
151             f"@singleton returning global instance of {self.__wrapped__.__name__}"
152         )
153         if self._instance is None:
154             self._instance = self.__wrapped__(*args, **kwargs)
155         return self._instance
156
157
158 def singleton(cls):
159     """
160     A singleton decorator. Returns a wrapper objects. A call on that object
161     returns a single instance object of decorated class. Use the __wrapped__
162     attribute to access decorated class directly in unit tests
163     """
164     return _SingletonWrapper(cls)
165
166
167 def memoized(func: Callable) -> Callable:
168     """Keep a cache of previous function call results.
169
170     The cache here is a dict with a key based on the arguments to the
171     call.  Consider also: functools.lru_cache for a more advanced
172     implementation.
173     """
174
175     @functools.wraps(func)
176     def wrapper_memoized(*args, **kwargs):
177         cache_key = args + tuple(kwargs.items())
178         if cache_key not in wrapper_memoized.cache:
179             value = func(*args, **kwargs)
180             logger.debug(
181                 f"Memoizing {cache_key} => {value} for {func.__name__}"
182             )
183             wrapper_memoized.cache[cache_key] = value
184         else:
185             logger.debug(f"Returning memoized value for {func.__name__}")
186         return wrapper_memoized.cache[cache_key]
187     wrapper_memoized.cache = dict()
188     return wrapper_memoized
189
190
191 def retry_predicate(
192     tries: int,
193     *,
194     predicate: Callable[..., bool],
195     delay_sec: float = 3.0,
196     backoff: float = 2.0,
197 ):
198     """Retries a function or method up to a certain number of times
199     with a prescribed initial delay period and backoff rate.
200
201     tries is the maximum number of attempts to run the function.
202     delay_sec sets the initial delay period in seconds.
203     backoff is a multiplied (must be >1) used to modify the delay.
204     predicate is a function that will be passed the retval of the
205     decorated function and must return True to stop or False to
206     retry.
207     """
208     if backoff < 1.0:
209         msg = f"backoff must be greater than or equal to 1, got {backoff}"
210         logger.critical(msg)
211         raise ValueError(msg)
212
213     tries = math.floor(tries)
214     if tries < 0:
215         msg = f"tries must be 0 or greater, got {tries}"
216         logger.critical(msg)
217         raise ValueError(msg)
218
219     if delay_sec <= 0:
220         msg = f"delay_sec must be greater than 0, got {delay_sec}"
221         logger.critical(msg)
222         raise ValueError(msg)
223
224     def deco_retry(f):
225         @functools.wraps(f)
226         def f_retry(*args, **kwargs):
227             mtries, mdelay = tries, delay_sec  # make mutable
228             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
229             retval = f(*args, **kwargs)
230             while mtries > 0:
231                 if predicate(retval) is True:
232                     logger.debug('Predicate succeeded, deco_retry is done.')
233                     return retval
234                 logger.debug("Predicate failed, sleeping and retrying.")
235                 mtries -= 1
236                 time.sleep(mdelay)
237                 mdelay *= backoff
238                 retval = f(*args, **kwargs)
239             return retval
240         return f_retry
241     return deco_retry
242
243
244 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
245     return retry_predicate(
246         tries,
247         predicate=lambda x: x is True,
248         delay_sec=delay_sec,
249         backoff=backoff,
250     )
251
252
253 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
254     return retry_predicate(
255         tries,
256         predicate=lambda x: x is not None,
257         delay_sec=delay_sec,
258         backoff=backoff,
259     )
260
261
262 def deprecated(func):
263     """This is a decorator which can be used to mark functions
264     as deprecated. It will result in a warning being emitted
265     when the function is used.
266     """
267
268     @functools.wraps(func)
269     def wrapper_deprecated(*args, **kwargs):
270         msg = f"Call to deprecated function {func.__name__}"
271         logger.warning(msg)
272         warnings.warn(msg, category=DeprecationWarning)
273         return func(*args, **kwargs)
274
275     return wrapper_deprecated
276
277
278 def thunkify(func):
279     """
280     Make a function immediately return a function of no args which,
281     when called, waits for the result, which will start being
282     processed in another thread.
283     """
284
285     @functools.wraps(func)
286     def lazy_thunked(*args, **kwargs):
287         wait_event = threading.Event()
288
289         result = [None]
290         exc = [False, None]
291
292         def worker_func():
293             try:
294                 func_result = func(*args, **kwargs)
295                 result[0] = func_result
296             except Exception:
297                 exc[0] = True
298                 exc[1] = sys.exc_info()  # (type, value, traceback)
299                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
300                 logger.warning(msg)
301                 print(msg)
302             finally:
303                 wait_event.set()
304
305         def thunk():
306             wait_event.wait()
307             if exc[0]:
308                 raise exc[1][0](exc[1][1])
309             return result[0]
310
311         threading.Thread(target=worker_func).start()
312         return thunk
313
314     return lazy_thunked
315
316
317 ############################################################
318 # Timeout
319 ############################################################
320
321 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
322 # Used work of Stephen "Zero" Chappell <[email protected]>
323 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
324
325
326 def _raise_exception(exception, error_message: Optional[str]):
327     if error_message is None:
328         raise exception()
329     else:
330         raise exception(error_message)
331
332
333 def _target(queue, function, *args, **kwargs):
334     """Run a function with arguments and return output via a queue.
335
336     This is a helper function for the Process created in _Timeout. It runs
337     the function with positional arguments and keyword arguments and then
338     returns the function's output by way of a queue. If an exception gets
339     raised, it is returned to _Timeout to be raised by the value property.
340     """
341     try:
342         queue.put((True, function(*args, **kwargs)))
343     except Exception:
344         queue.put((False, sys.exc_info()[1]))
345
346
347 class _Timeout(object):
348     """Wrap a function and add a timeout (limit) attribute to it.
349
350     Instances of this class are automatically generated by the add_timeout
351     function defined below.
352     """
353
354     def __init__(
355         self,
356         function: Callable,
357         timeout_exception: Exception,
358         error_message: str,
359         seconds: float,
360     ):
361         self.__limit = seconds
362         self.__function = function
363         self.__timeout_exception = timeout_exception
364         self.__error_message = error_message
365         self.__name__ = function.__name__
366         self.__doc__ = function.__doc__
367         self.__timeout = time.time()
368         self.__process = multiprocessing.Process()
369         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
370
371     def __call__(self, *args, **kwargs):
372         """Execute the embedded function object asynchronously.
373
374         The function given to the constructor is transparently called and
375         requires that "ready" be intermittently polled. If and when it is
376         True, the "value" property may then be checked for returned data.
377         """
378         self.__limit = kwargs.pop("timeout", self.__limit)
379         self.__queue = multiprocessing.Queue(1)
380         args = (self.__queue, self.__function) + args
381         self.__process = multiprocessing.Process(
382             target=_target, args=args, kwargs=kwargs
383         )
384         self.__process.daemon = True
385         self.__process.start()
386         if self.__limit is not None:
387             self.__timeout = self.__limit + time.time()
388         while not self.ready:
389             time.sleep(0.1)
390         return self.value
391
392     def cancel(self):
393         """Terminate any possible execution of the embedded function."""
394         if self.__process.is_alive():
395             self.__process.terminate()
396         _raise_exception(self.__timeout_exception, self.__error_message)
397
398     @property
399     def ready(self):
400         """Read-only property indicating status of "value" property."""
401         if self.__limit and self.__timeout < time.time():
402             self.cancel()
403         return self.__queue.full() and not self.__queue.empty()
404
405     @property
406     def value(self):
407         """Read-only property containing data returned from function."""
408         if self.ready is True:
409             flag, load = self.__queue.get()
410             if flag:
411                 return load
412             raise load
413
414
415 def timeout(
416     seconds: float = 1.0,
417     use_signals: Optional[bool] = None,
418     timeout_exception=exceptions.TimeoutError,
419     error_message="Function call timed out",
420 ):
421     """Add a timeout parameter to a function and return the function.
422
423     Note: the use_signals parameter is included in order to support
424     multiprocessing scenarios (signal can only be used from the process'
425     main thread).  When not using signals, timeout granularity will be
426     rounded to the nearest 0.1s.
427
428     Raises an exception when the timeout is reached.
429
430     It is illegal to pass anything other than a function as the first
431     parameter.  The function is wrapped and returned to the caller.
432     """
433     if use_signals is None:
434         import thread_utils
435         use_signals = thread_utils.is_current_thread_main_thread()
436
437     def decorate(function):
438         if use_signals:
439
440             def handler(signum, frame):
441                 _raise_exception(timeout_exception, error_message)
442
443             @functools.wraps(function)
444             def new_function(*args, **kwargs):
445                 new_seconds = kwargs.pop("timeout", seconds)
446                 if new_seconds:
447                     old = signal.signal(signal.SIGALRM, handler)
448                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
449
450                 if not seconds:
451                     return function(*args, **kwargs)
452
453                 try:
454                     return function(*args, **kwargs)
455                 finally:
456                     if new_seconds:
457                         signal.setitimer(signal.ITIMER_REAL, 0)
458                         signal.signal(signal.SIGALRM, old)
459
460             return new_function
461         else:
462
463             @functools.wraps(function)
464             def new_function(*args, **kwargs):
465                 timeout_wrapper = _Timeout(
466                     function, timeout_exception, error_message, seconds
467                 )
468                 return timeout_wrapper(*args, **kwargs)
469
470             return new_function
471
472     return decorate
473
474
475 class non_reentrant_code(object):
476     def __init__(self):
477         self._lock = threading.RLock
478         self._entered = False
479
480     def __call__(self, f):
481         def _gatekeeper(*args, **kwargs):
482             with self._lock:
483                 if self._entered:
484                     return
485                 self._entered = True
486                 f(*args, **kwargs)
487                 self._entered = False
488
489         return _gatekeeper
490
491
492 class rlocked(object):
493     def __init__(self):
494         self._lock = threading.RLock
495         self._entered = False
496
497     def __call__(self, f):
498         def _gatekeeper(*args, **kwargs):
499             with self._lock:
500                 if self._entered:
501                     return
502                 self._entered = True
503                 f(*args, **kwargs)
504                 self._entered = False
505         return _gatekeeper
506
507
508 def call_with_sample_rate(sample_rate: float) -> Callable:
509     if not 0.0 <= sample_rate <= 1.0:
510         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
511         logger.critical(msg)
512         raise ValueError(msg)
513
514     def decorator(f):
515         @functools.wraps(f)
516         def _call_with_sample_rate(*args, **kwargs):
517             if random.uniform(0, 1) < sample_rate:
518                 return f(*args, **kwargs)
519             else:
520                 logger.debug(
521                     f"@call_with_sample_rate skipping a call to {f.__name__}"
522                 )
523         return _call_with_sample_rate
524     return decorator
525
526
527 def decorate_matching_methods_with(decorator, acl=None):
528     """Apply decorator to all methods in a class whose names begin with
529     prefix.  If prefix is None (default), decorate all methods in the
530     class.
531     """
532     def decorate_the_class(cls):
533         for name, m in inspect.getmembers(cls, inspect.isfunction):
534             if acl is None:
535                 setattr(cls, name, decorator(m))
536             else:
537                 if acl(name):
538                     setattr(cls, name, decorator(m))
539         return cls
540     return decorate_the_class