More type annotations.
[python_utils.git] / thread_utils.py
index bb15c034b9e4e02f273b98dbc77410072878e089..22161275605d76a1199df8f18d536fd04e2fe17b 100644 (file)
@@ -13,6 +13,19 @@ logger = logging.getLogger(__name__)
 
 
 def current_thread_id() -> str:
+    """Returns a string composed of the parent process' id, the current
+    process' id and the current thread identifier.  The former two are
+    numbers (pids) whereas the latter is a thread id passed during thread
+    creation time.
+
+    >>> ret = current_thread_id()
+    >>> (ppid, pid, tid) = ret.split('/')
+    >>> ppid.isnumeric()
+    True
+    >>> pid.isnumeric()
+    True
+
+    """
     ppid = os.getppid()
     pid = os.getpid()
     tid = threading.current_thread().name
@@ -20,12 +33,35 @@ def current_thread_id() -> str:
 
 
 def is_current_thread_main_thread() -> bool:
+    """Returns True is the current (calling) thread is the process' main
+    thread and False otherwise.
+
+    >>> is_current_thread_main_thread()
+    True
+
+    >>> result = None
+    >>> def thunk():
+    ...     global result
+    ...     result = is_current_thread_main_thread()
+
+    >>> thunk()
+    >>> result
+    True
+
+    >>> import threading
+    >>> thread = threading.Thread(target=thunk)
+    >>> thread.start()
+    >>> thread.join()
+    >>> result
+    False
+
+    """
     return threading.current_thread() is threading.main_thread()
 
 
 def background_thread(
-        _funct: Optional[Callable]
-) -> Tuple[threading.Thread, threading.Event]:
+    _funct: Optional[Callable],
+) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
     """A function decorator to create a background thread.
 
     *** Please note: the decorated function must take an shutdown ***
@@ -55,11 +91,10 @@ def background_thread(
     periodically check.  If the event is set, it means the thread has
     been requested to terminate ASAP.
     """
+
     def wrapper(funct: Callable):
         @functools.wraps(funct)
-        def inner_wrapper(
-                *a, **kwa
-        ) -> Tuple[threading.Thread, threading.Event]:
+        def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (*a, should_terminate)
@@ -69,21 +104,20 @@ def background_thread(
                 kwargs=kwa,
             )
             thread.start()
-            logger.debug(
-                f'Started thread {thread.name} tid={thread.ident}'
-            )
+            logger.debug(f'Started thread {thread.name} tid={thread.ident}')
             return (thread, should_terminate)
+
         return inner_wrapper
 
     if _funct is None:
-        return wrapper
+        return wrapper  # type: ignore
     else:
         return wrapper(_funct)
 
 
 def periodically_invoke(
-        period_sec: float,
-        stop_after: Optional[int],
+    period_sec: float,
+    stop_after: Optional[int],
 ):
     """
     Periodically invoke a decorated function.  Stop after N invocations
@@ -105,6 +139,7 @@ def periodically_invoke(
             print(f"Hello, {name}")
 
     """
+
     def decorator_repeat(func):
         def helper_thread(should_terminate, *args, **kwargs) -> None:
             if stop_after is None:
@@ -126,15 +161,17 @@ def periodically_invoke(
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (should_terminate, *args)
-            thread = threading.Thread(
-                target=helper_thread,
-                args = newargs,
-                kwargs = kwargs
-            )
+            thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
             thread.start()
-            logger.debug(
-                f'Started thread {thread.name} tid={thread.ident}'
-            )
+            logger.debug(f'Started thread {thread.name} tid={thread.ident}')
             return (thread, should_terminate)
+
         return wrapper_repeat
+
     return decorator_repeat
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()