Add rate limiter decorator.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function."""
30
31     @functools.wraps(func)
32     def wrapper_timer(*args, **kwargs):
33         start_time = time.perf_counter()
34         value = func(*args, **kwargs)
35         end_time = time.perf_counter()
36         run_time = end_time - start_time
37         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
38         print(msg)
39         logger.info(msg)
40         return value
41     return wrapper_timer
42
43
44 def invocation_logged(func: Callable) -> Callable:
45     """Log the call of a function."""
46
47     @functools.wraps(func)
48     def wrapper_invocation_logged(*args, **kwargs):
49         msg = f"Entered {func.__qualname__}"
50         print(msg)
51         logger.info(msg)
52         ret = func(*args, **kwargs)
53         msg = f"Exited {func.__qualname__}"
54         print(msg)
55         logger.info(msg)
56         return ret
57     return wrapper_invocation_logged
58
59
60 def rate_limited(n_per_second: int) -> Callable:
61     """Limit invocation of a wrapped function to n calls per second.
62     Thread safe.
63
64     """
65     min_interval = 1.0 / float(n_per_second)
66
67     def wrapper_rate_limited(func: Callable) -> Callable:
68         last_invocation_time = [0.0]
69
70         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
71             while True:
72                 elapsed = time.clock_gettime(0) - last_invocation_time[0]
73                 wait_time = min_interval - elapsed
74                 if wait_time > 0.0:
75                     time.sleep(wait_time)
76                 else:
77                     break
78             ret = func(*args, **kargs)
79             last_invocation_time[0] = time.clock_gettime(0)
80             return ret
81         return wrapper_wrapper_rate_limited
82     return wrapper_rate_limited
83
84
85 def debug_args(func: Callable) -> Callable:
86     """Print the function signature and return value at each call."""
87
88     @functools.wraps(func)
89     def wrapper_debug_args(*args, **kwargs):
90         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
91         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
92         signature = ", ".join(args_repr + kwargs_repr)
93         msg = f"Calling {func.__name__}({signature})"
94         print(msg)
95         logger.info(msg)
96         value = func(*args, **kwargs)
97         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
98         logger.info(msg)
99         return value
100     return wrapper_debug_args
101
102
103 def debug_count_calls(func: Callable) -> Callable:
104     """Count function invocations and print a message befor every call."""
105
106     @functools.wraps(func)
107     def wrapper_debug_count_calls(*args, **kwargs):
108         wrapper_debug_count_calls.num_calls += 1
109         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
110         print(msg)
111         logger.info(msg)
112         return func(*args, **kwargs)
113     wrapper_debug_count_calls.num_calls = 0
114     return wrapper_debug_count_calls
115
116
117 class DelayWhen(enum.Enum):
118     BEFORE_CALL = 1
119     AFTER_CALL = 2
120     BEFORE_AND_AFTER = 3
121
122
123 def delay(
124     _func: Callable = None,
125     *,
126     seconds: float = 1.0,
127     when: DelayWhen = DelayWhen.BEFORE_CALL,
128 ) -> Callable:
129     """Delay the execution of a function by sleeping before and/or after.
130
131     Slow down a function by inserting a delay before and/or after its
132     invocation.
133     """
134
135     def decorator_delay(func: Callable) -> Callable:
136         @functools.wraps(func)
137         def wrapper_delay(*args, **kwargs):
138             if when & DelayWhen.BEFORE_CALL:
139                 logger.debug(
140                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
141                 )
142                 time.sleep(seconds)
143             retval = func(*args, **kwargs)
144             if when & DelayWhen.AFTER_CALL:
145                 logger.debug(
146                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
147                 )
148                 time.sleep(seconds)
149             return retval
150         return wrapper_delay
151
152     if _func is None:
153         return decorator_delay
154     else:
155         return decorator_delay(_func)
156
157
158 class _SingletonWrapper:
159     """
160     A singleton wrapper class. Its instances would be created
161     for each decorated class.
162     """
163
164     def __init__(self, cls):
165         self.__wrapped__ = cls
166         self._instance = None
167
168     def __call__(self, *args, **kwargs):
169         """Returns a single instance of decorated class"""
170         logger.debug(
171             f"@singleton returning global instance of {self.__wrapped__.__name__}"
172         )
173         if self._instance is None:
174             self._instance = self.__wrapped__(*args, **kwargs)
175         return self._instance
176
177
178 def singleton(cls):
179     """
180     A singleton decorator. Returns a wrapper objects. A call on that object
181     returns a single instance object of decorated class. Use the __wrapped__
182     attribute to access decorated class directly in unit tests
183     """
184     return _SingletonWrapper(cls)
185
186
187 def memoized(func: Callable) -> Callable:
188     """Keep a cache of previous function call results.
189
190     The cache here is a dict with a key based on the arguments to the
191     call.  Consider also: functools.lru_cache for a more advanced
192     implementation.
193     """
194
195     @functools.wraps(func)
196     def wrapper_memoized(*args, **kwargs):
197         cache_key = args + tuple(kwargs.items())
198         if cache_key not in wrapper_memoized.cache:
199             value = func(*args, **kwargs)
200             logger.debug(
201                 f"Memoizing {cache_key} => {value} for {func.__name__}"
202             )
203             wrapper_memoized.cache[cache_key] = value
204         else:
205             logger.debug(f"Returning memoized value for {func.__name__}")
206         return wrapper_memoized.cache[cache_key]
207     wrapper_memoized.cache = dict()
208     return wrapper_memoized
209
210
211 def retry_predicate(
212     tries: int,
213     *,
214     predicate: Callable[..., bool],
215     delay_sec: float = 3.0,
216     backoff: float = 2.0,
217 ):
218     """Retries a function or method up to a certain number of times
219     with a prescribed initial delay period and backoff rate.
220
221     tries is the maximum number of attempts to run the function.
222     delay_sec sets the initial delay period in seconds.
223     backoff is a multiplied (must be >1) used to modify the delay.
224     predicate is a function that will be passed the retval of the
225     decorated function and must return True to stop or False to
226     retry.
227     """
228     if backoff < 1.0:
229         msg = f"backoff must be greater than or equal to 1, got {backoff}"
230         logger.critical(msg)
231         raise ValueError(msg)
232
233     tries = math.floor(tries)
234     if tries < 0:
235         msg = f"tries must be 0 or greater, got {tries}"
236         logger.critical(msg)
237         raise ValueError(msg)
238
239     if delay_sec <= 0:
240         msg = f"delay_sec must be greater than 0, got {delay_sec}"
241         logger.critical(msg)
242         raise ValueError(msg)
243
244     def deco_retry(f):
245         @functools.wraps(f)
246         def f_retry(*args, **kwargs):
247             mtries, mdelay = tries, delay_sec  # make mutable
248             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
249             retval = f(*args, **kwargs)
250             while mtries > 0:
251                 if predicate(retval) is True:
252                     logger.debug('Predicate succeeded, deco_retry is done.')
253                     return retval
254                 logger.debug("Predicate failed, sleeping and retrying.")
255                 mtries -= 1
256                 time.sleep(mdelay)
257                 mdelay *= backoff
258                 retval = f(*args, **kwargs)
259             return retval
260         return f_retry
261     return deco_retry
262
263
264 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
265     return retry_predicate(
266         tries,
267         predicate=lambda x: x is True,
268         delay_sec=delay_sec,
269         backoff=backoff,
270     )
271
272
273 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
274     return retry_predicate(
275         tries,
276         predicate=lambda x: x is not None,
277         delay_sec=delay_sec,
278         backoff=backoff,
279     )
280
281
282 def deprecated(func):
283     """This is a decorator which can be used to mark functions
284     as deprecated. It will result in a warning being emitted
285     when the function is used.
286     """
287
288     @functools.wraps(func)
289     def wrapper_deprecated(*args, **kwargs):
290         msg = f"Call to deprecated function {func.__name__}"
291         logger.warning(msg)
292         warnings.warn(msg, category=DeprecationWarning)
293         return func(*args, **kwargs)
294
295     return wrapper_deprecated
296
297
298 def thunkify(func):
299     """
300     Make a function immediately return a function of no args which,
301     when called, waits for the result, which will start being
302     processed in another thread.
303     """
304
305     @functools.wraps(func)
306     def lazy_thunked(*args, **kwargs):
307         wait_event = threading.Event()
308
309         result = [None]
310         exc = [False, None]
311
312         def worker_func():
313             try:
314                 func_result = func(*args, **kwargs)
315                 result[0] = func_result
316             except Exception:
317                 exc[0] = True
318                 exc[1] = sys.exc_info()  # (type, value, traceback)
319                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
320                 print(msg)
321                 logger.warning(msg)
322             finally:
323                 wait_event.set()
324
325         def thunk():
326             wait_event.wait()
327             if exc[0]:
328                 raise exc[1][0](exc[1][1])
329             return result[0]
330
331         threading.Thread(target=worker_func).start()
332         return thunk
333
334     return lazy_thunked
335
336
337 ############################################################
338 # Timeout
339 ############################################################
340
341 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
342 # Used work of Stephen "Zero" Chappell <[email protected]>
343 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
344
345
346 def _raise_exception(exception, error_message: Optional[str]):
347     if error_message is None:
348         raise exception()
349     else:
350         raise exception(error_message)
351
352
353 def _target(queue, function, *args, **kwargs):
354     """Run a function with arguments and return output via a queue.
355
356     This is a helper function for the Process created in _Timeout. It runs
357     the function with positional arguments and keyword arguments and then
358     returns the function's output by way of a queue. If an exception gets
359     raised, it is returned to _Timeout to be raised by the value property.
360     """
361     try:
362         queue.put((True, function(*args, **kwargs)))
363     except Exception:
364         queue.put((False, sys.exc_info()[1]))
365
366
367 class _Timeout(object):
368     """Wrap a function and add a timeout (limit) attribute to it.
369
370     Instances of this class are automatically generated by the add_timeout
371     function defined below.
372     """
373
374     def __init__(
375         self,
376         function: Callable,
377         timeout_exception: Exception,
378         error_message: str,
379         seconds: float,
380     ):
381         self.__limit = seconds
382         self.__function = function
383         self.__timeout_exception = timeout_exception
384         self.__error_message = error_message
385         self.__name__ = function.__name__
386         self.__doc__ = function.__doc__
387         self.__timeout = time.time()
388         self.__process = multiprocessing.Process()
389         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
390
391     def __call__(self, *args, **kwargs):
392         """Execute the embedded function object asynchronously.
393
394         The function given to the constructor is transparently called and
395         requires that "ready" be intermittently polled. If and when it is
396         True, the "value" property may then be checked for returned data.
397         """
398         self.__limit = kwargs.pop("timeout", self.__limit)
399         self.__queue = multiprocessing.Queue(1)
400         args = (self.__queue, self.__function) + args
401         self.__process = multiprocessing.Process(
402             target=_target, args=args, kwargs=kwargs
403         )
404         self.__process.daemon = True
405         self.__process.start()
406         if self.__limit is not None:
407             self.__timeout = self.__limit + time.time()
408         while not self.ready:
409             time.sleep(0.1)
410         return self.value
411
412     def cancel(self):
413         """Terminate any possible execution of the embedded function."""
414         if self.__process.is_alive():
415             self.__process.terminate()
416         _raise_exception(self.__timeout_exception, self.__error_message)
417
418     @property
419     def ready(self):
420         """Read-only property indicating status of "value" property."""
421         if self.__limit and self.__timeout < time.time():
422             self.cancel()
423         return self.__queue.full() and not self.__queue.empty()
424
425     @property
426     def value(self):
427         """Read-only property containing data returned from function."""
428         if self.ready is True:
429             flag, load = self.__queue.get()
430             if flag:
431                 return load
432             raise load
433
434
435 def timeout(
436     seconds: float = 1.0,
437     use_signals: Optional[bool] = None,
438     timeout_exception=exceptions.TimeoutError,
439     error_message="Function call timed out",
440 ):
441     """Add a timeout parameter to a function and return the function.
442
443     Note: the use_signals parameter is included in order to support
444     multiprocessing scenarios (signal can only be used from the process'
445     main thread).  When not using signals, timeout granularity will be
446     rounded to the nearest 0.1s.
447
448     Raises an exception when the timeout is reached.
449
450     It is illegal to pass anything other than a function as the first
451     parameter.  The function is wrapped and returned to the caller.
452     """
453     if use_signals is None:
454         import thread_utils
455         use_signals = thread_utils.is_current_thread_main_thread()
456
457     def decorate(function):
458         if use_signals:
459
460             def handler(signum, frame):
461                 _raise_exception(timeout_exception, error_message)
462
463             @functools.wraps(function)
464             def new_function(*args, **kwargs):
465                 new_seconds = kwargs.pop("timeout", seconds)
466                 if new_seconds:
467                     old = signal.signal(signal.SIGALRM, handler)
468                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
469
470                 if not seconds:
471                     return function(*args, **kwargs)
472
473                 try:
474                     return function(*args, **kwargs)
475                 finally:
476                     if new_seconds:
477                         signal.setitimer(signal.ITIMER_REAL, 0)
478                         signal.signal(signal.SIGALRM, old)
479
480             return new_function
481         else:
482
483             @functools.wraps(function)
484             def new_function(*args, **kwargs):
485                 timeout_wrapper = _Timeout(
486                     function, timeout_exception, error_message, seconds
487                 )
488                 return timeout_wrapper(*args, **kwargs)
489
490             return new_function
491
492     return decorate
493
494
495 class non_reentrant_code(object):
496     def __init__(self):
497         self._lock = threading.RLock
498         self._entered = False
499
500     def __call__(self, f):
501         def _gatekeeper(*args, **kwargs):
502             with self._lock:
503                 if self._entered:
504                     return
505                 self._entered = True
506                 f(*args, **kwargs)
507                 self._entered = False
508
509         return _gatekeeper
510
511
512 class rlocked(object):
513     def __init__(self):
514         self._lock = threading.RLock
515         self._entered = False
516
517     def __call__(self, f):
518         def _gatekeeper(*args, **kwargs):
519             with self._lock:
520                 if self._entered:
521                     return
522                 self._entered = True
523                 f(*args, **kwargs)
524                 self._entered = False
525         return _gatekeeper
526
527
528 def call_with_sample_rate(sample_rate: float) -> Callable:
529     if not 0.0 <= sample_rate <= 1.0:
530         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
531         logger.critical(msg)
532         raise ValueError(msg)
533
534     def decorator(f):
535         @functools.wraps(f)
536         def _call_with_sample_rate(*args, **kwargs):
537             if random.uniform(0, 1) < sample_rate:
538                 return f(*args, **kwargs)
539             else:
540                 logger.debug(
541                     f"@call_with_sample_rate skipping a call to {f.__name__}"
542                 )
543         return _call_with_sample_rate
544     return decorator
545
546
547 def decorate_matching_methods_with(decorator, acl=None):
548     """Apply decorator to all methods in a class whose names begin with
549     prefix.  If prefix is None (default), decorate all methods in the
550     class.
551     """
552     def decorate_the_class(cls):
553         for name, m in inspect.getmembers(cls, inspect.isfunction):
554             if acl is None:
555                 setattr(cls, name, decorator(m))
556             else:
557                 if acl(name):
558                     setattr(cls, name, decorator(m))
559         return cls
560     return decorate_the_class