Lockfile custom command, fixup minor things.
[python_utils.git] / decorator_utils.py
1 #!/usr/bin/env python3
2
3 """Decorators."""
4
5 import datetime
6 import enum
7 import functools
8 import inspect
9 import logging
10 import math
11 import multiprocessing
12 import random
13 import signal
14 import sys
15 import threading
16 import time
17 import traceback
18 from typing import Callable, Optional
19 import warnings
20
21 # This module is commonly used by others in here and should avoid
22 # taking any unnecessary dependencies back on them.
23 import exceptions
24
25
26 logger = logging.getLogger(__name__)
27
28
29 def timed(func: Callable) -> Callable:
30     """Print the runtime of the decorated function."""
31
32     @functools.wraps(func)
33     def wrapper_timer(*args, **kwargs):
34         start_time = time.perf_counter()
35         value = func(*args, **kwargs)
36         end_time = time.perf_counter()
37         run_time = end_time - start_time
38         msg = f"Finished {func.__name__!r} in {run_time:.4f}s"
39         print(msg)
40         logger.info(msg)
41         return value
42     return wrapper_timer
43
44
45 def invocation_logged(func: Callable) -> Callable:
46     """Log the call of a function."""
47
48     @functools.wraps(func)
49     def wrapper_invocation_logged(*args, **kwargs):
50         now = datetime.datetime.now()
51         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
52         msg = f"[{ts}]: Entered {func.__name__}"
53         print(msg)
54         logger.info(msg)
55         ret = func(*args, **kwargs)
56         now = datetime.datetime.now()
57         ts = now.strftime("%Y/%d/%b:%H:%M:%S%Z")
58         msg = f"[{ts}]: Exited {func.__name__}"
59         print(msg)
60         logger.info(msg)
61         return ret
62     return wrapper_invocation_logged
63
64
65 def debug_args(func: Callable) -> Callable:
66     """Print the function signature and return value at each call."""
67
68     @functools.wraps(func)
69     def wrapper_debug_args(*args, **kwargs):
70         args_repr = [f"{repr(a)}:{type(a)}" for a in args]
71         kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()]
72         signature = ", ".join(args_repr + kwargs_repr)
73         msg = f"Calling {func.__name__}({signature})"
74         print(msg)
75         logger.info(msg)
76         value = func(*args, **kwargs)
77         msg = f"{func.__name__!r} returned {value!r}:{type(value)}"
78         logger.info(msg)
79         return value
80     return wrapper_debug_args
81
82
83 def debug_count_calls(func: Callable) -> Callable:
84     """Count function invocations and print a message befor every call."""
85
86     @functools.wraps(func)
87     def wrapper_debug_count_calls(*args, **kwargs):
88         wrapper_debug_count_calls.num_calls += 1
89         msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}"
90         print(msg)
91         logger.info(msg)
92         return func(*args, **kwargs)
93     wrapper_debug_count_calls.num_calls = 0
94     return wrapper_debug_count_calls
95
96
97 class DelayWhen(enum.Enum):
98     BEFORE_CALL = 1
99     AFTER_CALL = 2
100     BEFORE_AND_AFTER = 3
101
102
103 def delay(
104     _func: Callable = None,
105     *,
106     seconds: float = 1.0,
107     when: DelayWhen = DelayWhen.BEFORE_CALL,
108 ) -> Callable:
109     """Delay the execution of a function by sleeping before and/or after.
110
111     Slow down a function by inserting a delay before and/or after its
112     invocation.
113     """
114
115     def decorator_delay(func: Callable) -> Callable:
116         @functools.wraps(func)
117         def wrapper_delay(*args, **kwargs):
118             if when & DelayWhen.BEFORE_CALL:
119                 logger.debug(
120                     f"@delay for {seconds}s BEFORE_CALL to {func.__name__}"
121                 )
122                 time.sleep(seconds)
123             retval = func(*args, **kwargs)
124             if when & DelayWhen.AFTER_CALL:
125                 logger.debug(
126                     f"@delay for {seconds}s AFTER_CALL to {func.__name__}"
127                 )
128                 time.sleep(seconds)
129             return retval
130         return wrapper_delay
131
132     if _func is None:
133         return decorator_delay
134     else:
135         return decorator_delay(_func)
136
137
138 class _SingletonWrapper:
139     """
140     A singleton wrapper class. Its instances would be created
141     for each decorated class.
142     """
143
144     def __init__(self, cls):
145         self.__wrapped__ = cls
146         self._instance = None
147
148     def __call__(self, *args, **kwargs):
149         """Returns a single instance of decorated class"""
150         logger.debug(
151             f"@singleton returning global instance of {self.__wrapped__.__name__}"
152         )
153         if self._instance is None:
154             self._instance = self.__wrapped__(*args, **kwargs)
155         return self._instance
156
157
158 def singleton(cls):
159     """
160     A singleton decorator. Returns a wrapper objects. A call on that object
161     returns a single instance object of decorated class. Use the __wrapped__
162     attribute to access decorated class directly in unit tests
163     """
164     return _SingletonWrapper(cls)
165
166
167 def memoized(func: Callable) -> Callable:
168     """Keep a cache of previous function call results.
169
170     The cache here is a dict with a key based on the arguments to the
171     call.  Consider also: functools.lru_cache for a more advanced
172     implementation.
173     """
174
175     @functools.wraps(func)
176     def wrapper_memoized(*args, **kwargs):
177         cache_key = args + tuple(kwargs.items())
178         if cache_key not in wrapper_memoized.cache:
179             value = func(*args, **kwargs)
180             logger.debug(
181                 f"Memoizing {cache_key} => {value} for {func.__name__}"
182             )
183             wrapper_memoized.cache[cache_key] = value
184         else:
185             logger.debug(f"Returning memoized value for {func.__name__}")
186         return wrapper_memoized.cache[cache_key]
187     wrapper_memoized.cache = dict()
188     return wrapper_memoized
189
190
191 def retry_predicate(
192     tries: int,
193     *,
194     predicate: Callable[..., bool],
195     delay_sec: float = 3.0,
196     backoff: float = 2.0,
197 ):
198     """Retries a function or method up to a certain number of times
199     with a prescribed initial delay period and backoff rate.
200
201     tries is the maximum number of attempts to run the function.
202     delay_sec sets the initial delay period in seconds.
203     backoff is a multiplied (must be >1) used to modify the delay.
204     predicate is a function that will be passed the retval of the
205     decorated function and must return True to stop or False to
206     retry.
207     """
208     if backoff < 1.0:
209         msg = f"backoff must be greater than or equal to 1, got {backoff}"
210         logger.critical(msg)
211         raise ValueError(msg)
212
213     tries = math.floor(tries)
214     if tries < 0:
215         msg = f"tries must be 0 or greater, got {tries}"
216         logger.critical(msg)
217         raise ValueError(msg)
218
219     if delay_sec <= 0:
220         msg = f"delay_sec must be greater than 0, got {delay_sec}"
221         logger.critical(msg)
222         raise ValueError(msg)
223
224     def deco_retry(f):
225         @functools.wraps(f)
226         def f_retry(*args, **kwargs):
227             mtries, mdelay = tries, delay_sec  # make mutable
228             logger.debug(f'deco_retry: will make up to {mtries} attempts...')
229             retval = f(*args, **kwargs)
230             while mtries > 0:
231                 if predicate(retval) is True:
232                     logger.debug('Predicate succeeded, deco_retry is done.')
233                     return retval
234                 logger.debug("Predicate failed, sleeping and retrying.")
235                 mtries -= 1
236                 time.sleep(mdelay)
237                 mdelay *= backoff
238                 retval = f(*args, **kwargs)
239             return retval
240         return f_retry
241     return deco_retry
242
243
244 def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0):
245     return retry_predicate(
246         tries,
247         predicate=lambda x: x is True,
248         delay_sec=delay_sec,
249         backoff=backoff,
250     )
251
252
253 def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0):
254     return retry_predicate(
255         tries,
256         predicate=lambda x: x is not None,
257         delay_sec=delay_sec,
258         backoff=backoff,
259     )
260
261
262 def deprecated(func):
263     """This is a decorator which can be used to mark functions
264     as deprecated. It will result in a warning being emitted
265     when the function is used.
266     """
267
268     @functools.wraps(func)
269     def wrapper_deprecated(*args, **kwargs):
270         msg = f"Call to deprecated function {func.__name__}"
271         logger.warning(msg)
272         warnings.warn(msg, category=DeprecationWarning)
273         return func(*args, **kwargs)
274
275     return wrapper_deprecated
276
277
278 def thunkify(func):
279     """
280     Make a function immediately return a function of no args which,
281     when called, waits for the result, which will start being
282     processed in another thread.
283     """
284
285     @functools.wraps(func)
286     def lazy_thunked(*args, **kwargs):
287         wait_event = threading.Event()
288
289         result = [None]
290         exc = [False, None]
291
292         def worker_func():
293             try:
294                 func_result = func(*args, **kwargs)
295                 result[0] = func_result
296             except Exception:
297                 exc[0] = True
298                 exc[1] = sys.exc_info()  # (type, value, traceback)
299                 msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}"
300                 logger.warning(msg)
301                 print(msg)
302             finally:
303                 wait_event.set()
304
305         def thunk():
306             wait_event.wait()
307             if exc[0]:
308                 raise exc[1][0](exc[1][1])
309             return result[0]
310
311         threading.Thread(target=worker_func).start()
312         return thunk
313
314     return lazy_thunked
315
316
317 ############################################################
318 # Timeout
319 ############################################################
320
321 # http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/
322 # Used work of Stephen "Zero" Chappell <[email protected]>
323 # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py
324
325
326 def _raise_exception(exception, error_message: Optional[str]):
327     if error_message is None:
328         raise exception()
329     else:
330         raise exception(error_message)
331
332
333 def _target(queue, function, *args, **kwargs):
334     """Run a function with arguments and return output via a queue.
335
336     This is a helper function for the Process created in _Timeout. It runs
337     the function with positional arguments and keyword arguments and then
338     returns the function's output by way of a queue. If an exception gets
339     raised, it is returned to _Timeout to be raised by the value property.
340     """
341     try:
342         queue.put((True, function(*args, **kwargs)))
343     except Exception:
344         queue.put((False, sys.exc_info()[1]))
345
346
347 class _Timeout(object):
348     """Wrap a function and add a timeout (limit) attribute to it.
349
350     Instances of this class are automatically generated by the add_timeout
351     function defined below.
352     """
353
354     def __init__(
355         self,
356         function: Callable,
357         timeout_exception: Exception,
358         error_message: str,
359         seconds: float,
360     ):
361         self.__limit = seconds
362         self.__function = function
363         self.__timeout_exception = timeout_exception
364         self.__error_message = error_message
365         self.__name__ = function.__name__
366         self.__doc__ = function.__doc__
367         self.__timeout = time.time()
368         self.__process = multiprocessing.Process()
369         self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
370
371     def __call__(self, *args, **kwargs):
372         """Execute the embedded function object asynchronously.
373
374         The function given to the constructor is transparently called and
375         requires that "ready" be intermittently polled. If and when it is
376         True, the "value" property may then be checked for returned data.
377         """
378         self.__limit = kwargs.pop("timeout", self.__limit)
379         self.__queue = multiprocessing.Queue(1)
380         args = (self.__queue, self.__function) + args
381         self.__process = multiprocessing.Process(
382             target=_target, args=args, kwargs=kwargs
383         )
384         self.__process.daemon = True
385         self.__process.start()
386         if self.__limit is not None:
387             self.__timeout = self.__limit + time.time()
388         while not self.ready:
389             time.sleep(0.1)
390         return self.value
391
392     def cancel(self):
393         """Terminate any possible execution of the embedded function."""
394         if self.__process.is_alive():
395             self.__process.terminate()
396         _raise_exception(self.__timeout_exception, self.__error_message)
397
398     @property
399     def ready(self):
400         """Read-only property indicating status of "value" property."""
401         if self.__limit and self.__timeout < time.time():
402             self.cancel()
403         return self.__queue.full() and not self.__queue.empty()
404
405     @property
406     def value(self):
407         """Read-only property containing data returned from function."""
408         if self.ready is True:
409             flag, load = self.__queue.get()
410             if flag:
411                 return load
412             raise load
413
414
415 def timeout(
416     seconds: float = 1.0,
417     use_signals: Optional[bool] = None,
418     timeout_exception=exceptions.TimeoutError,
419     error_message="Function call timed out",
420 ):
421     """Add a timeout parameter to a function and return the function.
422
423     Note: the use_signals parameter is included in order to support
424     multiprocessing scenarios (signal can only be used from the process'
425     main thread).  When not using signals, timeout granularity will be
426     rounded to the nearest 0.1s.
427
428     Raises an exception when the timeout is reached.
429
430     It is illegal to pass anything other than a function as the first
431     parameter.  The function is wrapped and returned to the caller.
432     """
433     if use_signals is None:
434         import thread_utils
435         use_signals = thread_utils.is_current_thread_main_thread()
436
437     def decorate(function):
438
439         if use_signals:
440
441             def handler(signum, frame):
442                 _raise_exception(timeout_exception, error_message)
443
444             @functools.wraps(function)
445             def new_function(*args, **kwargs):
446                 new_seconds = kwargs.pop("timeout", seconds)
447                 if new_seconds:
448                     old = signal.signal(signal.SIGALRM, handler)
449                     signal.setitimer(signal.ITIMER_REAL, new_seconds)
450
451                 if not seconds:
452                     return function(*args, **kwargs)
453
454                 try:
455                     return function(*args, **kwargs)
456                 finally:
457                     if new_seconds:
458                         signal.setitimer(signal.ITIMER_REAL, 0)
459                         signal.signal(signal.SIGALRM, old)
460
461             return new_function
462         else:
463
464             @functools.wraps(function)
465             def new_function(*args, **kwargs):
466                 timeout_wrapper = _Timeout(
467                     function, timeout_exception, error_message, seconds
468                 )
469                 return timeout_wrapper(*args, **kwargs)
470
471             return new_function
472
473     return decorate
474
475
476 class non_reentrant_code(object):
477     def __init__(self):
478         self._lock = threading.RLock
479         self._entered = False
480
481     def __call__(self, f):
482         def _gatekeeper(*args, **kwargs):
483             with self._lock:
484                 if self._entered:
485                     return
486                 self._entered = True
487                 f(*args, **kwargs)
488                 self._entered = False
489
490         return _gatekeeper
491
492
493 class rlocked(object):
494     def __init__(self):
495         self._lock = threading.RLock
496         self._entered = False
497
498     def __call__(self, f):
499         def _gatekeeper(*args, **kwargs):
500             with self._lock:
501                 if self._entered:
502                     return
503                 self._entered = True
504                 f(*args, **kwargs)
505                 self._entered = False
506         return _gatekeeper
507
508
509 def call_with_sample_rate(sample_rate: float) -> Callable:
510     if not 0.0 <= sample_rate <= 1.0:
511         msg = f"sample_rate must be between [0, 1]. Got {sample_rate}."
512         logger.critical(msg)
513         raise ValueError(msg)
514
515     def decorator(f):
516         @functools.wraps(f)
517         def _call_with_sample_rate(*args, **kwargs):
518             if random.uniform(0, 1) < sample_rate:
519                 return f(*args, **kwargs)
520             else:
521                 logger.debug(
522                     f"@call_with_sample_rate skipping a call to {f.__name__}"
523                 )
524         return _call_with_sample_rate
525     return decorator
526
527
528 def decorate_matching_methods_with(decorator, acl=None):
529     """Apply decorator to all methods in a class whose names begin with
530     prefix.  If prefix is None (default), decorate all methods in the
531     class.
532     """
533     def decorate_the_class(cls):
534         for name, m in inspect.getmembers(cls, inspect.isfunction):
535             if acl is None:
536                 setattr(cls, name, decorator(m))
537             else:
538                 if acl(name):
539                     setattr(cls, name, decorator(m))
540         return cls
541     return decorate_the_class