ACL uses enums, some more tests, other stuff.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 # This module is commonly used by others in here and should avoid
22 # taking any unnecessary dependencies back on them.
23 import exceptions
24
25
26 logger = logging.getLogger(__name__)
27
28
29 def timed(func: Callable) -> Callable:
30     """Print the runtime of the decorated function."""
31
32     @functools.wraps(func)
33     def wrapper_timer(*args, **kwargs):
34         start_time = time.perf_counter()
35         value = func(*args, **kwargs)
36         end_time = time.perf_counter()
37         run_time = end_time - start_time
38         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
39         print(msg)
40         logger.info(msg)
41         return value
42     return wrapper_timer
43
44
45 def invocation_logged(func: Callable) -> Callable:
46     """Log the call of a function."""
47
48     @functools.wraps(func)
49     def wrapper_invocation_logged(*args, **kwargs):
50         now = datetime.datetime.now()
51         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
52         msg = f"[{ts}]: Entered {func.__name__}"
53         print(msg)
54         logger.info(msg)
55         ret = func(*args, **kwargs)
56         now = datetime.datetime.now()
57         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
58         msg = f"[{ts}]: Exited {func.__name__}"
59         print(msg)
60         logger.info(msg)
61         return ret
62     return wrapper_invocation_logged
63
64
65 def debug_args(func: Callable) -> Callable:
66     """Print the function signature and return value at each call."""
67
68     @functools.wraps(func)
69     def wrapper_debug_args(*args, **kwargs):
70         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
71         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
72         signature = ", ".join(args_repr + kwargs_repr)
73         msg = f"Calling {func.__name__}({signature})"
74         print(msg)
75         logger.info(msg)
76         value = func(*args, **kwargs)
77         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
78         logger.info(msg)
79         return value
80     return wrapper_debug_args
81
82
83 def debug_count_calls(func: Callable) -> Callable:
84     """Count function invocations and print a message befor every call."""
85
86     @functools.wraps(func)
87     def wrapper_debug_count_calls(*args, **kwargs):
88         wrapper_debug_count_calls.num_calls += 1
89         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
90         print(msg)
91         logger.info(msg)
92         return func(*args, **kwargs)
93     wrapper_debug_count_calls.num_calls = 0
94     return wrapper_debug_count_calls
95
96
97 class DelayWhen(enum.Enum):
98     BEFORE_CALL = 1
99     AFTER_CALL = 2
100     BEFORE_AND_AFTER = 3
101
102
103 def delay(
104     _func: Callable = None,
105     *,
106     seconds: float = 1.0,
107     when: DelayWhen = DelayWhen.BEFORE_CALL,
108 ) -> Callable:
109     """Delay the execution of a function by sleeping before and/or after.
110
111     Slow down a function by inserting a delay before and/or after its
112     invocation.
113     """
114
115     def decorator_delay(func: Callable) -> Callable:
116         @functools.wraps(func)
117         def wrapper_delay(*args, **kwargs):
118             if when & DelayWhen.BEFORE_CALL:
119                 logger.debug(
120                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
121                 )
122                 time.sleep(seconds)
123             retval = func(*args, **kwargs)
124             if when & DelayWhen.AFTER_CALL:
125                 logger.debug(
126                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
127                 )
128                 time.sleep(seconds)
129             return retval
130         return wrapper_delay
131
132     if _func is None:
133         return decorator_delay
134     else:
135         return decorator_delay(_func)
136
137
138 class _SingletonWrapper:
139     """
140     A singleton wrapper class. Its instances would be created
141     for each decorated class.
142     """
143
144     def __init__(self, cls):
145         self.__wrapped__ = cls
146         self._instance = None
147
148     def __call__(self, *args, **kwargs):
149         """Returns a single instance of decorated class"""
150         logger.debug(
151             f"@singleton returning global instance of {self.__wrapped__.__name__}"
152         )
153         if self._instance is None:
154             self._instance = self.__wrapped__(*args, **kwargs)
155         return self._instance
156
157
158 def singleton(cls):
159     """
160     A singleton decorator. Returns a wrapper objects. A call on that object
161     returns a single instance object of decorated class. Use the __wrapped__
162     attribute to access decorated class directly in unit tests
163     """
164     return _SingletonWrapper(cls)
165
166
167 def memoized(func: Callable) -> Callable:
168     """Keep a cache of previous function call results.
169
170     The cache here is a dict with a key based on the arguments to the
171     call.  Consider also: functools.lru_cache for a more advanced
172     implementation.
173     """
174
175     @functools.wraps(func)
176     def wrapper_memoized(*args, **kwargs):
177         cache_key = args + tuple(kwargs.items())
178         if cache_key not in wrapper_memoized.cache:
179             value = func(*args, **kwargs)
180             logger.debug(
181                 f"Memoizing {cache_key} => {value} for {func.__name__}"
182             )
183             wrapper_memoized.cache[cache_key] = value
184         else:
185             logger.debug(f"Returning memoized value for {func.__name__}")
186         return wrapper_memoized.cache[cache_key]
187     wrapper_memoized.cache = dict()
188     return wrapper_memoized
189
190
191 def retry_predicate(
192     tries: int,
193     *,
194     predicate: Callable[..., bool],
195     delay_sec: float = 3,
196     backoff: float = 2.0,
197 ):
198     """Retries a function or method up to a certain number of times
199     with a prescribed initial delay period and backoff rate.
200
201     tries is the maximum number of attempts to run the function.
202     delay_sec sets the initial delay period in seconds.
203     backoff is a multiplied (must be >1) used to modify the delay.
204     predicate is a function that will be passed the retval of the
205       decorated function and must return True to stop or False to
206       retry.
207     """
208     if backoff < 1:
209         msg = f"backoff must be greater than or equal to 1, got {backoff}"
210         logger.critical(msg)
211         raise ValueError(msg)
212
213     tries = math.floor(tries)
214     if tries < 0:
215         msg = f"tries must be 0 or greater, got {tries}"
216         logger.critical(msg)
217         raise ValueError(msg)
218
219     if delay_sec <= 0:
220         msg = f"delay_sec must be greater than 0, got {delay_sec}"
221         logger.critical(msg)
222         raise ValueError(msg)
223
224     def deco_retry(f):
225         @functools.wraps(f)
226         def f_retry(*args, **kwargs):
227             mtries, mdelay = tries, delay_sec  # make mutable
228             retval = f(*args, **kwargs)
229             while mtries > 0:
230                 if predicate(retval) is True:
231                     return retval
232                 logger.debug("Predicate failed, sleeping and retrying.")
233                 mtries -= 1
234                 time.sleep(mdelay)
235                 mdelay *= backoff
236                 retval = f(*args, **kwargs)
237             return retval
238         return f_retry
239     return deco_retry
240
241
242 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
243     return retry_predicate(
244         tries,
245         predicate=lambda x: x is True,
246         delay_sec=delay_sec,
247         backoff=backoff,
248     )
249
250
251 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
252     return retry_predicate(
253         tries,
254         predicate=lambda x: x is not None,
255         delay_sec=delay_sec,
256         backoff=backoff,
257     )
258
259
260 def deprecated(func):
261     """This is a decorator which can be used to mark functions
262     as deprecated. It will result in a warning being emitted
263     when the function is used.
264     """
265
266     @functools.wraps(func)
267     def wrapper_deprecated(*args, **kwargs):
268         msg = f"Call to deprecated function {func.__name__}"
269         logger.warning(msg)
270         warnings.warn(msg, category=DeprecationWarning)
271         return func(*args, **kwargs)
272
273     return wrapper_deprecated
274
275
276 def thunkify(func):
277     """
278     Make a function immediately return a function of no args which,
279     when called, waits for the result, which will start being
280     processed in another thread.
281     """
282
283     @functools.wraps(func)
284     def lazy_thunked(*args, **kwargs):
285         wait_event = threading.Event()
286
287         result = [None]
288         exc = [False, None]
289
290         def worker_func():
291             try:
292                 func_result = func(*args, **kwargs)
293                 result[0] = func_result
294             except Exception:
295                 exc[0] = True
296                 exc[1] = sys.exc_info()  # (type, value, traceback)
297                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
298                 logger.warning(msg)
299                 print(msg)
300             finally:
301                 wait_event.set()
302
303         def thunk():
304             wait_event.wait()
305             if exc[0]:
306                 raise exc[1][0](exc[1][1])
307             return result[0]
308
309         threading.Thread(target=worker_func).start()
310         return thunk
311
312     return lazy_thunked
313
314
315 ############################################################
316 # Timeout
317 ############################################################
318
319 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
320 # Used work of Stephen "Zero" Chappell <[email protected]>
321 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
322
323
324 def _raise_exception(exception, error_message: Optional[str]):
325     if error_message is None:
326         raise exception()
327     else:
328         raise exception(error_message)
329
330
331 def _target(queue, function, *args, **kwargs):
332     """Run a function with arguments and return output via a queue.
333
334     This is a helper function for the Process created in _Timeout. It runs
335     the function with positional arguments and keyword arguments and then
336     returns the function's output by way of a queue. If an exception gets
337     raised, it is returned to _Timeout to be raised by the value property.
338     """
339     try:
340         queue.put((True, function(*args, **kwargs)))
341     except:
342         queue.put((False, sys.exc_info()[1]))
343
344
345 class _Timeout(object):
346     """Wrap a function and add a timeout (limit) attribute to it.
347
348     Instances of this class are automatically generated by the add_timeout
349     function defined below.
350     """
351
352     def __init__(
353         self,
354         function: Callable,
355         timeout_exception: Exception,
356         error_message: str,
357         seconds: float,
358     ):
359         self.__limit = seconds
360         self.__function = function
361         self.__timeout_exception = timeout_exception
362         self.__error_message = error_message
363         self.__name__ = function.__name__
364         self.__doc__ = function.__doc__
365         self.__timeout = time.time()
366         self.__process = multiprocessing.Process()
367         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
368
369     def __call__(self, *args, **kwargs):
370         """Execute the embedded function object asynchronously.
371
372         The function given to the constructor is transparently called and
373         requires that "ready" be intermittently polled. If and when it is
374         True, the "value" property may then be checked for returned data.
375         """
376         self.__limit = kwargs.pop("timeout", self.__limit)
377         self.__queue = multiprocessing.Queue(1)
378         args = (self.__queue, self.__function) + args
379         self.__process = multiprocessing.Process(
380             target=_target, args=args, kwargs=kwargs
381         )
382         self.__process.daemon = True
383         self.__process.start()
384         if self.__limit is not None:
385             self.__timeout = self.__limit + time.time()
386         while not self.ready:
387             time.sleep(0.1)
388         return self.value
389
390     def cancel(self):
391         """Terminate any possible execution of the embedded function."""
392         if self.__process.is_alive():
393             self.__process.terminate()
394         _raise_exception(self.__timeout_exception, self.__error_message)
395
396     @property
397     def ready(self):
398         """Read-only property indicating status of "value" property."""
399         if self.__limit and self.__timeout < time.time():
400             self.cancel()
401         return self.__queue.full() and not self.__queue.empty()
402
403     @property
404     def value(self):
405         """Read-only property containing data returned from function."""
406         if self.ready is True:
407             flag, load = self.__queue.get()
408             if flag:
409                 return load
410             raise load
411
412
413 def timeout(
414     seconds: float = 1.0,
415     use_signals: Optional[bool] = None,
416     timeout_exception=exceptions.TimeoutError,
417     error_message="Function call timed out",
418 ):
419     """Add a timeout parameter to a function and return the function.
420
421     Note: the use_signals parameter is included in order to support
422     multiprocessing scenarios (signal can only be used from the process'
423     main thread).  When not using signals, timeout granularity will be
424     rounded to the nearest 0.1s.
425
426     Raises an exception when the timeout is reached.
427
428     It is illegal to pass anything other than a function as the first
429     parameter.  The function is wrapped and returned to the caller.
430     """
431     if use_signals is None:
432         import thread_utils
433         use_signals = thread_utils.is_current_thread_main_thread()
434
435     def decorate(function):
436
437         if use_signals:
438
439             def handler(signum, frame):
440                 _raise_exception(timeout_exception, error_message)
441
442             @functools.wraps(function)
443             def new_function(*args, **kwargs):
444                 new_seconds = kwargs.pop("timeout", seconds)
445                 if new_seconds:
446                     old = signal.signal(signal.SIGALRM, handler)
447                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
448
449                 if not seconds:
450                     return function(*args, **kwargs)
451
452                 try:
453                     return function(*args, **kwargs)
454                 finally:
455                     if new_seconds:
456                         signal.setitimer(signal.ITIMER_REAL, 0)
457                         signal.signal(signal.SIGALRM, old)
458
459             return new_function
460         else:
461
462             @functools.wraps(function)
463             def new_function(*args, **kwargs):
464                 timeout_wrapper = _Timeout(
465                     function, timeout_exception, error_message, seconds
466                 )
467                 return timeout_wrapper(*args, **kwargs)
468
469             return new_function
470
471     return decorate
472
473
474 class non_reentrant_code(object):
475     def __init__(self):
476         self._lock = threading.RLock
477         self._entered = False
478
479     def __call__(self, f):
480         def _gatekeeper(*args, **kwargs):
481             with self._lock:
482                 if self._entered:
483                     return
484                 self._entered = True
485                 f(*args, **kwargs)
486                 self._entered = False
487
488         return _gatekeeper
489
490
491 class rlocked(object):
492     def __init__(self):
493         self._lock = threading.RLock
494         self._entered = False
495
496     def __call__(self, f):
497         def _gatekeeper(*args, **kwargs):
498             with self._lock:
499                 if self._entered:
500                     return
501                 self._entered = True
502                 f(*args, **kwargs)
503                 self._entered = False
504         return _gatekeeper
505
506
507 def call_with_sample_rate(sample_rate: float) -> Callable:
508     if not 0.0 <= sample_rate <= 1.0:
509         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
510         logger.critical(msg)
511         raise ValueError(msg)
512
513     def decorator(f):
514         @functools.wraps(f)
515         def _call_with_sample_rate(*args, **kwargs):
516             if random.uniform(0, 1) < sample_rate:
517                 return f(*args, **kwargs)
518             else:
519                 logger.debug(
520                     f"@call_with_sample_rate skipping a call to {f.__name__}"
521                 )
522         return _call_with_sample_rate
523     return decorator
524
525
526 def decorate_matching_methods_with(decorator, acl=None):
527     """Apply decorator to all methods in a class whose names begin with
528     prefix.  If prefix is None (default), decorate all methods in the
529     class.
530     """
531     def decorate_the_class(cls):
532         for name, m in inspect.getmembers(cls, inspect.isfunction):
533             if acl is None:
534                 setattr(cls, name, decorator(m))
535             else:
536                 if acl(name):
537                     setattr(cls, name, decorator(m))
538         return cls
539     return decorate_the_class