Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
5
6 """Useful(?) decorators."""
7
8 import enum
9 import functools
10 import inspect
11 import logging
12 import math
13 import multiprocessing
14 import random
15 import signal
16 import sys
17 import threading
18 import time
19 import traceback
20 import warnings
21 from typing import Any, Callable, List, Optional
22
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
25 import exceptions
26
27 logger = logging.getLogger(__name__)
28
29
30 def timed(func: Callable) -> Callable:
31     """Print the runtime of the decorated function.
32
33     >>> @timed
34     ... def foo():
35     ...     import time
36     ...     time.sleep(0.01)
37
38     >>> foo()  # doctest: +ELLIPSIS
39     Finished foo in ...
40
41     """
42
43     @functools.wraps(func)
44     def wrapper_timer(*args, **kwargs):
45         start_time = time.perf_counter()
46         value = func(*args, **kwargs)
47         end_time = time.perf_counter()
48         run_time = end_time - start_time
49         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
50         print(msg)
51         logger.info(msg)
52         return value
53
54     return wrapper_timer
55
56
57 def invocation_logged(func: Callable) -> Callable:
58     """Log the call of a function on stdout and the info log.
59
60     >>> @invocation_logged
61     ... def foo():
62     ...     print('Hello, world.')
63
64     >>> foo()
65     Entered foo
66     Hello, world.
67     Exited foo
68
69     """
70
71     @functools.wraps(func)
72     def wrapper_invocation_logged(*args, **kwargs):
73         msg = f"Entered {func.__qualname__}"
74         print(msg)
75         logger.info(msg)
76         ret = func(*args, **kwargs)
77         msg = f"Exited {func.__qualname__}"
78         print(msg)
79         logger.info(msg)
80         return ret
81
82     return wrapper_invocation_logged
83
84
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86     """Limit invocation of a wrapped function to n calls per time period.
87     Thread safe.  In testing this was relatively fair with multiple
88     threads using it though that hasn't been measured in detail.
89
90     >>> import time
91     >>> import decorator_utils
92     >>> import thread_utils
93
94     >>> calls = 0
95
96     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97     ... def limited(x: int):
98     ...     global calls
99     ...     calls += 1
100
101     >>> @thread_utils.background_thread
102     ... def a(stop):
103     ...     for _ in range(3):
104     ...         limited(_)
105
106     >>> @thread_utils.background_thread
107     ... def b(stop):
108     ...     for _ in range(3):
109     ...         limited(_)
110
111     >>> start = time.time()
112     >>> (t1, e1) = a()
113     >>> (t2, e2) = b()
114     >>> t1.join()
115     >>> t2.join()
116     >>> end = time.time()
117     >>> dur = end - start
118     >>> dur > 0.5
119     True
120
121     >>> calls
122     6
123
124     """
125     min_interval_seconds = per_period_in_seconds / float(n_calls)
126
127     def wrapper_rate_limited(func: Callable) -> Callable:
128         cv = threading.Condition()
129         last_invocation_timestamp = [0.0]
130
131         def may_proceed() -> float:
132             now = time.time()
133             last_invocation = last_invocation_timestamp[0]
134             if last_invocation != 0.0:
135                 elapsed_since_last = now - last_invocation
136                 wait_time = min_interval_seconds - elapsed_since_last
137             else:
138                 wait_time = 0.0
139             logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
140             return wait_time
141
142         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143             with cv:
144                 while True:
145                     if cv.wait_for(
146                         lambda: may_proceed() <= 0.0,
147                         timeout=may_proceed(),
148                     ):
149                         break
150             with cv:
151                 logger.debug('@%.4f> calling it...', time.time())
152                 ret = func(*args, **kargs)
153                 last_invocation_timestamp[0] = time.time()
154                 logger.debug(
155                     '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
156                 )
157                 cv.notify()
158             return ret
159
160         return wrapper_wrapper_rate_limited
161
162     return wrapper_rate_limited
163
164
165 def debug_args(func: Callable) -> Callable:
166     """Print the function signature and return value at each call.
167
168     >>> @debug_args
169     ... def foo(a, b, c):
170     ...     print(a)
171     ...     print(b)
172     ...     print(c)
173     ...     return (a + b, c)
174
175     >>> foo(1, 2.0, "test")
176     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
177     1
178     2.0
179     test
180     foo returned (3.0, 'test'):<class 'tuple'>
181     (3.0, 'test')
182     """
183
184     @functools.wraps(func)
185     def wrapper_debug_args(*args, **kwargs):
186         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188         signature = ", ".join(args_repr + kwargs_repr)
189         msg = f"Calling {func.__qualname__}({signature})"
190         print(msg)
191         logger.info(msg)
192         value = func(*args, **kwargs)
193         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
194         print(msg)
195         logger.info(msg)
196         return value
197
198     return wrapper_debug_args
199
200
201 def debug_count_calls(func: Callable) -> Callable:
202     """Count function invocations and print a message befor every call.
203
204     >>> @debug_count_calls
205     ... def factoral(x):
206     ...     if x == 1:
207     ...         return 1
208     ...     return x * factoral(x - 1)
209
210     >>> factoral(5)
211     Call #1 of 'factoral'
212     Call #2 of 'factoral'
213     Call #3 of 'factoral'
214     Call #4 of 'factoral'
215     Call #5 of 'factoral'
216     120
217
218     """
219
220     @functools.wraps(func)
221     def wrapper_debug_count_calls(*args, **kwargs):
222         wrapper_debug_count_calls.num_calls += 1
223         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
224         print(msg)
225         logger.info(msg)
226         return func(*args, **kwargs)
227
228     wrapper_debug_count_calls.num_calls = 0  # type: ignore
229     return wrapper_debug_count_calls
230
231
232 class DelayWhen(enum.IntEnum):
233     """When should we delay: before or after calling the function (or
234     both)?
235
236     """
237
238     BEFORE_CALL = 1
239     AFTER_CALL = 2
240     BEFORE_AND_AFTER = 3
241
242
243 def delay(
244     _func: Callable = None,
245     *,
246     seconds: float = 1.0,
247     when: DelayWhen = DelayWhen.BEFORE_CALL,
248 ) -> Callable:
249     """Slow down a function by inserting a delay before and/or after its
250     invocation.
251
252     >>> import time
253
254     >>> @delay(seconds=1.0)
255     ... def foo():
256     ...     pass
257
258     >>> start = time.time()
259     >>> foo()
260     >>> dur = time.time() - start
261     >>> dur >= 1.0
262     True
263
264     """
265
266     def decorator_delay(func: Callable) -> Callable:
267         @functools.wraps(func)
268         def wrapper_delay(*args, **kwargs):
269             if when & DelayWhen.BEFORE_CALL:
270                 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
271                 time.sleep(seconds)
272             retval = func(*args, **kwargs)
273             if when & DelayWhen.AFTER_CALL:
274                 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
275                 time.sleep(seconds)
276             return retval
277
278         return wrapper_delay
279
280     if _func is None:
281         return decorator_delay
282     else:
283         return decorator_delay(_func)
284
285
286 class _SingletonWrapper:
287     """
288     A singleton wrapper class. Its instances would be created
289     for each decorated class.
290
291     """
292
293     def __init__(self, cls):
294         self.__wrapped__ = cls
295         self._instance = None
296
297     def __call__(self, *args, **kwargs):
298         """Returns a single instance of decorated class"""
299         logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
300         if self._instance is None:
301             self._instance = self.__wrapped__(*args, **kwargs)
302         return self._instance
303
304
305 def singleton(cls):
306     """
307     A singleton decorator. Returns a wrapper objects. A call on that object
308     returns a single instance object of decorated class. Use the __wrapped__
309     attribute to access decorated class directly in unit tests
310
311     >>> @singleton
312     ... class foo(object):
313     ...     pass
314
315     >>> a = foo()
316     >>> b = foo()
317     >>> a is b
318     True
319
320     >>> id(a) == id(b)
321     True
322
323     """
324     return _SingletonWrapper(cls)
325
326
327 def memoized(func: Callable) -> Callable:
328     """Keep a cache of previous function call results.
329
330     The cache here is a dict with a key based on the arguments to the
331     call.  Consider also: functools.cache for a more advanced
332     implementation.  See:
333     https://docs.python.org/3/library/functools.html#functools.cache
334
335     >>> import time
336
337     >>> @memoized
338     ... def expensive(arg) -> int:
339     ...     # Simulate something slow to compute or lookup
340     ...     time.sleep(1.0)
341     ...     return arg * arg
342
343     >>> start = time.time()
344     >>> expensive(5)           # Takes about 1 sec
345     25
346
347     >>> expensive(3)           # Also takes about 1 sec
348     9
349
350     >>> expensive(5)           # Pulls from cache, fast
351     25
352
353     >>> expensive(3)           # Pulls from cache again, fast
354     9
355
356     >>> dur = time.time() - start
357     >>> dur < 3.0
358     True
359
360     """
361
362     @functools.wraps(func)
363     def wrapper_memoized(*args, **kwargs):
364         cache_key = args + tuple(kwargs.items())
365         if cache_key not in wrapper_memoized.cache:
366             value = func(*args, **kwargs)
367             logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
368             wrapper_memoized.cache[cache_key] = value
369         else:
370             logger.debug('Returning memoized value for %s', {func.__name__})
371         return wrapper_memoized.cache[cache_key]
372
373     wrapper_memoized.cache = {}  # type: ignore
374     return wrapper_memoized
375
376
377 def retry_predicate(
378     tries: int,
379     *,
380     predicate: Callable[..., bool],
381     delay_sec: float = 3.0,
382     backoff: float = 2.0,
383 ):
384     """Retries a function or method up to a certain number of times with a
385     prescribed initial delay period and backoff rate (multiplier).
386
387     Args:
388         tries: the maximum number of attempts to run the function
389         delay_sec: sets the initial delay period in seconds
390         backoff: a multiplier (must be >=1.0) used to modify the
391             delay at each subsequent invocation
392         predicate: a Callable that will be passed the retval of
393             the decorated function and must return True to indicate
394             that we should stop calling or False to indicate a retry
395             is necessary
396     """
397
398     if backoff < 1.0:
399         msg = f"backoff must be greater than or equal to 1, got {backoff}"
400         logger.critical(msg)
401         raise ValueError(msg)
402
403     tries = math.floor(tries)
404     if tries < 0:
405         msg = f"tries must be 0 or greater, got {tries}"
406         logger.critical(msg)
407         raise ValueError(msg)
408
409     if delay_sec <= 0:
410         msg = f"delay_sec must be greater than 0, got {delay_sec}"
411         logger.critical(msg)
412         raise ValueError(msg)
413
414     def deco_retry(f):
415         @functools.wraps(f)
416         def f_retry(*args, **kwargs):
417             mtries, mdelay = tries, delay_sec  # make mutable
418             logger.debug('deco_retry: will make up to %d attempts...', mtries)
419             retval = f(*args, **kwargs)
420             while mtries > 0:
421                 if predicate(retval) is True:
422                     logger.debug('Predicate succeeded, deco_retry is done.')
423                     return retval
424                 logger.debug("Predicate failed, sleeping and retrying.")
425                 mtries -= 1
426                 time.sleep(mdelay)
427                 mdelay *= backoff
428                 retval = f(*args, **kwargs)
429             return retval
430
431         return f_retry
432
433     return deco_retry
434
435
436 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
437     """A helper for @retry_predicate that retries a decorated
438     function as long as it keeps returning False.
439
440     >>> import time
441
442     >>> counter = 0
443     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
444     ... def foo():
445     ...     global counter
446     ...     counter += 1
447     ...     return counter >= 3
448
449     >>> start = time.time()
450     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
451     True
452
453     >>> dur = time.time() - start
454     >>> counter
455     3
456     >>> dur > 2.0
457     True
458     >>> dur < 2.3
459     True
460
461     """
462     return retry_predicate(
463         tries,
464         predicate=lambda x: x is True,
465         delay_sec=delay_sec,
466         backoff=backoff,
467     )
468
469
470 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
471     """Another helper for @retry_predicate above.  Retries up to N
472     times so long as the wrapped function returns None with a delay
473     between each retry and a backoff that can increase the delay.
474     """
475
476     return retry_predicate(
477         tries,
478         predicate=lambda x: x is not None,
479         delay_sec=delay_sec,
480         backoff=backoff,
481     )
482
483
484 def deprecated(func):
485     """This is a decorator which can be used to mark functions
486     as deprecated. It will result in a warning being emitted
487     when the function is used.
488     """
489
490     @functools.wraps(func)
491     def wrapper_deprecated(*args, **kwargs):
492         msg = f"Call to deprecated function {func.__qualname__}"
493         logger.warning(msg)
494         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495         print(msg, file=sys.stderr)
496         return func(*args, **kwargs)
497
498     return wrapper_deprecated
499
500
501 def thunkify(func):
502     """
503     Make a function immediately return a function of no args which,
504     when called, waits for the result, which will start being
505     processed in another thread.
506     """
507
508     @functools.wraps(func)
509     def lazy_thunked(*args, **kwargs):
510         wait_event = threading.Event()
511
512         result = [None]
513         exc: List[Any] = [False, None]
514
515         def worker_func():
516             try:
517                 func_result = func(*args, **kwargs)
518                 result[0] = func_result
519             except Exception:
520                 exc[0] = True
521                 exc[1] = sys.exc_info()  # (type, value, traceback)
522                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
523                 logger.warning(msg)
524             finally:
525                 wait_event.set()
526
527         def thunk():
528             wait_event.wait()
529             if exc[0]:
530                 assert exc[1]
531                 raise exc[1][0](exc[1][1])
532             return result[0]
533
534         threading.Thread(target=worker_func).start()
535         return thunk
536
537     return lazy_thunked
538
539
540 ############################################################
541 # Timeout
542 ############################################################
543
544 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
545 # Used work of Stephen "Zero" Chappell <[email protected]>
546 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
547
548
549 def _raise_exception(exception, error_message: Optional[str]):
550     if error_message is None:
551         raise Exception(exception)
552     else:
553         raise Exception(error_message)
554
555
556 def _target(queue, function, *args, **kwargs):
557     """Run a function with arguments and return output via a queue.
558
559     This is a helper function for the Process created in _Timeout. It runs
560     the function with positional arguments and keyword arguments and then
561     returns the function's output by way of a queue. If an exception gets
562     raised, it is returned to _Timeout to be raised by the value property.
563     """
564     try:
565         queue.put((True, function(*args, **kwargs)))
566     except Exception:
567         queue.put((False, sys.exc_info()[1]))
568
569
570 class _Timeout(object):
571     """Wrap a function and add a timeout to it.
572
573     Instances of this class are automatically generated by the add_timeout
574     function defined below.  Do not use directly.
575     """
576
577     def __init__(
578         self,
579         function: Callable,
580         timeout_exception: Exception,
581         error_message: str,
582         seconds: float,
583     ):
584         self.__limit = seconds
585         self.__function = function
586         self.__timeout_exception = timeout_exception
587         self.__error_message = error_message
588         self.__name__ = function.__name__
589         self.__doc__ = function.__doc__
590         self.__timeout = time.time()
591         self.__process = multiprocessing.Process()
592         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
593
594     def __call__(self, *args, **kwargs):
595         """Execute the embedded function object asynchronously.
596
597         The function given to the constructor is transparently called and
598         requires that "ready" be intermittently polled. If and when it is
599         True, the "value" property may then be checked for returned data.
600         """
601         self.__limit = kwargs.pop("timeout", self.__limit)
602         self.__queue = multiprocessing.Queue(1)
603         args = (self.__queue, self.__function) + args
604         self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
605         self.__process.daemon = True
606         self.__process.start()
607         if self.__limit is not None:
608             self.__timeout = self.__limit + time.time()
609         while not self.ready:
610             time.sleep(0.1)
611         return self.value
612
613     def cancel(self):
614         """Terminate any possible execution of the embedded function."""
615         if self.__process.is_alive():
616             self.__process.terminate()
617         _raise_exception(self.__timeout_exception, self.__error_message)
618
619     @property
620     def ready(self):
621         """Read-only property indicating status of "value" property."""
622         if self.__limit and self.__timeout < time.time():
623             self.cancel()
624         return self.__queue.full() and not self.__queue.empty()
625
626     @property
627     def value(self):
628         """Read-only property containing data returned from function."""
629         if self.ready is True:
630             flag, load = self.__queue.get()
631             if flag:
632                 return load
633             raise load
634         return None
635
636
637 def timeout(
638     seconds: float = 1.0,
639     use_signals: Optional[bool] = None,
640     timeout_exception=exceptions.TimeoutError,
641     error_message="Function call timed out",
642 ):
643     """Add a timeout parameter to a function and return the function.
644
645     Note: the use_signals parameter is included in order to support
646     multiprocessing scenarios (signal can only be used from the process'
647     main thread).  When not using signals, timeout granularity will be
648     rounded to the nearest 0.1s.
649
650     Beware that an @timeout on a function inside a module will be
651     evaluated at module load time and not when the wrapped function is
652     invoked.  This can lead to problems when relying on the automatic
653     main thread detection code (use_signals=None, the default) since
654     the import probably happens on the main thread and the invocation
655     can happen on a different thread (which can't use signals).
656
657     Raises an exception when/if the timeout is reached.
658
659     It is illegal to pass anything other than a function as the first
660     parameter.  The function is wrapped and returned to the caller.
661
662     >>> @timeout(0.2)
663     ... def foo(delay: float):
664     ...     time.sleep(delay)
665     ...     return "ok"
666
667     >>> foo(0)
668     'ok'
669
670     >>> foo(1.0)
671     Traceback (most recent call last):
672     ...
673     Exception: Function call timed out
674
675     """
676     if use_signals is None:
677         import thread_utils
678
679         use_signals = thread_utils.is_current_thread_main_thread()
680
681     def decorate(function):
682         if use_signals:
683
684             def handler(unused_signum, unused_frame):
685                 _raise_exception(timeout_exception, error_message)
686
687             @functools.wraps(function)
688             def new_function(*args, **kwargs):
689                 new_seconds = kwargs.pop("timeout", seconds)
690                 if new_seconds:
691                     old = signal.signal(signal.SIGALRM, handler)
692                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
693
694                 if not seconds:
695                     return function(*args, **kwargs)
696
697                 try:
698                     return function(*args, **kwargs)
699                 finally:
700                     if new_seconds:
701                         signal.setitimer(signal.ITIMER_REAL, 0)
702                         signal.signal(signal.SIGALRM, old)
703
704             return new_function
705         else:
706
707             @functools.wraps(function)
708             def new_function(*args, **kwargs):
709                 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
710                 return timeout_wrapper(*args, **kwargs)
711
712             return new_function
713
714     return decorate
715
716
717 def synchronized(lock):
718     """Emulates java's synchronized keyword: given a lock, require that
719     threads take that lock (or wait) before invoking the wrapped
720     function and automatically releases the lock afterwards.
721     """
722
723     def wrap(f):
724         @functools.wraps(f)
725         def _gatekeeper(*args, **kw):
726             lock.acquire()
727             try:
728                 return f(*args, **kw)
729             finally:
730                 lock.release()
731
732         return _gatekeeper
733
734     return wrap
735
736
737 def call_with_sample_rate(sample_rate: float) -> Callable:
738     """Calls the wrapped function probabilistically given a rate between
739     0.0 and 1.0 inclusive (0% probability and 100% probability).
740     """
741
742     if not 0.0 <= sample_rate <= 1.0:
743         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
744         logger.critical(msg)
745         raise ValueError(msg)
746
747     def decorator(f):
748         @functools.wraps(f)
749         def _call_with_sample_rate(*args, **kwargs):
750             if random.uniform(0, 1) < sample_rate:
751                 return f(*args, **kwargs)
752             else:
753                 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
754                 return None
755
756         return _call_with_sample_rate
757
758     return decorator
759
760
761 def decorate_matching_methods_with(decorator, acl=None):
762     """Apply the given decorator to all methods in a class whose names
763     begin with prefix.  If prefix is None (default), decorate all
764     methods in the class.
765     """
766
767     def decorate_the_class(cls):
768         for name, m in inspect.getmembers(cls, inspect.isfunction):
769             if acl is None:
770                 setattr(cls, name, decorator(m))
771             else:
772                 if acl(name):
773                     setattr(cls, name, decorator(m))
774         return cls
775
776     return decorate_the_class
777
778
779 if __name__ == '__main__':
780     import doctest
781
782     doctest.testmod()