Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
5
6 """Decorators."""
7
8 import enum
9 import functools
10 import inspect
11 import logging
12 import math
13 import multiprocessing
14 import random
15 import signal
16 import sys
17 import threading
18 import time
19 import traceback
20 import warnings
21 from typing import Any, Callable, Optional
22
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
25 import exceptions
26
27 logger = logging.getLogger(__name__)
28
29
30 def timed(func: Callable) -> Callable:
31     """Print the runtime of the decorated function.
32
33     >>> @timed
34     ... def foo():
35     ...     import time
36     ...     time.sleep(0.1)
37
38     >>> foo()  # doctest: +ELLIPSIS
39     Finished foo in ...
40
41     """
42
43     @functools.wraps(func)
44     def wrapper_timer(*args, **kwargs):
45         start_time = time.perf_counter()
46         value = func(*args, **kwargs)
47         end_time = time.perf_counter()
48         run_time = end_time - start_time
49         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
50         print(msg)
51         logger.info(msg)
52         return value
53
54     return wrapper_timer
55
56
57 def invocation_logged(func: Callable) -> Callable:
58     """Log the call of a function.
59
60     >>> @invocation_logged
61     ... def foo():
62     ...     print('Hello, world.')
63
64     >>> foo()
65     Entered foo
66     Hello, world.
67     Exited foo
68
69     """
70
71     @functools.wraps(func)
72     def wrapper_invocation_logged(*args, **kwargs):
73         msg = f"Entered {func.__qualname__}"
74         print(msg)
75         logger.info(msg)
76         ret = func(*args, **kwargs)
77         msg = f"Exited {func.__qualname__}"
78         print(msg)
79         logger.info(msg)
80         return ret
81
82     return wrapper_invocation_logged
83
84
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86     """Limit invocation of a wrapped function to n calls per period.
87     Thread safe.  In testing this was relatively fair with multiple
88     threads using it though that hasn't been measured.
89
90     >>> import time
91     >>> import decorator_utils
92     >>> import thread_utils
93
94     >>> calls = 0
95
96     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97     ... def limited(x: int):
98     ...     global calls
99     ...     calls += 1
100
101     >>> @thread_utils.background_thread
102     ... def a(stop):
103     ...     for _ in range(3):
104     ...         limited(_)
105
106     >>> @thread_utils.background_thread
107     ... def b(stop):
108     ...     for _ in range(3):
109     ...         limited(_)
110
111     >>> start = time.time()
112     >>> (t1, e1) = a()
113     >>> (t2, e2) = b()
114     >>> t1.join()
115     >>> t2.join()
116     >>> end = time.time()
117     >>> dur = end - start
118     >>> dur > 0.5
119     True
120
121     >>> calls
122     6
123
124     """
125     min_interval_seconds = per_period_in_seconds / float(n_calls)
126
127     def wrapper_rate_limited(func: Callable) -> Callable:
128         cv = threading.Condition()
129         last_invocation_timestamp = [0.0]
130
131         def may_proceed() -> float:
132             now = time.time()
133             last_invocation = last_invocation_timestamp[0]
134             if last_invocation != 0.0:
135                 elapsed_since_last = now - last_invocation
136                 wait_time = min_interval_seconds - elapsed_since_last
137             else:
138                 wait_time = 0.0
139             logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
140             return wait_time
141
142         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143             with cv:
144                 while True:
145                     if cv.wait_for(
146                         lambda: may_proceed() <= 0.0,
147                         timeout=may_proceed(),
148                     ):
149                         break
150             with cv:
151                 logger.debug('@%.4f> calling it...', time.time())
152                 ret = func(*args, **kargs)
153                 last_invocation_timestamp[0] = time.time()
154                 logger.debug(
155                     '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
156                 )
157                 cv.notify()
158             return ret
159
160         return wrapper_wrapper_rate_limited
161
162     return wrapper_rate_limited
163
164
165 def debug_args(func: Callable) -> Callable:
166     """Print the function signature and return value at each call.
167
168     >>> @debug_args
169     ... def foo(a, b, c):
170     ...     print(a)
171     ...     print(b)
172     ...     print(c)
173     ...     return (a + b, c)
174
175     >>> foo(1, 2.0, "test")
176     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
177     1
178     2.0
179     test
180     foo returned (3.0, 'test'):<class 'tuple'>
181     (3.0, 'test')
182     """
183
184     @functools.wraps(func)
185     def wrapper_debug_args(*args, **kwargs):
186         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188         signature = ", ".join(args_repr + kwargs_repr)
189         msg = f"Calling {func.__qualname__}({signature})"
190         print(msg)
191         logger.info(msg)
192         value = func(*args, **kwargs)
193         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
194         print(msg)
195         logger.info(msg)
196         return value
197
198     return wrapper_debug_args
199
200
201 def debug_count_calls(func: Callable) -> Callable:
202     """Count function invocations and print a message befor every call.
203
204     >>> @debug_count_calls
205     ... def factoral(x):
206     ...     if x == 1:
207     ...         return 1
208     ...     return x * factoral(x - 1)
209
210     >>> factoral(5)
211     Call #1 of 'factoral'
212     Call #2 of 'factoral'
213     Call #3 of 'factoral'
214     Call #4 of 'factoral'
215     Call #5 of 'factoral'
216     120
217
218     """
219
220     @functools.wraps(func)
221     def wrapper_debug_count_calls(*args, **kwargs):
222         wrapper_debug_count_calls.num_calls += 1
223         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
224         print(msg)
225         logger.info(msg)
226         return func(*args, **kwargs)
227
228     wrapper_debug_count_calls.num_calls = 0  # type: ignore
229     return wrapper_debug_count_calls
230
231
232 class DelayWhen(enum.IntEnum):
233     """When should we delay: before or after calling the function (or
234     both)?
235
236     """
237
238     BEFORE_CALL = 1
239     AFTER_CALL = 2
240     BEFORE_AND_AFTER = 3
241
242
243 def delay(
244     _func: Callable = None,
245     *,
246     seconds: float = 1.0,
247     when: DelayWhen = DelayWhen.BEFORE_CALL,
248 ) -> Callable:
249     """Delay the execution of a function by sleeping before and/or after.
250
251     Slow down a function by inserting a delay before and/or after its
252     invocation.
253
254     >>> import time
255
256     >>> @delay(seconds=1.0)
257     ... def foo():
258     ...     pass
259
260     >>> start = time.time()
261     >>> foo()
262     >>> dur = time.time() - start
263     >>> dur >= 1.0
264     True
265
266     """
267
268     def decorator_delay(func: Callable) -> Callable:
269         @functools.wraps(func)
270         def wrapper_delay(*args, **kwargs):
271             if when & DelayWhen.BEFORE_CALL:
272                 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
273                 time.sleep(seconds)
274             retval = func(*args, **kwargs)
275             if when & DelayWhen.AFTER_CALL:
276                 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
277                 time.sleep(seconds)
278             return retval
279
280         return wrapper_delay
281
282     if _func is None:
283         return decorator_delay
284     else:
285         return decorator_delay(_func)
286
287
288 class _SingletonWrapper:
289     """
290     A singleton wrapper class. Its instances would be created
291     for each decorated class.
292
293     """
294
295     def __init__(self, cls):
296         self.__wrapped__ = cls
297         self._instance = None
298
299     def __call__(self, *args, **kwargs):
300         """Returns a single instance of decorated class"""
301         logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
302         if self._instance is None:
303             self._instance = self.__wrapped__(*args, **kwargs)
304         return self._instance
305
306
307 def singleton(cls):
308     """
309     A singleton decorator. Returns a wrapper objects. A call on that object
310     returns a single instance object of decorated class. Use the __wrapped__
311     attribute to access decorated class directly in unit tests
312
313     >>> @singleton
314     ... class foo(object):
315     ...     pass
316
317     >>> a = foo()
318     >>> b = foo()
319     >>> a is b
320     True
321
322     >>> id(a) == id(b)
323     True
324
325     """
326     return _SingletonWrapper(cls)
327
328
329 def memoized(func: Callable) -> Callable:
330     """Keep a cache of previous function call results.
331
332     The cache here is a dict with a key based on the arguments to the
333     call.  Consider also: functools.lru_cache for a more advanced
334     implementation.
335
336     >>> import time
337
338     >>> @memoized
339     ... def expensive(arg) -> int:
340     ...     # Simulate something slow to compute or lookup
341     ...     time.sleep(1.0)
342     ...     return arg * arg
343
344     >>> start = time.time()
345     >>> expensive(5)           # Takes about 1 sec
346     25
347
348     >>> expensive(3)           # Also takes about 1 sec
349     9
350
351     >>> expensive(5)           # Pulls from cache, fast
352     25
353
354     >>> expensive(3)           # Pulls from cache again, fast
355     9
356
357     >>> dur = time.time() - start
358     >>> dur < 3.0
359     True
360
361     """
362
363     @functools.wraps(func)
364     def wrapper_memoized(*args, **kwargs):
365         cache_key = args + tuple(kwargs.items())
366         if cache_key not in wrapper_memoized.cache:
367             value = func(*args, **kwargs)
368             logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
369             wrapper_memoized.cache[cache_key] = value
370         else:
371             logger.debug('Returning memoized value for %s', {func.__name__})
372         return wrapper_memoized.cache[cache_key]
373
374     wrapper_memoized.cache = {}  # type: ignore
375     return wrapper_memoized
376
377
378 def retry_predicate(
379     tries: int,
380     *,
381     predicate: Callable[..., bool],
382     delay_sec: float = 3.0,
383     backoff: float = 2.0,
384 ):
385     """Retries a function or method up to a certain number of times
386     with a prescribed initial delay period and backoff rate.
387
388     tries is the maximum number of attempts to run the function.
389     delay_sec sets the initial delay period in seconds.
390     backoff is a multiplied (must be >1) used to modify the delay.
391     predicate is a function that will be passed the retval of the
392     decorated function and must return True to stop or False to
393     retry.
394
395     """
396     if backoff < 1.0:
397         msg = f"backoff must be greater than or equal to 1, got {backoff}"
398         logger.critical(msg)
399         raise ValueError(msg)
400
401     tries = math.floor(tries)
402     if tries < 0:
403         msg = f"tries must be 0 or greater, got {tries}"
404         logger.critical(msg)
405         raise ValueError(msg)
406
407     if delay_sec <= 0:
408         msg = f"delay_sec must be greater than 0, got {delay_sec}"
409         logger.critical(msg)
410         raise ValueError(msg)
411
412     def deco_retry(f):
413         @functools.wraps(f)
414         def f_retry(*args, **kwargs):
415             mtries, mdelay = tries, delay_sec  # make mutable
416             logger.debug('deco_retry: will make up to %d attempts...', mtries)
417             retval = f(*args, **kwargs)
418             while mtries > 0:
419                 if predicate(retval) is True:
420                     logger.debug('Predicate succeeded, deco_retry is done.')
421                     return retval
422                 logger.debug("Predicate failed, sleeping and retrying.")
423                 mtries -= 1
424                 time.sleep(mdelay)
425                 mdelay *= backoff
426                 retval = f(*args, **kwargs)
427             return retval
428
429         return f_retry
430
431     return deco_retry
432
433
434 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
435     """A helper for @retry_predicate that retries a decorated
436     function as long as it keeps returning False.
437
438     >>> import time
439
440     >>> counter = 0
441
442     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
443     ... def foo():
444     ...     global counter
445     ...     counter += 1
446     ...     return counter >= 3
447
448     >>> start = time.time()
449     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
450     True
451
452     >>> dur = time.time() - start
453     >>> counter
454     3
455     >>> dur > 2.0
456     True
457     >>> dur < 2.3
458     True
459
460     """
461     return retry_predicate(
462         tries,
463         predicate=lambda x: x is True,
464         delay_sec=delay_sec,
465         backoff=backoff,
466     )
467
468
469 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
470     """Another helper for @retry_predicate above.  Retries up to N
471     times so long as the wrapped function returns None with a delay
472     between each retry and a backoff that can increase the delay.
473
474     """
475     return retry_predicate(
476         tries,
477         predicate=lambda x: x is not None,
478         delay_sec=delay_sec,
479         backoff=backoff,
480     )
481
482
483 def deprecated(func):
484     """This is a decorator which can be used to mark functions
485     as deprecated. It will result in a warning being emitted
486     when the function is used.
487
488     """
489
490     @functools.wraps(func)
491     def wrapper_deprecated(*args, **kwargs):
492         msg = f"Call to deprecated function {func.__qualname__}"
493         logger.warning(msg)
494         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495         print(msg, file=sys.stderr)
496         return func(*args, **kwargs)
497
498     return wrapper_deprecated
499
500
501 def thunkify(func):
502     """
503     Make a function immediately return a function of no args which,
504     when called, waits for the result, which will start being
505     processed in another thread.
506     """
507
508     @functools.wraps(func)
509     def lazy_thunked(*args, **kwargs):
510         wait_event = threading.Event()
511
512         result = [None]
513         exc = [False, None]
514
515         def worker_func():
516             try:
517                 func_result = func(*args, **kwargs)
518                 result[0] = func_result
519             except Exception:
520                 exc[0] = True
521                 exc[1] = sys.exc_info()  # (type, value, traceback)
522                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
523                 logger.warning(msg)
524             finally:
525                 wait_event.set()
526
527         def thunk():
528             wait_event.wait()
529             if exc[0]:
530                 raise exc[1][0](exc[1][1])
531             return result[0]
532
533         threading.Thread(target=worker_func).start()
534         return thunk
535
536     return lazy_thunked
537
538
539 ############################################################
540 # Timeout
541 ############################################################
542
543 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
544 # Used work of Stephen "Zero" Chappell <[email protected]>
545 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
546
547
548 def _raise_exception(exception, error_message: Optional[str]):
549     if error_message is None:
550         raise Exception(exception)
551     else:
552         raise Exception(error_message)
553
554
555 def _target(queue, function, *args, **kwargs):
556     """Run a function with arguments and return output via a queue.
557
558     This is a helper function for the Process created in _Timeout. It runs
559     the function with positional arguments and keyword arguments and then
560     returns the function's output by way of a queue. If an exception gets
561     raised, it is returned to _Timeout to be raised by the value property.
562     """
563     try:
564         queue.put((True, function(*args, **kwargs)))
565     except Exception:
566         queue.put((False, sys.exc_info()[1]))
567
568
569 class _Timeout(object):
570     """Wrap a function and add a timeout to it.
571
572     Instances of this class are automatically generated by the add_timeout
573     function defined below.  Do not use directly.
574     """
575
576     def __init__(
577         self,
578         function: Callable,
579         timeout_exception: Exception,
580         error_message: str,
581         seconds: float,
582     ):
583         self.__limit = seconds
584         self.__function = function
585         self.__timeout_exception = timeout_exception
586         self.__error_message = error_message
587         self.__name__ = function.__name__
588         self.__doc__ = function.__doc__
589         self.__timeout = time.time()
590         self.__process = multiprocessing.Process()
591         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
592
593     def __call__(self, *args, **kwargs):
594         """Execute the embedded function object asynchronously.
595
596         The function given to the constructor is transparently called and
597         requires that "ready" be intermittently polled. If and when it is
598         True, the "value" property may then be checked for returned data.
599         """
600         self.__limit = kwargs.pop("timeout", self.__limit)
601         self.__queue = multiprocessing.Queue(1)
602         args = (self.__queue, self.__function) + args
603         self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
604         self.__process.daemon = True
605         self.__process.start()
606         if self.__limit is not None:
607             self.__timeout = self.__limit + time.time()
608         while not self.ready:
609             time.sleep(0.1)
610         return self.value
611
612     def cancel(self):
613         """Terminate any possible execution of the embedded function."""
614         if self.__process.is_alive():
615             self.__process.terminate()
616         _raise_exception(self.__timeout_exception, self.__error_message)
617
618     @property
619     def ready(self):
620         """Read-only property indicating status of "value" property."""
621         if self.__limit and self.__timeout < time.time():
622             self.cancel()
623         return self.__queue.full() and not self.__queue.empty()
624
625     @property
626     def value(self):
627         """Read-only property containing data returned from function."""
628         if self.ready is True:
629             flag, load = self.__queue.get()
630             if flag:
631                 return load
632             raise load
633         return None
634
635
636 def timeout(
637     seconds: float = 1.0,
638     use_signals: Optional[bool] = None,
639     timeout_exception=exceptions.TimeoutError,
640     error_message="Function call timed out",
641 ):
642     """Add a timeout parameter to a function and return the function.
643
644     Note: the use_signals parameter is included in order to support
645     multiprocessing scenarios (signal can only be used from the process'
646     main thread).  When not using signals, timeout granularity will be
647     rounded to the nearest 0.1s.
648
649     Raises an exception when/if the timeout is reached.
650
651     It is illegal to pass anything other than a function as the first
652     parameter.  The function is wrapped and returned to the caller.
653
654     >>> @timeout(0.2)
655     ... def foo(delay: float):
656     ...     time.sleep(delay)
657     ...     return "ok"
658
659     >>> foo(0)
660     'ok'
661
662     >>> foo(1.0)
663     Traceback (most recent call last):
664     ...
665     Exception: Function call timed out
666
667     """
668     if use_signals is None:
669         import thread_utils
670
671         use_signals = thread_utils.is_current_thread_main_thread()
672
673     def decorate(function):
674         if use_signals:
675
676             def handler(signum, frame):
677                 _raise_exception(timeout_exception, error_message)
678
679             @functools.wraps(function)
680             def new_function(*args, **kwargs):
681                 new_seconds = kwargs.pop("timeout", seconds)
682                 if new_seconds:
683                     old = signal.signal(signal.SIGALRM, handler)
684                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
685
686                 if not seconds:
687                     return function(*args, **kwargs)
688
689                 try:
690                     return function(*args, **kwargs)
691                 finally:
692                     if new_seconds:
693                         signal.setitimer(signal.ITIMER_REAL, 0)
694                         signal.signal(signal.SIGALRM, old)
695
696             return new_function
697         else:
698
699             @functools.wraps(function)
700             def new_function(*args, **kwargs):
701                 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
702                 return timeout_wrapper(*args, **kwargs)
703
704             return new_function
705
706     return decorate
707
708
709 def synchronized(lock):
710     def wrap(f):
711         @functools.wraps(f)
712         def _gatekeeper(*args, **kw):
713             lock.acquire()
714             try:
715                 return f(*args, **kw)
716             finally:
717                 lock.release()
718
719         return _gatekeeper
720
721     return wrap
722
723
724 def call_with_sample_rate(sample_rate: float) -> Callable:
725     if not 0.0 <= sample_rate <= 1.0:
726         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
727         logger.critical(msg)
728         raise ValueError(msg)
729
730     def decorator(f):
731         @functools.wraps(f)
732         def _call_with_sample_rate(*args, **kwargs):
733             if random.uniform(0, 1) < sample_rate:
734                 return f(*args, **kwargs)
735             else:
736                 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
737                 return None
738
739         return _call_with_sample_rate
740
741     return decorator
742
743
744 def decorate_matching_methods_with(decorator, acl=None):
745     """Apply decorator to all methods in a class whose names begin with
746     prefix.  If prefix is None (default), decorate all methods in the
747     class.
748     """
749
750     def decorate_the_class(cls):
751         for name, m in inspect.getmembers(cls, inspect.isfunction):
752             if acl is None:
753                 setattr(cls, name, decorator(m))
754             else:
755                 if acl(name):
756                     setattr(cls, name, decorator(m))
757         return cls
758
759     return decorate_the_class
760
761
762 if __name__ == '__main__':
763     import doctest
764
765     doctest.testmod()