Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / thread_utils.py
index 0130cdc510547196d418d6699d1b46b84a6ddf7c..01755deafc1e7af9189026c2ae233c6146c7d494 100644 (file)
@@ -1,10 +1,14 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
+"""Utilities for dealing with threads + threading."""
+
 import functools
 import logging
 import os
 import threading
 import functools
 import logging
 import os
 import threading
-from typing import Callable, Optional, Tuple
+from typing import Any, Callable, Optional, Tuple
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
@@ -13,6 +17,19 @@ logger = logging.getLogger(__name__)
 
 
 def current_thread_id() -> str:
 
 
 def current_thread_id() -> str:
+    """Returns a string composed of the parent process' id, the current
+    process' id and the current thread identifier.  The former two are
+    numbers (pids) whereas the latter is a thread id passed during thread
+    creation time.
+
+    >>> ret = current_thread_id()
+    >>> (ppid, pid, tid) = ret.split('/')
+    >>> ppid.isnumeric()
+    True
+    >>> pid.isnumeric()
+    True
+
+    """
     ppid = os.getppid()
     pid = os.getpid()
     tid = threading.current_thread().name
     ppid = os.getppid()
     pid = os.getpid()
     tid = threading.current_thread().name
@@ -22,13 +39,33 @@ def current_thread_id() -> str:
 def is_current_thread_main_thread() -> bool:
     """Returns True is the current (calling) thread is the process' main
     thread and False otherwise.
 def is_current_thread_main_thread() -> bool:
     """Returns True is the current (calling) thread is the process' main
     thread and False otherwise.
+
+    >>> is_current_thread_main_thread()
+    True
+
+    >>> result = None
+    >>> def thunk():
+    ...     global result
+    ...     result = is_current_thread_main_thread()
+
+    >>> thunk()
+    >>> result
+    True
+
+    >>> import threading
+    >>> thread = threading.Thread(target=thunk)
+    >>> thread.start()
+    >>> thread.join()
+    >>> result
+    False
+
     """
     return threading.current_thread() is threading.main_thread()
 
 
 def background_thread(
     """
     return threading.current_thread() is threading.main_thread()
 
 
 def background_thread(
-        _funct: Optional[Callable]
-) -> Tuple[threading.Thread, threading.Event]:
+    _funct: Optional[Callable[..., Any]],
+) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
     """A function decorator to create a background thread.
 
     *** Please note: the decorated function must take an shutdown ***
     """A function decorator to create a background thread.
 
     *** Please note: the decorated function must take an shutdown ***
@@ -58,11 +95,10 @@ def background_thread(
     periodically check.  If the event is set, it means the thread has
     been requested to terminate ASAP.
     """
     periodically check.  If the event is set, it means the thread has
     been requested to terminate ASAP.
     """
+
     def wrapper(funct: Callable):
         @functools.wraps(funct)
     def wrapper(funct: Callable):
         @functools.wraps(funct)
-        def inner_wrapper(
-                *a, **kwa
-        ) -> Tuple[threading.Thread, threading.Event]:
+        def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (*a, should_terminate)
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (*a, should_terminate)
@@ -72,21 +108,20 @@ def background_thread(
                 kwargs=kwa,
             )
             thread.start()
                 kwargs=kwa,
             )
             thread.start()
-            logger.debug(
-                f'Started thread {thread.name} tid={thread.ident}'
-            )
+            logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
             return (thread, should_terminate)
             return (thread, should_terminate)
+
         return inner_wrapper
 
     if _funct is None:
         return inner_wrapper
 
     if _funct is None:
-        return wrapper
+        return wrapper  # type: ignore
     else:
         return wrapper(_funct)
 
 
 def periodically_invoke(
     else:
         return wrapper(_funct)
 
 
 def periodically_invoke(
-        period_sec: float,
-        stop_after: Optional[int],
+    period_sec: float,
+    stop_after: Optional[int],
 ):
     """
     Periodically invoke a decorated function.  Stop after N invocations
 ):
     """
     Periodically invoke a decorated function.  Stop after N invocations
@@ -108,6 +143,7 @@ def periodically_invoke(
             print(f"Hello, {name}")
 
     """
             print(f"Hello, {name}")
 
     """
+
     def decorator_repeat(func):
         def helper_thread(should_terminate, *args, **kwargs) -> None:
             if stop_after is None:
     def decorator_repeat(func):
         def helper_thread(should_terminate, *args, **kwargs) -> None:
             if stop_after is None:
@@ -129,15 +165,17 @@ def periodically_invoke(
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (should_terminate, *args)
             should_terminate = threading.Event()
             should_terminate.clear()
             newargs = (should_terminate, *args)
-            thread = threading.Thread(
-                target=helper_thread,
-                args = newargs,
-                kwargs = kwargs
-            )
+            thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
             thread.start()
             thread.start()
-            logger.debug(
-                f'Started thread {thread.name} tid={thread.ident}'
-            )
+            logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
             return (thread, should_terminate)
             return (thread, should_terminate)
+
         return wrapper_repeat
         return wrapper_repeat
+
     return decorator_repeat
     return decorator_repeat
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()