X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=decorator_utils.py;h=2817239c88c2396b0e5dcc56e7c535b8afdd99d9;hb=09e6d10face80d98a4578ff54192b5c8bec007d7;hp=4d882bed7ac4486db741b76b587b402a6dac147e;hpb=497fb9e21f45ec08e1486abaee6dfa7b20b8a691;p=python_utils.git diff --git a/decorator_utils.py b/decorator_utils.py index 4d882be..2817239 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -5,6 +5,7 @@ import datetime import enum import functools +import inspect import logging import math import multiprocessing @@ -17,7 +18,10 @@ import traceback from typing import Callable, Optional import warnings -import thread_utils +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. +import exceptions + logger = logging.getLogger(__name__) @@ -317,14 +321,6 @@ def thunkify(func): # in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py -class TimeoutError(AssertionError): - def __init__(self, value: str = "Timed Out"): - self.value = value - - def __str__(self): - return repr(self.value) - - def _raise_exception(exception, error_message: Optional[str]): if error_message is None: raise exception() @@ -417,7 +413,7 @@ class _Timeout(object): def timeout( seconds: float = 1.0, use_signals: Optional[bool] = None, - timeout_exception=TimeoutError, + timeout_exception=exceptions.TimeoutError, error_message="Function call timed out", ): """Add a timeout parameter to a function and return the function. @@ -433,6 +429,7 @@ def timeout( parameter. The function is wrapped and returned to the caller. """ if use_signals is None: + import thread_utils use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -524,3 +521,19 @@ def call_with_sample_rate(sample_rate: float) -> Callable: ) return _call_with_sample_rate return decorator + + +def decorate_matching_methods_with(decorator, acl=None): + """Apply decorator to all methods in a class whose names begin with + prefix. If prefix is None (default), decorate all methods in the + class. + """ + def decorate_the_class(cls): + for name, m in inspect.getmembers(cls, inspect.isfunction): + if acl is None: + setattr(cls, name, decorator(m)) + else: + if acl(name): + setattr(cls, name, decorator(m)) + return cls + return decorate_the_class