Lots of changes.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 import exceptions
22 import thread_utils
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function."""
30
31     @functools.wraps(func)
32     def wrapper_timer(*args, **kwargs):
33         start_time = time.perf_counter()
34         value = func(*args, **kwargs)
35         end_time = time.perf_counter()
36         run_time = end_time - start_time
37         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
38         print(msg)
39         logger.info(msg)
40         return value
41     return wrapper_timer
42
43
44 def invocation_logged(func: Callable) -> Callable:
45     """Log the call of a function."""
46
47     @functools.wraps(func)
48     def wrapper_invocation_logged(*args, **kwargs):
49         now = datetime.datetime.now()
50         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
51         msg = f"[{ts}]: Entered {func.__name__}"
52         print(msg)
53         logger.info(msg)
54         ret = func(*args, **kwargs)
55         now = datetime.datetime.now()
56         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
57         msg = f"[{ts}]: Exited {func.__name__}"
58         print(msg)
59         logger.info(msg)
60         return ret
61     return wrapper_invocation_logged
62
63
64 def debug_args(func: Callable) -> Callable:
65     """Print the function signature and return value at each call."""
66
67     @functools.wraps(func)
68     def wrapper_debug_args(*args, **kwargs):
69         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
70         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
71         signature = ", ".join(args_repr + kwargs_repr)
72         msg = f"Calling {func.__name__}({signature})"
73         print(msg)
74         logger.info(msg)
75         value = func(*args, **kwargs)
76         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
77         logger.info(msg)
78         return value
79     return wrapper_debug_args
80
81
82 def debug_count_calls(func: Callable) -> Callable:
83     """Count function invocations and print a message befor every call."""
84
85     @functools.wraps(func)
86     def wrapper_debug_count_calls(*args, **kwargs):
87         wrapper_debug_count_calls.num_calls += 1
88         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
89         print(msg)
90         logger.info(msg)
91         return func(*args, **kwargs)
92     wrapper_debug_count_calls.num_calls = 0
93     return wrapper_debug_count_calls
94
95
96 class DelayWhen(enum.Enum):
97     BEFORE_CALL = 1
98     AFTER_CALL = 2
99     BEFORE_AND_AFTER = 3
100
101
102 def delay(
103     _func: Callable = None,
104     *,
105     seconds: float = 1.0,
106     when: DelayWhen = DelayWhen.BEFORE_CALL,
107 ) -> Callable:
108     """Delay the execution of a function by sleeping before and/or after.
109
110     Slow down a function by inserting a delay before and/or after its
111     invocation.
112     """
113
114     def decorator_delay(func: Callable) -> Callable:
115         @functools.wraps(func)
116         def wrapper_delay(*args, **kwargs):
117             if when & DelayWhen.BEFORE_CALL:
118                 logger.debug(
119                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
120                 )
121                 time.sleep(seconds)
122             retval = func(*args, **kwargs)
123             if when & DelayWhen.AFTER_CALL:
124                 logger.debug(
125                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
126                 )
127                 time.sleep(seconds)
128             return retval
129         return wrapper_delay
130
131     if _func is None:
132         return decorator_delay
133     else:
134         return decorator_delay(_func)
135
136
137 class _SingletonWrapper:
138     """
139     A singleton wrapper class. Its instances would be created
140     for each decorated class.
141     """
142
143     def __init__(self, cls):
144         self.__wrapped__ = cls
145         self._instance = None
146
147     def __call__(self, *args, **kwargs):
148         """Returns a single instance of decorated class"""
149         logger.debug(
150             f"@singleton returning global instance of {self.__wrapped__.__name__}"
151         )
152         if self._instance is None:
153             self._instance = self.__wrapped__(*args, **kwargs)
154         return self._instance
155
156
157 def singleton(cls):
158     """
159     A singleton decorator. Returns a wrapper objects. A call on that object
160     returns a single instance object of decorated class. Use the __wrapped__
161     attribute to access decorated class directly in unit tests
162     """
163     return _SingletonWrapper(cls)
164
165
166 def memoized(func: Callable) -> Callable:
167     """Keep a cache of previous function call results.
168
169     The cache here is a dict with a key based on the arguments to the
170     call.  Consider also: functools.lru_cache for a more advanced
171     implementation.
172     """
173
174     @functools.wraps(func)
175     def wrapper_memoized(*args, **kwargs):
176         cache_key = args + tuple(kwargs.items())
177         if cache_key not in wrapper_memoized.cache:
178             value = func(*args, **kwargs)
179             logger.debug(
180                 f"Memoizing {cache_key} => {value} for {func.__name__}"
181             )
182             wrapper_memoized.cache[cache_key] = value
183         else:
184             logger.debug(f"Returning memoized value for {func.__name__}")
185         return wrapper_memoized.cache[cache_key]
186     wrapper_memoized.cache = dict()
187     return wrapper_memoized
188
189
190 def retry_predicate(
191     tries: int,
192     *,
193     predicate: Callable[..., bool],
194     delay_sec: float = 3,
195     backoff: float = 2.0,
196 ):
197     """Retries a function or method up to a certain number of times
198     with a prescribed initial delay period and backoff rate.
199
200     tries is the maximum number of attempts to run the function.
201     delay_sec sets the initial delay period in seconds.
202     backoff is a multiplied (must be >1) used to modify the delay.
203     predicate is a function that will be passed the retval of the
204       decorated function and must return True to stop or False to
205       retry.
206     """
207     if backoff < 1:
208         msg = f"backoff must be greater than or equal to 1, got {backoff}"
209         logger.critical(msg)
210         raise ValueError(msg)
211
212     tries = math.floor(tries)
213     if tries < 0:
214         msg = f"tries must be 0 or greater, got {tries}"
215         logger.critical(msg)
216         raise ValueError(msg)
217
218     if delay_sec <= 0:
219         msg = f"delay_sec must be greater than 0, got {delay_sec}"
220         logger.critical(msg)
221         raise ValueError(msg)
222
223     def deco_retry(f):
224         @functools.wraps(f)
225         def f_retry(*args, **kwargs):
226             mtries, mdelay = tries, delay_sec  # make mutable
227             retval = f(*args, **kwargs)
228             while mtries > 0:
229                 if predicate(retval) is True:
230                     return retval
231                 logger.debug("Predicate failed, sleeping and retrying.")
232                 mtries -= 1
233                 time.sleep(mdelay)
234                 mdelay *= backoff
235                 retval = f(*args, **kwargs)
236             return retval
237         return f_retry
238     return deco_retry
239
240
241 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
242     return retry_predicate(
243         tries,
244         predicate=lambda x: x is True,
245         delay_sec=delay_sec,
246         backoff=backoff,
247     )
248
249
250 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
251     return retry_predicate(
252         tries,
253         predicate=lambda x: x is not None,
254         delay_sec=delay_sec,
255         backoff=backoff,
256     )
257
258
259 def deprecated(func):
260     """This is a decorator which can be used to mark functions
261     as deprecated. It will result in a warning being emitted
262     when the function is used.
263     """
264
265     @functools.wraps(func)
266     def wrapper_deprecated(*args, **kwargs):
267         msg = f"Call to deprecated function {func.__name__}"
268         logger.warning(msg)
269         warnings.warn(msg, category=DeprecationWarning)
270         return func(*args, **kwargs)
271
272     return wrapper_deprecated
273
274
275 def thunkify(func):
276     """
277     Make a function immediately return a function of no args which,
278     when called, waits for the result, which will start being
279     processed in another thread.
280     """
281
282     @functools.wraps(func)
283     def lazy_thunked(*args, **kwargs):
284         wait_event = threading.Event()
285
286         result = [None]
287         exc = [False, None]
288
289         def worker_func():
290             try:
291                 func_result = func(*args, **kwargs)
292                 result[0] = func_result
293             except Exception:
294                 exc[0] = True
295                 exc[1] = sys.exc_info()  # (type, value, traceback)
296                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
297                 logger.warning(msg)
298                 print(msg)
299             finally:
300                 wait_event.set()
301
302         def thunk():
303             wait_event.wait()
304             if exc[0]:
305                 raise exc[1][0](exc[1][1])
306             return result[0]
307
308         threading.Thread(target=worker_func).start()
309         return thunk
310
311     return lazy_thunked
312
313
314 ############################################################
315 # Timeout
316 ############################################################
317
318 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
319 # Used work of Stephen "Zero" Chappell <[email protected]>
320 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
321
322
323 def _raise_exception(exception, error_message: Optional[str]):
324     if error_message is None:
325         raise exception()
326     else:
327         raise exception(error_message)
328
329
330 def _target(queue, function, *args, **kwargs):
331     """Run a function with arguments and return output via a queue.
332
333     This is a helper function for the Process created in _Timeout. It runs
334     the function with positional arguments and keyword arguments and then
335     returns the function's output by way of a queue. If an exception gets
336     raised, it is returned to _Timeout to be raised by the value property.
337     """
338     try:
339         queue.put((True, function(*args, **kwargs)))
340     except:
341         queue.put((False, sys.exc_info()[1]))
342
343
344 class _Timeout(object):
345     """Wrap a function and add a timeout (limit) attribute to it.
346
347     Instances of this class are automatically generated by the add_timeout
348     function defined below.
349     """
350
351     def __init__(
352         self,
353         function: Callable,
354         timeout_exception: Exception,
355         error_message: str,
356         seconds: float,
357     ):
358         self.__limit = seconds
359         self.__function = function
360         self.__timeout_exception = timeout_exception
361         self.__error_message = error_message
362         self.__name__ = function.__name__
363         self.__doc__ = function.__doc__
364         self.__timeout = time.time()
365         self.__process = multiprocessing.Process()
366         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
367
368     def __call__(self, *args, **kwargs):
369         """Execute the embedded function object asynchronously.
370
371         The function given to the constructor is transparently called and
372         requires that "ready" be intermittently polled. If and when it is
373         True, the "value" property may then be checked for returned data.
374         """
375         self.__limit = kwargs.pop("timeout", self.__limit)
376         self.__queue = multiprocessing.Queue(1)
377         args = (self.__queue, self.__function) + args
378         self.__process = multiprocessing.Process(
379             target=_target, args=args, kwargs=kwargs
380         )
381         self.__process.daemon = True
382         self.__process.start()
383         if self.__limit is not None:
384             self.__timeout = self.__limit + time.time()
385         while not self.ready:
386             time.sleep(0.1)
387         return self.value
388
389     def cancel(self):
390         """Terminate any possible execution of the embedded function."""
391         if self.__process.is_alive():
392             self.__process.terminate()
393         _raise_exception(self.__timeout_exception, self.__error_message)
394
395     @property
396     def ready(self):
397         """Read-only property indicating status of "value" property."""
398         if self.__limit and self.__timeout < time.time():
399             self.cancel()
400         return self.__queue.full() and not self.__queue.empty()
401
402     @property
403     def value(self):
404         """Read-only property containing data returned from function."""
405         if self.ready is True:
406             flag, load = self.__queue.get()
407             if flag:
408                 return load
409             raise load
410
411
412 def timeout(
413     seconds: float = 1.0,
414     use_signals: Optional[bool] = None,
415     timeout_exception=exceptions.TimeoutError,
416     error_message="Function call timed out",
417 ):
418     """Add a timeout parameter to a function and return the function.
419
420     Note: the use_signals parameter is included in order to support
421     multiprocessing scenarios (signal can only be used from the process'
422     main thread).  When not using signals, timeout granularity will be
423     rounded to the nearest 0.1s.
424
425     Raises an exception when the timeout is reached.
426
427     It is illegal to pass anything other than a function as the first
428     parameter.  The function is wrapped and returned to the caller.
429     """
430     if use_signals is None:
431         use_signals = thread_utils.is_current_thread_main_thread()
432
433     def decorate(function):
434
435         if use_signals:
436
437             def handler(signum, frame):
438                 _raise_exception(timeout_exception, error_message)
439
440             @functools.wraps(function)
441             def new_function(*args, **kwargs):
442                 new_seconds = kwargs.pop("timeout", seconds)
443                 if new_seconds:
444                     old = signal.signal(signal.SIGALRM, handler)
445                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
446
447                 if not seconds:
448                     return function(*args, **kwargs)
449
450                 try:
451                     return function(*args, **kwargs)
452                 finally:
453                     if new_seconds:
454                         signal.setitimer(signal.ITIMER_REAL, 0)
455                         signal.signal(signal.SIGALRM, old)
456
457             return new_function
458         else:
459
460             @functools.wraps(function)
461             def new_function(*args, **kwargs):
462                 timeout_wrapper = _Timeout(
463                     function, timeout_exception, error_message, seconds
464                 )
465                 return timeout_wrapper(*args, **kwargs)
466
467             return new_function
468
469     return decorate
470
471
472 class non_reentrant_code(object):
473     def __init__(self):
474         self._lock = threading.RLock
475         self._entered = False
476
477     def __call__(self, f):
478         def _gatekeeper(*args, **kwargs):
479             with self._lock:
480                 if self._entered:
481                     return
482                 self._entered = True
483                 f(*args, **kwargs)
484                 self._entered = False
485
486         return _gatekeeper
487
488
489 class rlocked(object):
490     def __init__(self):
491         self._lock = threading.RLock
492         self._entered = False
493
494     def __call__(self, f):
495         def _gatekeeper(*args, **kwargs):
496             with self._lock:
497                 if self._entered:
498                     return
499                 self._entered = True
500                 f(*args, **kwargs)
501                 self._entered = False
502         return _gatekeeper
503
504
505 def call_with_sample_rate(sample_rate: float) -> Callable:
506     if not 0.0 <= sample_rate <= 1.0:
507         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
508         logger.critical(msg)
509         raise ValueError(msg)
510
511     def decorator(f):
512         @functools.wraps(f)
513         def _call_with_sample_rate(*args, **kwargs):
514             if random.uniform(0, 1) < sample_rate:
515                 return f(*args, **kwargs)
516             else:
517                 logger.debug(
518                     f"@call_with_sample_rate skipping a call to {f.__name__}"
519                 )
520         return _call_with_sample_rate
521     return decorator
522
523
524 def decorate_matching_methods_with(decorator, acl=None):
525     """Apply decorator to all methods in a class whose names begin with
526     prefix.  If prefix is None (default), decorate all methods in the
527     class.
528     """
529     def decorate_the_class(cls):
530         for name, m in inspect.getmembers(cls, inspect.isfunction):
531             if acl is None:
532                 setattr(cls, name, decorator(m))
533             else:
534                 if acl(name):
535                     setattr(cls, name, decorator(m))
536         return cls
537     return decorate_the_class