Make smart futures avoid polling.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 # This module is commonly used by others in here and should avoid
22 # taking any unnecessary dependencies back on them.
23 import exceptions
24
25
26 logger = logging.getLogger(__name__)
27
28
29 def timed(func: Callable) -> Callable:
30     """Print the runtime of the decorated function."""
31
32     @functools.wraps(func)
33     def wrapper_timer(*args, **kwargs):
34         start_time = time.perf_counter()
35         value = func(*args, **kwargs)
36         end_time = time.perf_counter()
37         run_time = end_time - start_time
38         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
39         print(msg)
40         logger.info(msg)
41         return value
42     return wrapper_timer
43
44
45 def invocation_logged(func: Callable) -> Callable:
46     """Log the call of a function."""
47
48     @functools.wraps(func)
49     def wrapper_invocation_logged(*args, **kwargs):
50         msg = f"Entered {func.__qualname__}"
51         print(msg)
52         logger.info(msg)
53         ret = func(*args, **kwargs)
54         msg = f"Exited {func.__qualname__}"
55         print(msg)
56         logger.info(msg)
57         return ret
58     return wrapper_invocation_logged
59
60
61 def debug_args(func: Callable) -> Callable:
62     """Print the function signature and return value at each call."""
63
64     @functools.wraps(func)
65     def wrapper_debug_args(*args, **kwargs):
66         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
67         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
68         signature = ", ".join(args_repr + kwargs_repr)
69         msg = f"Calling {func.__name__}({signature})"
70         print(msg)
71         logger.info(msg)
72         value = func(*args, **kwargs)
73         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
74         logger.info(msg)
75         return value
76     return wrapper_debug_args
77
78
79 def debug_count_calls(func: Callable) -> Callable:
80     """Count function invocations and print a message befor every call."""
81
82     @functools.wraps(func)
83     def wrapper_debug_count_calls(*args, **kwargs):
84         wrapper_debug_count_calls.num_calls += 1
85         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
86         print(msg)
87         logger.info(msg)
88         return func(*args, **kwargs)
89     wrapper_debug_count_calls.num_calls = 0
90     return wrapper_debug_count_calls
91
92
93 class DelayWhen(enum.Enum):
94     BEFORE_CALL = 1
95     AFTER_CALL = 2
96     BEFORE_AND_AFTER = 3
97
98
99 def delay(
100     _func: Callable = None,
101     *,
102     seconds: float = 1.0,
103     when: DelayWhen = DelayWhen.BEFORE_CALL,
104 ) -> Callable:
105     """Delay the execution of a function by sleeping before and/or after.
106
107     Slow down a function by inserting a delay before and/or after its
108     invocation.
109     """
110
111     def decorator_delay(func: Callable) -> Callable:
112         @functools.wraps(func)
113         def wrapper_delay(*args, **kwargs):
114             if when & DelayWhen.BEFORE_CALL:
115                 logger.debug(
116                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
117                 )
118                 time.sleep(seconds)
119             retval = func(*args, **kwargs)
120             if when & DelayWhen.AFTER_CALL:
121                 logger.debug(
122                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
123                 )
124                 time.sleep(seconds)
125             return retval
126         return wrapper_delay
127
128     if _func is None:
129         return decorator_delay
130     else:
131         return decorator_delay(_func)
132
133
134 class _SingletonWrapper:
135     """
136     A singleton wrapper class. Its instances would be created
137     for each decorated class.
138     """
139
140     def __init__(self, cls):
141         self.__wrapped__ = cls
142         self._instance = None
143
144     def __call__(self, *args, **kwargs):
145         """Returns a single instance of decorated class"""
146         logger.debug(
147             f"@singleton returning global instance of {self.__wrapped__.__name__}"
148         )
149         if self._instance is None:
150             self._instance = self.__wrapped__(*args, **kwargs)
151         return self._instance
152
153
154 def singleton(cls):
155     """
156     A singleton decorator. Returns a wrapper objects. A call on that object
157     returns a single instance object of decorated class. Use the __wrapped__
158     attribute to access decorated class directly in unit tests
159     """
160     return _SingletonWrapper(cls)
161
162
163 def memoized(func: Callable) -> Callable:
164     """Keep a cache of previous function call results.
165
166     The cache here is a dict with a key based on the arguments to the
167     call.  Consider also: functools.lru_cache for a more advanced
168     implementation.
169     """
170
171     @functools.wraps(func)
172     def wrapper_memoized(*args, **kwargs):
173         cache_key = args + tuple(kwargs.items())
174         if cache_key not in wrapper_memoized.cache:
175             value = func(*args, **kwargs)
176             logger.debug(
177                 f"Memoizing {cache_key} => {value} for {func.__name__}"
178             )
179             wrapper_memoized.cache[cache_key] = value
180         else:
181             logger.debug(f"Returning memoized value for {func.__name__}")
182         return wrapper_memoized.cache[cache_key]
183     wrapper_memoized.cache = dict()
184     return wrapper_memoized
185
186
187 def retry_predicate(
188     tries: int,
189     *,
190     predicate: Callable[..., bool],
191     delay_sec: float = 3.0,
192     backoff: float = 2.0,
193 ):
194     """Retries a function or method up to a certain number of times
195     with a prescribed initial delay period and backoff rate.
196
197     tries is the maximum number of attempts to run the function.
198     delay_sec sets the initial delay period in seconds.
199     backoff is a multiplied (must be >1) used to modify the delay.
200     predicate is a function that will be passed the retval of the
201     decorated function and must return True to stop or False to
202     retry.
203     """
204     if backoff < 1.0:
205         msg = f"backoff must be greater than or equal to 1, got {backoff}"
206         logger.critical(msg)
207         raise ValueError(msg)
208
209     tries = math.floor(tries)
210     if tries < 0:
211         msg = f"tries must be 0 or greater, got {tries}"
212         logger.critical(msg)
213         raise ValueError(msg)
214
215     if delay_sec <= 0:
216         msg = f"delay_sec must be greater than 0, got {delay_sec}"
217         logger.critical(msg)
218         raise ValueError(msg)
219
220     def deco_retry(f):
221         @functools.wraps(f)
222         def f_retry(*args, **kwargs):
223             mtries, mdelay = tries, delay_sec  # make mutable
224             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
225             retval = f(*args, **kwargs)
226             while mtries > 0:
227                 if predicate(retval) is True:
228                     logger.debug('Predicate succeeded, deco_retry is done.')
229                     return retval
230                 logger.debug("Predicate failed, sleeping and retrying.")
231                 mtries -= 1
232                 time.sleep(mdelay)
233                 mdelay *= backoff
234                 retval = f(*args, **kwargs)
235             return retval
236         return f_retry
237     return deco_retry
238
239
240 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
241     return retry_predicate(
242         tries,
243         predicate=lambda x: x is True,
244         delay_sec=delay_sec,
245         backoff=backoff,
246     )
247
248
249 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
250     return retry_predicate(
251         tries,
252         predicate=lambda x: x is not None,
253         delay_sec=delay_sec,
254         backoff=backoff,
255     )
256
257
258 def deprecated(func):
259     """This is a decorator which can be used to mark functions
260     as deprecated. It will result in a warning being emitted
261     when the function is used.
262     """
263
264     @functools.wraps(func)
265     def wrapper_deprecated(*args, **kwargs):
266         msg = f"Call to deprecated function {func.__name__}"
267         logger.warning(msg)
268         warnings.warn(msg, category=DeprecationWarning)
269         return func(*args, **kwargs)
270
271     return wrapper_deprecated
272
273
274 def thunkify(func):
275     """
276     Make a function immediately return a function of no args which,
277     when called, waits for the result, which will start being
278     processed in another thread.
279     """
280
281     @functools.wraps(func)
282     def lazy_thunked(*args, **kwargs):
283         wait_event = threading.Event()
284
285         result = [None]
286         exc = [False, None]
287
288         def worker_func():
289             try:
290                 func_result = func(*args, **kwargs)
291                 result[0] = func_result
292             except Exception:
293                 exc[0] = True
294                 exc[1] = sys.exc_info()  # (type, value, traceback)
295                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
296                 print(msg)
297                 logger.warning(msg)
298             finally:
299                 wait_event.set()
300
301         def thunk():
302             wait_event.wait()
303             if exc[0]:
304                 raise exc[1][0](exc[1][1])
305             return result[0]
306
307         threading.Thread(target=worker_func).start()
308         return thunk
309
310     return lazy_thunked
311
312
313 ############################################################
314 # Timeout
315 ############################################################
316
317 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
318 # Used work of Stephen "Zero" Chappell <[email protected]>
319 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
320
321
322 def _raise_exception(exception, error_message: Optional[str]):
323     if error_message is None:
324         raise exception()
325     else:
326         raise exception(error_message)
327
328
329 def _target(queue, function, *args, **kwargs):
330     """Run a function with arguments and return output via a queue.
331
332     This is a helper function for the Process created in _Timeout. It runs
333     the function with positional arguments and keyword arguments and then
334     returns the function's output by way of a queue. If an exception gets
335     raised, it is returned to _Timeout to be raised by the value property.
336     """
337     try:
338         queue.put((True, function(*args, **kwargs)))
339     except Exception:
340         queue.put((False, sys.exc_info()[1]))
341
342
343 class _Timeout(object):
344     """Wrap a function and add a timeout (limit) attribute to it.
345
346     Instances of this class are automatically generated by the add_timeout
347     function defined below.
348     """
349
350     def __init__(
351         self,
352         function: Callable,
353         timeout_exception: Exception,
354         error_message: str,
355         seconds: float,
356     ):
357         self.__limit = seconds
358         self.__function = function
359         self.__timeout_exception = timeout_exception
360         self.__error_message = error_message
361         self.__name__ = function.__name__
362         self.__doc__ = function.__doc__
363         self.__timeout = time.time()
364         self.__process = multiprocessing.Process()
365         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
366
367     def __call__(self, *args, **kwargs):
368         """Execute the embedded function object asynchronously.
369
370         The function given to the constructor is transparently called and
371         requires that "ready" be intermittently polled. If and when it is
372         True, the "value" property may then be checked for returned data.
373         """
374         self.__limit = kwargs.pop("timeout", self.__limit)
375         self.__queue = multiprocessing.Queue(1)
376         args = (self.__queue, self.__function) + args
377         self.__process = multiprocessing.Process(
378             target=_target, args=args, kwargs=kwargs
379         )
380         self.__process.daemon = True
381         self.__process.start()
382         if self.__limit is not None:
383             self.__timeout = self.__limit + time.time()
384         while not self.ready:
385             time.sleep(0.1)
386         return self.value
387
388     def cancel(self):
389         """Terminate any possible execution of the embedded function."""
390         if self.__process.is_alive():
391             self.__process.terminate()
392         _raise_exception(self.__timeout_exception, self.__error_message)
393
394     @property
395     def ready(self):
396         """Read-only property indicating status of "value" property."""
397         if self.__limit and self.__timeout < time.time():
398             self.cancel()
399         return self.__queue.full() and not self.__queue.empty()
400
401     @property
402     def value(self):
403         """Read-only property containing data returned from function."""
404         if self.ready is True:
405             flag, load = self.__queue.get()
406             if flag:
407                 return load
408             raise load
409
410
411 def timeout(
412     seconds: float = 1.0,
413     use_signals: Optional[bool] = None,
414     timeout_exception=exceptions.TimeoutError,
415     error_message="Function call timed out",
416 ):
417     """Add a timeout parameter to a function and return the function.
418
419     Note: the use_signals parameter is included in order to support
420     multiprocessing scenarios (signal can only be used from the process'
421     main thread).  When not using signals, timeout granularity will be
422     rounded to the nearest 0.1s.
423
424     Raises an exception when the timeout is reached.
425
426     It is illegal to pass anything other than a function as the first
427     parameter.  The function is wrapped and returned to the caller.
428     """
429     if use_signals is None:
430         import thread_utils
431         use_signals = thread_utils.is_current_thread_main_thread()
432
433     def decorate(function):
434         if use_signals:
435
436             def handler(signum, frame):
437                 _raise_exception(timeout_exception, error_message)
438
439             @functools.wraps(function)
440             def new_function(*args, **kwargs):
441                 new_seconds = kwargs.pop("timeout", seconds)
442                 if new_seconds:
443                     old = signal.signal(signal.SIGALRM, handler)
444                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
445
446                 if not seconds:
447                     return function(*args, **kwargs)
448
449                 try:
450                     return function(*args, **kwargs)
451                 finally:
452                     if new_seconds:
453                         signal.setitimer(signal.ITIMER_REAL, 0)
454                         signal.signal(signal.SIGALRM, old)
455
456             return new_function
457         else:
458
459             @functools.wraps(function)
460             def new_function(*args, **kwargs):
461                 timeout_wrapper = _Timeout(
462                     function, timeout_exception, error_message, seconds
463                 )
464                 return timeout_wrapper(*args, **kwargs)
465
466             return new_function
467
468     return decorate
469
470
471 class non_reentrant_code(object):
472     def __init__(self):
473         self._lock = threading.RLock
474         self._entered = False
475
476     def __call__(self, f):
477         def _gatekeeper(*args, **kwargs):
478             with self._lock:
479                 if self._entered:
480                     return
481                 self._entered = True
482                 f(*args, **kwargs)
483                 self._entered = False
484
485         return _gatekeeper
486
487
488 class rlocked(object):
489     def __init__(self):
490         self._lock = threading.RLock
491         self._entered = False
492
493     def __call__(self, f):
494         def _gatekeeper(*args, **kwargs):
495             with self._lock:
496                 if self._entered:
497                     return
498                 self._entered = True
499                 f(*args, **kwargs)
500                 self._entered = False
501         return _gatekeeper
502
503
504 def call_with_sample_rate(sample_rate: float) -> Callable:
505     if not 0.0 <= sample_rate <= 1.0:
506         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
507         logger.critical(msg)
508         raise ValueError(msg)
509
510     def decorator(f):
511         @functools.wraps(f)
512         def _call_with_sample_rate(*args, **kwargs):
513             if random.uniform(0, 1) < sample_rate:
514                 return f(*args, **kwargs)
515             else:
516                 logger.debug(
517                     f"@call_with_sample_rate skipping a call to {f.__name__}"
518                 )
519         return _call_with_sample_rate
520     return decorator
521
522
523 def decorate_matching_methods_with(decorator, acl=None):
524     """Apply decorator to all methods in a class whose names begin with
525     prefix.  If prefix is None (default), decorate all methods in the
526     class.
527     """
528     def decorate_the_class(cls):
529         for name, m in inspect.getmembers(cls, inspect.isfunction):
530             if acl is None:
531                 setattr(cls, name, decorator(m))
532             else:
533                 if acl(name):
534                     setattr(cls, name, decorator(m))
535         return cls
536     return decorate_the_class