Use ValueError, not Exception.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Callable, Optional
18 import warnings
19
20 import exceptions
21 import thread_utils
22
23 logger = logging.getLogger(__name__)
24
25
26 def timed(func: Callable) -> Callable:
27     """Print the runtime of the decorated function."""
28
29     @functools.wraps(func)
30     def wrapper_timer(*args, **kwargs):
31         start_time = time.perf_counter()
32         value = func(*args, **kwargs)
33         end_time = time.perf_counter()
34         run_time = end_time - start_time
35         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
36         print(msg)
37         logger.info(msg)
38         return value
39     return wrapper_timer
40
41
42 def invocation_logged(func: Callable) -> Callable:
43     """Log the call of a function."""
44
45     @functools.wraps(func)
46     def wrapper_invocation_logged(*args, **kwargs):
47         now = datetime.datetime.now()
48         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
49         msg = f"[{ts}]: Entered {func.__name__}"
50         print(msg)
51         logger.info(msg)
52         ret = func(*args, **kwargs)
53         now = datetime.datetime.now()
54         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
55         msg = f"[{ts}]: Exited {func.__name__}"
56         print(msg)
57         logger.info(msg)
58         return ret
59     return wrapper_invocation_logged
60
61
62 def debug_args(func: Callable) -> Callable:
63     """Print the function signature and return value at each call."""
64
65     @functools.wraps(func)
66     def wrapper_debug_args(*args, **kwargs):
67         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
68         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
69         signature = ", ".join(args_repr + kwargs_repr)
70         msg = f"Calling {func.__name__}({signature})"
71         print(msg)
72         logger.info(msg)
73         value = func(*args, **kwargs)
74         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
75         logger.info(msg)
76         return value
77     return wrapper_debug_args
78
79
80 def debug_count_calls(func: Callable) -> Callable:
81     """Count function invocations and print a message befor every call."""
82
83     @functools.wraps(func)
84     def wrapper_debug_count_calls(*args, **kwargs):
85         wrapper_debug_count_calls.num_calls += 1
86         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
87         print(msg)
88         logger.info(msg)
89         return func(*args, **kwargs)
90     wrapper_debug_count_calls.num_calls = 0
91     return wrapper_debug_count_calls
92
93
94 class DelayWhen(enum.Enum):
95     BEFORE_CALL = 1
96     AFTER_CALL = 2
97     BEFORE_AND_AFTER = 3
98
99
100 def delay(
101     _func: Callable = None,
102     *,
103     seconds: float = 1.0,
104     when: DelayWhen = DelayWhen.BEFORE_CALL,
105 ) -> Callable:
106     """Delay the execution of a function by sleeping before and/or after.
107
108     Slow down a function by inserting a delay before and/or after its
109     invocation.
110     """
111
112     def decorator_delay(func: Callable) -> Callable:
113         @functools.wraps(func)
114         def wrapper_delay(*args, **kwargs):
115             if when & DelayWhen.BEFORE_CALL:
116                 logger.debug(
117                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
118                 )
119                 time.sleep(seconds)
120             retval = func(*args, **kwargs)
121             if when & DelayWhen.AFTER_CALL:
122                 logger.debug(
123                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
124                 )
125                 time.sleep(seconds)
126             return retval
127         return wrapper_delay
128
129     if _func is None:
130         return decorator_delay
131     else:
132         return decorator_delay(_func)
133
134
135 class _SingletonWrapper:
136     """
137     A singleton wrapper class. Its instances would be created
138     for each decorated class.
139     """
140
141     def __init__(self, cls):
142         self.__wrapped__ = cls
143         self._instance = None
144
145     def __call__(self, *args, **kwargs):
146         """Returns a single instance of decorated class"""
147         logger.debug(
148             f"@singleton returning global instance of {self.__wrapped__.__name__}"
149         )
150         if self._instance is None:
151             self._instance = self.__wrapped__(*args, **kwargs)
152         return self._instance
153
154
155 def singleton(cls):
156     """
157     A singleton decorator. Returns a wrapper objects. A call on that object
158     returns a single instance object of decorated class. Use the __wrapped__
159     attribute to access decorated class directly in unit tests
160     """
161     return _SingletonWrapper(cls)
162
163
164 def memoized(func: Callable) -> Callable:
165     """Keep a cache of previous function call results.
166
167     The cache here is a dict with a key based on the arguments to the
168     call.  Consider also: functools.lru_cache for a more advanced
169     implementation.
170     """
171
172     @functools.wraps(func)
173     def wrapper_memoized(*args, **kwargs):
174         cache_key = args + tuple(kwargs.items())
175         if cache_key not in wrapper_memoized.cache:
176             value = func(*args, **kwargs)
177             logger.debug(
178                 f"Memoizing {cache_key} => {value} for {func.__name__}"
179             )
180             wrapper_memoized.cache[cache_key] = value
181         else:
182             logger.debug(f"Returning memoized value for {func.__name__}")
183         return wrapper_memoized.cache[cache_key]
184     wrapper_memoized.cache = dict()
185     return wrapper_memoized
186
187
188 def retry_predicate(
189     tries: int,
190     *,
191     predicate: Callable[..., bool],
192     delay_sec: float = 3,
193     backoff: float = 2.0,
194 ):
195     """Retries a function or method up to a certain number of times
196     with a prescribed initial delay period and backoff rate.
197
198     tries is the maximum number of attempts to run the function.
199     delay_sec sets the initial delay period in seconds.
200     backoff is a multiplied (must be >1) used to modify the delay.
201     predicate is a function that will be passed the retval of the
202       decorated function and must return True to stop or False to
203       retry.
204     """
205     if backoff < 1:
206         msg = f"backoff must be greater than or equal to 1, got {backoff}"
207         logger.critical(msg)
208         raise ValueError(msg)
209
210     tries = math.floor(tries)
211     if tries < 0:
212         msg = f"tries must be 0 or greater, got {tries}"
213         logger.critical(msg)
214         raise ValueError(msg)
215
216     if delay_sec <= 0:
217         msg = f"delay_sec must be greater than 0, got {delay_sec}"
218         logger.critical(msg)
219         raise ValueError(msg)
220
221     def deco_retry(f):
222         @functools.wraps(f)
223         def f_retry(*args, **kwargs):
224             mtries, mdelay = tries, delay_sec  # make mutable
225             retval = f(*args, **kwargs)
226             while mtries > 0:
227                 if predicate(retval) is True:
228                     return retval
229                 logger.debug("Predicate failed, sleeping and retrying.")
230                 mtries -= 1
231                 time.sleep(mdelay)
232                 mdelay *= backoff
233                 retval = f(*args, **kwargs)
234             return retval
235         return f_retry
236     return deco_retry
237
238
239 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
240     return retry_predicate(
241         tries,
242         predicate=lambda x: x is True,
243         delay_sec=delay_sec,
244         backoff=backoff,
245     )
246
247
248 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
249     return retry_predicate(
250         tries,
251         predicate=lambda x: x is not None,
252         delay_sec=delay_sec,
253         backoff=backoff,
254     )
255
256
257 def deprecated(func):
258     """This is a decorator which can be used to mark functions
259     as deprecated. It will result in a warning being emitted
260     when the function is used.
261     """
262
263     @functools.wraps(func)
264     def wrapper_deprecated(*args, **kwargs):
265         msg = f"Call to deprecated function {func.__name__}"
266         logger.warning(msg)
267         warnings.warn(msg, category=DeprecationWarning)
268         return func(*args, **kwargs)
269
270     return wrapper_deprecated
271
272
273 def thunkify(func):
274     """
275     Make a function immediately return a function of no args which,
276     when called, waits for the result, which will start being
277     processed in another thread.
278     """
279
280     @functools.wraps(func)
281     def lazy_thunked(*args, **kwargs):
282         wait_event = threading.Event()
283
284         result = [None]
285         exc = [False, None]
286
287         def worker_func():
288             try:
289                 func_result = func(*args, **kwargs)
290                 result[0] = func_result
291             except Exception:
292                 exc[0] = True
293                 exc[1] = sys.exc_info()  # (type, value, traceback)
294                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
295                 logger.warning(msg)
296                 print(msg)
297             finally:
298                 wait_event.set()
299
300         def thunk():
301             wait_event.wait()
302             if exc[0]:
303                 raise exc[1][0](exc[1][1])
304             return result[0]
305
306         threading.Thread(target=worker_func).start()
307         return thunk
308
309     return lazy_thunked
310
311
312 ############################################################
313 # Timeout
314 ############################################################
315
316 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
317 # Used work of Stephen "Zero" Chappell <[email protected]>
318 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
319
320
321 def _raise_exception(exception, error_message: Optional[str]):
322     if error_message is None:
323         raise exception()
324     else:
325         raise exception(error_message)
326
327
328 def _target(queue, function, *args, **kwargs):
329     """Run a function with arguments and return output via a queue.
330
331     This is a helper function for the Process created in _Timeout. It runs
332     the function with positional arguments and keyword arguments and then
333     returns the function's output by way of a queue. If an exception gets
334     raised, it is returned to _Timeout to be raised by the value property.
335     """
336     try:
337         queue.put((True, function(*args, **kwargs)))
338     except:
339         queue.put((False, sys.exc_info()[1]))
340
341
342 class _Timeout(object):
343     """Wrap a function and add a timeout (limit) attribute to it.
344
345     Instances of this class are automatically generated by the add_timeout
346     function defined below.
347     """
348
349     def __init__(
350         self,
351         function: Callable,
352         timeout_exception: Exception,
353         error_message: str,
354         seconds: float,
355     ):
356         self.__limit = seconds
357         self.__function = function
358         self.__timeout_exception = timeout_exception
359         self.__error_message = error_message
360         self.__name__ = function.__name__
361         self.__doc__ = function.__doc__
362         self.__timeout = time.time()
363         self.__process = multiprocessing.Process()
364         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
365
366     def __call__(self, *args, **kwargs):
367         """Execute the embedded function object asynchronously.
368
369         The function given to the constructor is transparently called and
370         requires that "ready" be intermittently polled. If and when it is
371         True, the "value" property may then be checked for returned data.
372         """
373         self.__limit = kwargs.pop("timeout", self.__limit)
374         self.__queue = multiprocessing.Queue(1)
375         args = (self.__queue, self.__function) + args
376         self.__process = multiprocessing.Process(
377             target=_target, args=args, kwargs=kwargs
378         )
379         self.__process.daemon = True
380         self.__process.start()
381         if self.__limit is not None:
382             self.__timeout = self.__limit + time.time()
383         while not self.ready:
384             time.sleep(0.1)
385         return self.value
386
387     def cancel(self):
388         """Terminate any possible execution of the embedded function."""
389         if self.__process.is_alive():
390             self.__process.terminate()
391         _raise_exception(self.__timeout_exception, self.__error_message)
392
393     @property
394     def ready(self):
395         """Read-only property indicating status of "value" property."""
396         if self.__limit and self.__timeout < time.time():
397             self.cancel()
398         return self.__queue.full() and not self.__queue.empty()
399
400     @property
401     def value(self):
402         """Read-only property containing data returned from function."""
403         if self.ready is True:
404             flag, load = self.__queue.get()
405             if flag:
406                 return load
407             raise load
408
409
410 def timeout(
411     seconds: float = 1.0,
412     use_signals: Optional[bool] = None,
413     timeout_exception=exceptions.TimeoutError,
414     error_message="Function call timed out",
415 ):
416     """Add a timeout parameter to a function and return the function.
417
418     Note: the use_signals parameter is included in order to support
419     multiprocessing scenarios (signal can only be used from the process'
420     main thread).  When not using signals, timeout granularity will be
421     rounded to the nearest 0.1s.
422
423     Raises an exception when the timeout is reached.
424
425     It is illegal to pass anything other than a function as the first
426     parameter.  The function is wrapped and returned to the caller.
427     """
428     if use_signals is None:
429         use_signals = thread_utils.is_current_thread_main_thread()
430
431     def decorate(function):
432
433         if use_signals:
434
435             def handler(signum, frame):
436                 _raise_exception(timeout_exception, error_message)
437
438             @functools.wraps(function)
439             def new_function(*args, **kwargs):
440                 new_seconds = kwargs.pop("timeout", seconds)
441                 if new_seconds:
442                     old = signal.signal(signal.SIGALRM, handler)
443                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
444
445                 if not seconds:
446                     return function(*args, **kwargs)
447
448                 try:
449                     return function(*args, **kwargs)
450                 finally:
451                     if new_seconds:
452                         signal.setitimer(signal.ITIMER_REAL, 0)
453                         signal.signal(signal.SIGALRM, old)
454
455             return new_function
456         else:
457
458             @functools.wraps(function)
459             def new_function(*args, **kwargs):
460                 timeout_wrapper = _Timeout(
461                     function, timeout_exception, error_message, seconds
462                 )
463                 return timeout_wrapper(*args, **kwargs)
464
465             return new_function
466
467     return decorate
468
469
470 class non_reentrant_code(object):
471     def __init__(self):
472         self._lock = threading.RLock
473         self._entered = False
474
475     def __call__(self, f):
476         def _gatekeeper(*args, **kwargs):
477             with self._lock:
478                 if self._entered:
479                     return
480                 self._entered = True
481                 f(*args, **kwargs)
482                 self._entered = False
483
484         return _gatekeeper
485
486
487 class rlocked(object):
488     def __init__(self):
489         self._lock = threading.RLock
490         self._entered = False
491
492     def __call__(self, f):
493         def _gatekeeper(*args, **kwargs):
494             with self._lock:
495                 if self._entered:
496                     return
497                 self._entered = True
498                 f(*args, **kwargs)
499                 self._entered = False
500         return _gatekeeper
501
502
503 def call_with_sample_rate(sample_rate: float) -> Callable:
504     if not 0.0 <= sample_rate <= 1.0:
505         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
506         logger.critical(msg)
507         raise ValueError(msg)
508
509     def decorator(f):
510         @functools.wraps(f)
511         def _call_with_sample_rate(*args, **kwargs):
512             if random.uniform(0, 1) < sample_rate:
513                 return f(*args, **kwargs)
514             else:
515                 logger.debug(
516                     f"@call_with_sample_rate skipping a call to {f.__name__}"
517                 )
518         return _call_with_sample_rate
519     return decorator