Make rate_limited use cvs.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import enum
6 import functools
7 import inspect
8 import logging
9 import math
10 import multiprocessing
11 import random
12 import signal
13 import sys
14 import threading
15 import time
16 import traceback
17 from typing import Any, Callable, Optional, Tuple
18 import warnings
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import exceptions
23
24
25 logger = logging.getLogger(__name__)
26
27
28 def timed(func: Callable) -> Callable:
29     """Print the runtime of the decorated function."""
30
31     @functools.wraps(func)
32     def wrapper_timer(*args, **kwargs):
33         start_time = time.perf_counter()
34         value = func(*args, **kwargs)
35         end_time = time.perf_counter()
36         run_time = end_time - start_time
37         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
38         print(msg)
39         logger.info(msg)
40         return value
41     return wrapper_timer
42
43
44 def invocation_logged(func: Callable) -> Callable:
45     """Log the call of a function."""
46
47     @functools.wraps(func)
48     def wrapper_invocation_logged(*args, **kwargs):
49         msg = f"Entered {func.__qualname__}"
50         print(msg)
51         logger.info(msg)
52         ret = func(*args, **kwargs)
53         msg = f"Exited {func.__qualname__}"
54         print(msg)
55         logger.info(msg)
56         return ret
57     return wrapper_invocation_logged
58
59
60 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
61     """Limit invocation of a wrapped function to n calls per period.
62     Thread safe.  In testing this was relatively fair with multiple
63     threads using it though that hasn't been measured.
64
65     """
66     min_interval_seconds = per_period_in_seconds / float(n_calls)
67
68     def wrapper_rate_limited(func: Callable) -> Callable:
69         cv = threading.Condition()
70         last_invocation_timestamp = [0.0]
71
72         def may_proceed() -> float:
73             now = time.time()
74             last_invocation = last_invocation_timestamp[0]
75             if last_invocation != 0.0:
76                 elapsed_since_last = now - last_invocation
77                 wait_time = min_interval_seconds - elapsed_since_last
78             else:
79                 wait_time = 0.0
80             return wait_time
81
82         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
83             with cv:
84                 while True:
85                     cv.wait_for(
86                         lambda: may_proceed() <= 0.0,
87                         timeout=may_proceed(),
88                     )
89                     break
90             ret = func(*args, **kargs)
91             with cv:
92                 last_invocation_timestamp[0] = time.time()
93                 cv.notify()
94             return ret
95         return wrapper_wrapper_rate_limited
96     return wrapper_rate_limited
97
98
99 def debug_args(func: Callable) -> Callable:
100     """Print the function signature and return value at each call."""
101
102     @functools.wraps(func)
103     def wrapper_debug_args(*args, **kwargs):
104         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
105         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
106         signature = ", ".join(args_repr + kwargs_repr)
107         msg = f"Calling {func.__name__}({signature})"
108         print(msg)
109         logger.info(msg)
110         value = func(*args, **kwargs)
111         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
112         logger.info(msg)
113         return value
114     return wrapper_debug_args
115
116
117 def debug_count_calls(func: Callable) -> Callable:
118     """Count function invocations and print a message befor every call."""
119
120     @functools.wraps(func)
121     def wrapper_debug_count_calls(*args, **kwargs):
122         wrapper_debug_count_calls.num_calls += 1
123         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
124         print(msg)
125         logger.info(msg)
126         return func(*args, **kwargs)
127     wrapper_debug_count_calls.num_calls = 0
128     return wrapper_debug_count_calls
129
130
131 class DelayWhen(enum.Enum):
132     BEFORE_CALL = 1
133     AFTER_CALL = 2
134     BEFORE_AND_AFTER = 3
135
136
137 def delay(
138     _func: Callable = None,
139     *,
140     seconds: float = 1.0,
141     when: DelayWhen = DelayWhen.BEFORE_CALL,
142 ) -> Callable:
143     """Delay the execution of a function by sleeping before and/or after.
144
145     Slow down a function by inserting a delay before and/or after its
146     invocation.
147     """
148
149     def decorator_delay(func: Callable) -> Callable:
150         @functools.wraps(func)
151         def wrapper_delay(*args, **kwargs):
152             if when & DelayWhen.BEFORE_CALL:
153                 logger.debug(
154                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
155                 )
156                 time.sleep(seconds)
157             retval = func(*args, **kwargs)
158             if when & DelayWhen.AFTER_CALL:
159                 logger.debug(
160                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
161                 )
162                 time.sleep(seconds)
163             return retval
164         return wrapper_delay
165
166     if _func is None:
167         return decorator_delay
168     else:
169         return decorator_delay(_func)
170
171
172 class _SingletonWrapper:
173     """
174     A singleton wrapper class. Its instances would be created
175     for each decorated class.
176     """
177
178     def __init__(self, cls):
179         self.__wrapped__ = cls
180         self._instance = None
181
182     def __call__(self, *args, **kwargs):
183         """Returns a single instance of decorated class"""
184         logger.debug(
185             f"@singleton returning global instance of {self.__wrapped__.__name__}"
186         )
187         if self._instance is None:
188             self._instance = self.__wrapped__(*args, **kwargs)
189         return self._instance
190
191
192 def singleton(cls):
193     """
194     A singleton decorator. Returns a wrapper objects. A call on that object
195     returns a single instance object of decorated class. Use the __wrapped__
196     attribute to access decorated class directly in unit tests
197     """
198     return _SingletonWrapper(cls)
199
200
201 def memoized(func: Callable) -> Callable:
202     """Keep a cache of previous function call results.
203
204     The cache here is a dict with a key based on the arguments to the
205     call.  Consider also: functools.lru_cache for a more advanced
206     implementation.
207     """
208
209     @functools.wraps(func)
210     def wrapper_memoized(*args, **kwargs):
211         cache_key = args + tuple(kwargs.items())
212         if cache_key not in wrapper_memoized.cache:
213             value = func(*args, **kwargs)
214             logger.debug(
215                 f"Memoizing {cache_key} => {value} for {func.__name__}"
216             )
217             wrapper_memoized.cache[cache_key] = value
218         else:
219             logger.debug(f"Returning memoized value for {func.__name__}")
220         return wrapper_memoized.cache[cache_key]
221     wrapper_memoized.cache = dict()
222     return wrapper_memoized
223
224
225 def retry_predicate(
226     tries: int,
227     *,
228     predicate: Callable[..., bool],
229     delay_sec: float = 3.0,
230     backoff: float = 2.0,
231 ):
232     """Retries a function or method up to a certain number of times
233     with a prescribed initial delay period and backoff rate.
234
235     tries is the maximum number of attempts to run the function.
236     delay_sec sets the initial delay period in seconds.
237     backoff is a multiplied (must be >1) used to modify the delay.
238     predicate is a function that will be passed the retval of the
239     decorated function and must return True to stop or False to
240     retry.
241     """
242     if backoff < 1.0:
243         msg = f"backoff must be greater than or equal to 1, got {backoff}"
244         logger.critical(msg)
245         raise ValueError(msg)
246
247     tries = math.floor(tries)
248     if tries < 0:
249         msg = f"tries must be 0 or greater, got {tries}"
250         logger.critical(msg)
251         raise ValueError(msg)
252
253     if delay_sec <= 0:
254         msg = f"delay_sec must be greater than 0, got {delay_sec}"
255         logger.critical(msg)
256         raise ValueError(msg)
257
258     def deco_retry(f):
259         @functools.wraps(f)
260         def f_retry(*args, **kwargs):
261             mtries, mdelay = tries, delay_sec  # make mutable
262             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
263             retval = f(*args, **kwargs)
264             while mtries > 0:
265                 if predicate(retval) is True:
266                     logger.debug('Predicate succeeded, deco_retry is done.')
267                     return retval
268                 logger.debug("Predicate failed, sleeping and retrying.")
269                 mtries -= 1
270                 time.sleep(mdelay)
271                 mdelay *= backoff
272                 retval = f(*args, **kwargs)
273             return retval
274         return f_retry
275     return deco_retry
276
277
278 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
279     return retry_predicate(
280         tries,
281         predicate=lambda x: x is True,
282         delay_sec=delay_sec,
283         backoff=backoff,
284     )
285
286
287 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
288     return retry_predicate(
289         tries,
290         predicate=lambda x: x is not None,
291         delay_sec=delay_sec,
292         backoff=backoff,
293     )
294
295
296 def deprecated(func):
297     """This is a decorator which can be used to mark functions
298     as deprecated. It will result in a warning being emitted
299     when the function is used.
300     """
301
302     @functools.wraps(func)
303     def wrapper_deprecated(*args, **kwargs):
304         msg = f"Call to deprecated function {func.__name__}"
305         logger.warning(msg)
306         warnings.warn(msg, category=DeprecationWarning)
307         return func(*args, **kwargs)
308
309     return wrapper_deprecated
310
311
312 def thunkify(func):
313     """
314     Make a function immediately return a function of no args which,
315     when called, waits for the result, which will start being
316     processed in another thread.
317     """
318
319     @functools.wraps(func)
320     def lazy_thunked(*args, **kwargs):
321         wait_event = threading.Event()
322
323         result = [None]
324         exc = [False, None]
325
326         def worker_func():
327             try:
328                 func_result = func(*args, **kwargs)
329                 result[0] = func_result
330             except Exception:
331                 exc[0] = True
332                 exc[1] = sys.exc_info()  # (type, value, traceback)
333                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
334                 print(msg)
335                 logger.warning(msg)
336             finally:
337                 wait_event.set()
338
339         def thunk():
340             wait_event.wait()
341             if exc[0]:
342                 raise exc[1][0](exc[1][1])
343             return result[0]
344
345         threading.Thread(target=worker_func).start()
346         return thunk
347
348     return lazy_thunked
349
350
351 ############################################################
352 # Timeout
353 ############################################################
354
355 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
356 # Used work of Stephen "Zero" Chappell <[email protected]>
357 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
358
359
360 def _raise_exception(exception, error_message: Optional[str]):
361     if error_message is None:
362         raise exception()
363     else:
364         raise exception(error_message)
365
366
367 def _target(queue, function, *args, **kwargs):
368     """Run a function with arguments and return output via a queue.
369
370     This is a helper function for the Process created in _Timeout. It runs
371     the function with positional arguments and keyword arguments and then
372     returns the function's output by way of a queue. If an exception gets
373     raised, it is returned to _Timeout to be raised by the value property.
374     """
375     try:
376         queue.put((True, function(*args, **kwargs)))
377     except Exception:
378         queue.put((False, sys.exc_info()[1]))
379
380
381 class _Timeout(object):
382     """Wrap a function and add a timeout (limit) attribute to it.
383
384     Instances of this class are automatically generated by the add_timeout
385     function defined below.
386     """
387
388     def __init__(
389         self,
390         function: Callable,
391         timeout_exception: Exception,
392         error_message: str,
393         seconds: float,
394     ):
395         self.__limit = seconds
396         self.__function = function
397         self.__timeout_exception = timeout_exception
398         self.__error_message = error_message
399         self.__name__ = function.__name__
400         self.__doc__ = function.__doc__
401         self.__timeout = time.time()
402         self.__process = multiprocessing.Process()
403         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
404
405     def __call__(self, *args, **kwargs):
406         """Execute the embedded function object asynchronously.
407
408         The function given to the constructor is transparently called and
409         requires that "ready" be intermittently polled. If and when it is
410         True, the "value" property may then be checked for returned data.
411         """
412         self.__limit = kwargs.pop("timeout", self.__limit)
413         self.__queue = multiprocessing.Queue(1)
414         args = (self.__queue, self.__function) + args
415         self.__process = multiprocessing.Process(
416             target=_target, args=args, kwargs=kwargs
417         )
418         self.__process.daemon = True
419         self.__process.start()
420         if self.__limit is not None:
421             self.__timeout = self.__limit + time.time()
422         while not self.ready:
423             time.sleep(0.1)
424         return self.value
425
426     def cancel(self):
427         """Terminate any possible execution of the embedded function."""
428         if self.__process.is_alive():
429             self.__process.terminate()
430         _raise_exception(self.__timeout_exception, self.__error_message)
431
432     @property
433     def ready(self):
434         """Read-only property indicating status of "value" property."""
435         if self.__limit and self.__timeout < time.time():
436             self.cancel()
437         return self.__queue.full() and not self.__queue.empty()
438
439     @property
440     def value(self):
441         """Read-only property containing data returned from function."""
442         if self.ready is True:
443             flag, load = self.__queue.get()
444             if flag:
445                 return load
446             raise load
447
448
449 def timeout(
450     seconds: float = 1.0,
451     use_signals: Optional[bool] = None,
452     timeout_exception=exceptions.TimeoutError,
453     error_message="Function call timed out",
454 ):
455     """Add a timeout parameter to a function and return the function.
456
457     Note: the use_signals parameter is included in order to support
458     multiprocessing scenarios (signal can only be used from the process'
459     main thread).  When not using signals, timeout granularity will be
460     rounded to the nearest 0.1s.
461
462     Raises an exception when the timeout is reached.
463
464     It is illegal to pass anything other than a function as the first
465     parameter.  The function is wrapped and returned to the caller.
466     """
467     if use_signals is None:
468         import thread_utils
469         use_signals = thread_utils.is_current_thread_main_thread()
470
471     def decorate(function):
472         if use_signals:
473
474             def handler(signum, frame):
475                 _raise_exception(timeout_exception, error_message)
476
477             @functools.wraps(function)
478             def new_function(*args, **kwargs):
479                 new_seconds = kwargs.pop("timeout", seconds)
480                 if new_seconds:
481                     old = signal.signal(signal.SIGALRM, handler)
482                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
483
484                 if not seconds:
485                     return function(*args, **kwargs)
486
487                 try:
488                     return function(*args, **kwargs)
489                 finally:
490                     if new_seconds:
491                         signal.setitimer(signal.ITIMER_REAL, 0)
492                         signal.signal(signal.SIGALRM, old)
493
494             return new_function
495         else:
496
497             @functools.wraps(function)
498             def new_function(*args, **kwargs):
499                 timeout_wrapper = _Timeout(
500                     function, timeout_exception, error_message, seconds
501                 )
502                 return timeout_wrapper(*args, **kwargs)
503
504             return new_function
505
506     return decorate
507
508
509 class non_reentrant_code(object):
510     def __init__(self):
511         self._lock = threading.RLock
512         self._entered = False
513
514     def __call__(self, f):
515         def _gatekeeper(*args, **kwargs):
516             with self._lock:
517                 if self._entered:
518                     return
519                 self._entered = True
520                 f(*args, **kwargs)
521                 self._entered = False
522
523         return _gatekeeper
524
525
526 class rlocked(object):
527     def __init__(self):
528         self._lock = threading.RLock
529         self._entered = False
530
531     def __call__(self, f):
532         def _gatekeeper(*args, **kwargs):
533             with self._lock:
534                 if self._entered:
535                     return
536                 self._entered = True
537                 f(*args, **kwargs)
538                 self._entered = False
539         return _gatekeeper
540
541
542 def call_with_sample_rate(sample_rate: float) -> Callable:
543     if not 0.0 <= sample_rate <= 1.0:
544         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
545         logger.critical(msg)
546         raise ValueError(msg)
547
548     def decorator(f):
549         @functools.wraps(f)
550         def _call_with_sample_rate(*args, **kwargs):
551             if random.uniform(0, 1) < sample_rate:
552                 return f(*args, **kwargs)
553             else:
554                 logger.debug(
555                     f"@call_with_sample_rate skipping a call to {f.__name__}"
556                 )
557         return _call_with_sample_rate
558     return decorator
559
560
561 def decorate_matching_methods_with(decorator, acl=None):
562     """Apply decorator to all methods in a class whose names begin with
563     prefix.  If prefix is None (default), decorate all methods in the
564     class.
565     """
566     def decorate_the_class(cls):
567         for name, m in inspect.getmembers(cls, inspect.isfunction):
568             if acl is None:
569                 setattr(cls, name, decorator(m))
570             else:
571                 if acl(name):
572                     setattr(cls, name, decorator(m))
573         return cls
574     return decorate_the_class