Geocoder.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4 # Portions (marked) below retain the original author's copyright.
5
6 """Decorators."""
7
8 import enum
9 import functools
10 import inspect
11 import logging
12 import math
13 import multiprocessing
14 import random
15 import signal
16 import sys
17 import threading
18 import time
19 import traceback
20 import warnings
21 from typing import Any, Callable, List, Optional
22
23 # This module is commonly used by others in here and should avoid
24 # taking any unnecessary dependencies back on them.
25 import exceptions
26
27 logger = logging.getLogger(__name__)
28
29
30 def timed(func: Callable) -> Callable:
31     """Print the runtime of the decorated function.
32
33     >>> @timed
34     ... def foo():
35     ...     import time
36     ...     time.sleep(0.1)
37
38     >>> foo()  # doctest: +ELLIPSIS
39     Finished foo in ...
40
41     """
42
43     @functools.wraps(func)
44     def wrapper_timer(*args, **kwargs):
45         start_time = time.perf_counter()
46         value = func(*args, **kwargs)
47         end_time = time.perf_counter()
48         run_time = end_time - start_time
49         msg = f"Finished {func.__qualname__} in {run_time:.4f}s"
50         print(msg)
51         logger.info(msg)
52         return value
53
54     return wrapper_timer
55
56
57 def invocation_logged(func: Callable) -> Callable:
58     """Log the call of a function.
59
60     >>> @invocation_logged
61     ... def foo():
62     ...     print('Hello, world.')
63
64     >>> foo()
65     Entered foo
66     Hello, world.
67     Exited foo
68
69     """
70
71     @functools.wraps(func)
72     def wrapper_invocation_logged(*args, **kwargs):
73         msg = f"Entered {func.__qualname__}"
74         print(msg)
75         logger.info(msg)
76         ret = func(*args, **kwargs)
77         msg = f"Exited {func.__qualname__}"
78         print(msg)
79         logger.info(msg)
80         return ret
81
82     return wrapper_invocation_logged
83
84
85 def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable:
86     """Limit invocation of a wrapped function to n calls per period.
87     Thread safe.  In testing this was relatively fair with multiple
88     threads using it though that hasn't been measured.
89
90     >>> import time
91     >>> import decorator_utils
92     >>> import thread_utils
93
94     >>> calls = 0
95
96     >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0)
97     ... def limited(x: int):
98     ...     global calls
99     ...     calls += 1
100
101     >>> @thread_utils.background_thread
102     ... def a(stop):
103     ...     for _ in range(3):
104     ...         limited(_)
105
106     >>> @thread_utils.background_thread
107     ... def b(stop):
108     ...     for _ in range(3):
109     ...         limited(_)
110
111     >>> start = time.time()
112     >>> (t1, e1) = a()
113     >>> (t2, e2) = b()
114     >>> t1.join()
115     >>> t2.join()
116     >>> end = time.time()
117     >>> dur = end - start
118     >>> dur > 0.5
119     True
120
121     >>> calls
122     6
123
124     """
125     min_interval_seconds = per_period_in_seconds / float(n_calls)
126
127     def wrapper_rate_limited(func: Callable) -> Callable:
128         cv = threading.Condition()
129         last_invocation_timestamp = [0.0]
130
131         def may_proceed() -> float:
132             now = time.time()
133             last_invocation = last_invocation_timestamp[0]
134             if last_invocation != 0.0:
135                 elapsed_since_last = now - last_invocation
136                 wait_time = min_interval_seconds - elapsed_since_last
137             else:
138                 wait_time = 0.0
139             logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time)
140             return wait_time
141
142         def wrapper_wrapper_rate_limited(*args, **kargs) -> Any:
143             with cv:
144                 while True:
145                     if cv.wait_for(
146                         lambda: may_proceed() <= 0.0,
147                         timeout=may_proceed(),
148                     ):
149                         break
150             with cv:
151                 logger.debug('@%.4f> calling it...', time.time())
152                 ret = func(*args, **kargs)
153                 last_invocation_timestamp[0] = time.time()
154                 logger.debug(
155                     '@%.4f> Last invocation <- %.4f', time.time(), last_invocation_timestamp[0]
156                 )
157                 cv.notify()
158             return ret
159
160         return wrapper_wrapper_rate_limited
161
162     return wrapper_rate_limited
163
164
165 def debug_args(func: Callable) -> Callable:
166     """Print the function signature and return value at each call.
167
168     >>> @debug_args
169     ... def foo(a, b, c):
170     ...     print(a)
171     ...     print(b)
172     ...     print(c)
173     ...     return (a + b, c)
174
175     >>> foo(1, 2.0, "test")
176     Calling foo(1:<class 'int'>, 2.0:<class 'float'>, 'test':<class 'str'>)
177     1
178     2.0
179     test
180     foo returned (3.0, 'test'):<class 'tuple'>
181     (3.0, 'test')
182     """
183
184     @functools.wraps(func)
185     def wrapper_debug_args(*args, **kwargs):
186         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
187         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
188         signature = ", ".join(args_repr + kwargs_repr)
189         msg = f"Calling {func.__qualname__}({signature})"
190         print(msg)
191         logger.info(msg)
192         value = func(*args, **kwargs)
193         msg = f"{func.__qualname__} returned {value!r}:{type(value)}"
194         print(msg)
195         logger.info(msg)
196         return value
197
198     return wrapper_debug_args
199
200
201 def debug_count_calls(func: Callable) -> Callable:
202     """Count function invocations and print a message befor every call.
203
204     >>> @debug_count_calls
205     ... def factoral(x):
206     ...     if x == 1:
207     ...         return 1
208     ...     return x * factoral(x - 1)
209
210     >>> factoral(5)
211     Call #1 of 'factoral'
212     Call #2 of 'factoral'
213     Call #3 of 'factoral'
214     Call #4 of 'factoral'
215     Call #5 of 'factoral'
216     120
217
218     """
219
220     @functools.wraps(func)
221     def wrapper_debug_count_calls(*args, **kwargs):
222         wrapper_debug_count_calls.num_calls += 1
223         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
224         print(msg)
225         logger.info(msg)
226         return func(*args, **kwargs)
227
228     wrapper_debug_count_calls.num_calls = 0  # type: ignore
229     return wrapper_debug_count_calls
230
231
232 class DelayWhen(enum.IntEnum):
233     """When should we delay: before or after calling the function (or
234     both)?
235
236     """
237
238     BEFORE_CALL = 1
239     AFTER_CALL = 2
240     BEFORE_AND_AFTER = 3
241
242
243 def delay(
244     _func: Callable = None,
245     *,
246     seconds: float = 1.0,
247     when: DelayWhen = DelayWhen.BEFORE_CALL,
248 ) -> Callable:
249     """Delay the execution of a function by sleeping before and/or after.
250
251     Slow down a function by inserting a delay before and/or after its
252     invocation.
253
254     >>> import time
255
256     >>> @delay(seconds=1.0)
257     ... def foo():
258     ...     pass
259
260     >>> start = time.time()
261     >>> foo()
262     >>> dur = time.time() - start
263     >>> dur >= 1.0
264     True
265
266     """
267
268     def decorator_delay(func: Callable) -> Callable:
269         @functools.wraps(func)
270         def wrapper_delay(*args, **kwargs):
271             if when & DelayWhen.BEFORE_CALL:
272                 logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__)
273                 time.sleep(seconds)
274             retval = func(*args, **kwargs)
275             if when & DelayWhen.AFTER_CALL:
276                 logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__)
277                 time.sleep(seconds)
278             return retval
279
280         return wrapper_delay
281
282     if _func is None:
283         return decorator_delay
284     else:
285         return decorator_delay(_func)
286
287
288 class _SingletonWrapper:
289     """
290     A singleton wrapper class. Its instances would be created
291     for each decorated class.
292
293     """
294
295     def __init__(self, cls):
296         self.__wrapped__ = cls
297         self._instance = None
298
299     def __call__(self, *args, **kwargs):
300         """Returns a single instance of decorated class"""
301         logger.debug('@singleton returning global instance of %s', self.__wrapped__.__name__)
302         if self._instance is None:
303             self._instance = self.__wrapped__(*args, **kwargs)
304         return self._instance
305
306
307 def singleton(cls):
308     """
309     A singleton decorator. Returns a wrapper objects. A call on that object
310     returns a single instance object of decorated class. Use the __wrapped__
311     attribute to access decorated class directly in unit tests
312
313     >>> @singleton
314     ... class foo(object):
315     ...     pass
316
317     >>> a = foo()
318     >>> b = foo()
319     >>> a is b
320     True
321
322     >>> id(a) == id(b)
323     True
324
325     """
326     return _SingletonWrapper(cls)
327
328
329 def memoized(func: Callable) -> Callable:
330     """Keep a cache of previous function call results.
331
332     The cache here is a dict with a key based on the arguments to the
333     call.  Consider also: functools.lru_cache for a more advanced
334     implementation.
335
336     >>> import time
337
338     >>> @memoized
339     ... def expensive(arg) -> int:
340     ...     # Simulate something slow to compute or lookup
341     ...     time.sleep(1.0)
342     ...     return arg * arg
343
344     >>> start = time.time()
345     >>> expensive(5)           # Takes about 1 sec
346     25
347
348     >>> expensive(3)           # Also takes about 1 sec
349     9
350
351     >>> expensive(5)           # Pulls from cache, fast
352     25
353
354     >>> expensive(3)           # Pulls from cache again, fast
355     9
356
357     >>> dur = time.time() - start
358     >>> dur < 3.0
359     True
360
361     """
362
363     @functools.wraps(func)
364     def wrapper_memoized(*args, **kwargs):
365         cache_key = args + tuple(kwargs.items())
366         if cache_key not in wrapper_memoized.cache:
367             value = func(*args, **kwargs)
368             logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__)
369             wrapper_memoized.cache[cache_key] = value
370         else:
371             logger.debug('Returning memoized value for %s', {func.__name__})
372         return wrapper_memoized.cache[cache_key]
373
374     wrapper_memoized.cache = {}  # type: ignore
375     return wrapper_memoized
376
377
378 def retry_predicate(
379     tries: int,
380     *,
381     predicate: Callable[..., bool],
382     delay_sec: float = 3.0,
383     backoff: float = 2.0,
384 ):
385     """Retries a function or method up to a certain number of times
386     with a prescribed initial delay period and backoff rate.
387
388     tries is the maximum number of attempts to run the function.
389     delay_sec sets the initial delay period in seconds.
390     backoff is a multiplied (must be >1) used to modify the delay.
391     predicate is a function that will be passed the retval of the
392     decorated function and must return True to stop or False to
393     retry.
394
395     """
396     if backoff < 1.0:
397         msg = f"backoff must be greater than or equal to 1, got {backoff}"
398         logger.critical(msg)
399         raise ValueError(msg)
400
401     tries = math.floor(tries)
402     if tries < 0:
403         msg = f"tries must be 0 or greater, got {tries}"
404         logger.critical(msg)
405         raise ValueError(msg)
406
407     if delay_sec <= 0:
408         msg = f"delay_sec must be greater than 0, got {delay_sec}"
409         logger.critical(msg)
410         raise ValueError(msg)
411
412     def deco_retry(f):
413         @functools.wraps(f)
414         def f_retry(*args, **kwargs):
415             mtries, mdelay = tries, delay_sec  # make mutable
416             logger.debug('deco_retry: will make up to %d attempts...', mtries)
417             retval = f(*args, **kwargs)
418             while mtries > 0:
419                 if predicate(retval) is True:
420                     logger.debug('Predicate succeeded, deco_retry is done.')
421                     return retval
422                 logger.debug("Predicate failed, sleeping and retrying.")
423                 mtries -= 1
424                 time.sleep(mdelay)
425                 mdelay *= backoff
426                 retval = f(*args, **kwargs)
427             return retval
428
429         return f_retry
430
431     return deco_retry
432
433
434 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
435     """A helper for @retry_predicate that retries a decorated
436     function as long as it keeps returning False.
437
438     >>> import time
439
440     >>> counter = 0
441
442     >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1)
443     ... def foo():
444     ...     global counter
445     ...     counter += 1
446     ...     return counter >= 3
447
448     >>> start = time.time()
449     >>> foo()  # fail, delay 1.0, fail, delay 1.1, succeed
450     True
451
452     >>> dur = time.time() - start
453     >>> counter
454     3
455     >>> dur > 2.0
456     True
457     >>> dur < 2.3
458     True
459
460     """
461     return retry_predicate(
462         tries,
463         predicate=lambda x: x is True,
464         delay_sec=delay_sec,
465         backoff=backoff,
466     )
467
468
469 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
470     """Another helper for @retry_predicate above.  Retries up to N
471     times so long as the wrapped function returns None with a delay
472     between each retry and a backoff that can increase the delay.
473
474     """
475     return retry_predicate(
476         tries,
477         predicate=lambda x: x is not None,
478         delay_sec=delay_sec,
479         backoff=backoff,
480     )
481
482
483 def deprecated(func):
484     """This is a decorator which can be used to mark functions
485     as deprecated. It will result in a warning being emitted
486     when the function is used.
487
488     """
489
490     @functools.wraps(func)
491     def wrapper_deprecated(*args, **kwargs):
492         msg = f"Call to deprecated function {func.__qualname__}"
493         logger.warning(msg)
494         warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
495         print(msg, file=sys.stderr)
496         return func(*args, **kwargs)
497
498     return wrapper_deprecated
499
500
501 def thunkify(func):
502     """
503     Make a function immediately return a function of no args which,
504     when called, waits for the result, which will start being
505     processed in another thread.
506     """
507
508     @functools.wraps(func)
509     def lazy_thunked(*args, **kwargs):
510         wait_event = threading.Event()
511
512         result = [None]
513         exc: List[Any] = [False, None]
514
515         def worker_func():
516             try:
517                 func_result = func(*args, **kwargs)
518                 result[0] = func_result
519             except Exception:
520                 exc[0] = True
521                 exc[1] = sys.exc_info()  # (type, value, traceback)
522                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
523                 logger.warning(msg)
524             finally:
525                 wait_event.set()
526
527         def thunk():
528             wait_event.wait()
529             if exc[0]:
530                 assert exc[1]
531                 raise exc[1][0](exc[1][1])
532             return result[0]
533
534         threading.Thread(target=worker_func).start()
535         return thunk
536
537     return lazy_thunked
538
539
540 ############################################################
541 # Timeout
542 ############################################################
543
544 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
545 # Used work of Stephen "Zero" Chappell <[email protected]>
546 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
547
548
549 def _raise_exception(exception, error_message: Optional[str]):
550     if error_message is None:
551         raise Exception(exception)
552     else:
553         raise Exception(error_message)
554
555
556 def _target(queue, function, *args, **kwargs):
557     """Run a function with arguments and return output via a queue.
558
559     This is a helper function for the Process created in _Timeout. It runs
560     the function with positional arguments and keyword arguments and then
561     returns the function's output by way of a queue. If an exception gets
562     raised, it is returned to _Timeout to be raised by the value property.
563     """
564     try:
565         queue.put((True, function(*args, **kwargs)))
566     except Exception:
567         queue.put((False, sys.exc_info()[1]))
568
569
570 class _Timeout(object):
571     """Wrap a function and add a timeout to it.
572
573     Instances of this class are automatically generated by the add_timeout
574     function defined below.  Do not use directly.
575     """
576
577     def __init__(
578         self,
579         function: Callable,
580         timeout_exception: Exception,
581         error_message: str,
582         seconds: float,
583     ):
584         self.__limit = seconds
585         self.__function = function
586         self.__timeout_exception = timeout_exception
587         self.__error_message = error_message
588         self.__name__ = function.__name__
589         self.__doc__ = function.__doc__
590         self.__timeout = time.time()
591         self.__process = multiprocessing.Process()
592         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
593
594     def __call__(self, *args, **kwargs):
595         """Execute the embedded function object asynchronously.
596
597         The function given to the constructor is transparently called and
598         requires that "ready" be intermittently polled. If and when it is
599         True, the "value" property may then be checked for returned data.
600         """
601         self.__limit = kwargs.pop("timeout", self.__limit)
602         self.__queue = multiprocessing.Queue(1)
603         args = (self.__queue, self.__function) + args
604         self.__process = multiprocessing.Process(target=_target, args=args, kwargs=kwargs)
605         self.__process.daemon = True
606         self.__process.start()
607         if self.__limit is not None:
608             self.__timeout = self.__limit + time.time()
609         while not self.ready:
610             time.sleep(0.1)
611         return self.value
612
613     def cancel(self):
614         """Terminate any possible execution of the embedded function."""
615         if self.__process.is_alive():
616             self.__process.terminate()
617         _raise_exception(self.__timeout_exception, self.__error_message)
618
619     @property
620     def ready(self):
621         """Read-only property indicating status of "value" property."""
622         if self.__limit and self.__timeout < time.time():
623             self.cancel()
624         return self.__queue.full() and not self.__queue.empty()
625
626     @property
627     def value(self):
628         """Read-only property containing data returned from function."""
629         if self.ready is True:
630             flag, load = self.__queue.get()
631             if flag:
632                 return load
633             raise load
634         return None
635
636
637 def timeout(
638     seconds: float = 1.0,
639     use_signals: Optional[bool] = None,
640     timeout_exception=exceptions.TimeoutError,
641     error_message="Function call timed out",
642 ):
643     """Add a timeout parameter to a function and return the function.
644
645     Note: the use_signals parameter is included in order to support
646     multiprocessing scenarios (signal can only be used from the process'
647     main thread).  When not using signals, timeout granularity will be
648     rounded to the nearest 0.1s.
649
650     Raises an exception when/if the timeout is reached.
651
652     It is illegal to pass anything other than a function as the first
653     parameter.  The function is wrapped and returned to the caller.
654
655     >>> @timeout(0.2)
656     ... def foo(delay: float):
657     ...     time.sleep(delay)
658     ...     return "ok"
659
660     >>> foo(0)
661     'ok'
662
663     >>> foo(1.0)
664     Traceback (most recent call last):
665     ...
666     Exception: Function call timed out
667
668     """
669     if use_signals is None:
670         import thread_utils
671
672         use_signals = thread_utils.is_current_thread_main_thread()
673
674     def decorate(function):
675         if use_signals:
676
677             def handler(unused_signum, unused_frame):
678                 _raise_exception(timeout_exception, error_message)
679
680             @functools.wraps(function)
681             def new_function(*args, **kwargs):
682                 new_seconds = kwargs.pop("timeout", seconds)
683                 if new_seconds:
684                     old = signal.signal(signal.SIGALRM, handler)
685                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
686
687                 if not seconds:
688                     return function(*args, **kwargs)
689
690                 try:
691                     return function(*args, **kwargs)
692                 finally:
693                     if new_seconds:
694                         signal.setitimer(signal.ITIMER_REAL, 0)
695                         signal.signal(signal.SIGALRM, old)
696
697             return new_function
698         else:
699
700             @functools.wraps(function)
701             def new_function(*args, **kwargs):
702                 timeout_wrapper = _Timeout(function, timeout_exception, error_message, seconds)
703                 return timeout_wrapper(*args, **kwargs)
704
705             return new_function
706
707     return decorate
708
709
710 def synchronized(lock):
711     def wrap(f):
712         @functools.wraps(f)
713         def _gatekeeper(*args, **kw):
714             lock.acquire()
715             try:
716                 return f(*args, **kw)
717             finally:
718                 lock.release()
719
720         return _gatekeeper
721
722     return wrap
723
724
725 def call_with_sample_rate(sample_rate: float) -> Callable:
726     if not 0.0 <= sample_rate <= 1.0:
727         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
728         logger.critical(msg)
729         raise ValueError(msg)
730
731     def decorator(f):
732         @functools.wraps(f)
733         def _call_with_sample_rate(*args, **kwargs):
734             if random.uniform(0, 1) < sample_rate:
735                 return f(*args, **kwargs)
736             else:
737                 logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__)
738                 return None
739
740         return _call_with_sample_rate
741
742     return decorator
743
744
745 def decorate_matching_methods_with(decorator, acl=None):
746     """Apply decorator to all methods in a class whose names begin with
747     prefix.  If prefix is None (default), decorate all methods in the
748     class.
749     """
750
751     def decorate_the_class(cls):
752         for name, m in inspect.getmembers(cls, inspect.isfunction):
753             if acl is None:
754                 setattr(cls, name, decorator(m))
755             else:
756                 if acl(name):
757                     setattr(cls, name, decorator(m))
758         return cls
759
760     return decorate_the_class
761
762
763 if __name__ == '__main__':
764     import doctest
765
766     doctest.testmod()