Simple base interfaces.
[pyutils.git] / src / pyutils / parallelize / thread_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """Utilities for dealing with threads + threading."""
6
7 import functools
8 import logging
9 import os
10 import threading
11 from typing import Any, Callable, Optional, Tuple
12
13 from pyutils.typez.typing import Runnable
14
15 # This module is commonly used by others in here and should avoid
16 # taking any unnecessary dependencies back on them.
17
18 logger = logging.getLogger(__name__)
19
20
21 def current_thread_id() -> str:
22     """
23     Returns:
24         A string composed of the parent process' id, the
25         current process' id and the current thread name that can be used
26         as a unique identifier for the current thread.  The former two are
27         numbers (pids) whereas the latter is a thread id passed during
28         thread creation time.
29
30     >>> from pyutils.parallelize import thread_utils
31     >>> ret = thread_utils.current_thread_id()
32     >>> ret  # doctest: +SKIP
33     '76891/84444/MainThread:'
34     >>> (ppid, pid, tid) = ret.split('/')
35     >>> ppid.isnumeric()
36     True
37     >>> pid.isnumeric()
38     True
39     """
40     ppid = os.getppid()
41     pid = os.getpid()
42     tid = threading.current_thread().name
43     return f"{ppid}/{pid}/{tid}:"
44
45
46 def is_current_thread_main_thread() -> bool:
47     """
48     Returns:
49         True is the current (calling) thread is the process' main
50         thread and False otherwise.
51
52     >>> from pyutils.parallelize import thread_utils
53     >>> thread_utils.is_current_thread_main_thread()
54     True
55
56     >>> result = None
57     >>> def am_i_the_main_thread():
58     ...     global result
59     ...     result = thread_utils.is_current_thread_main_thread()
60
61     >>> am_i_the_main_thread()
62     >>> result
63     True
64
65     >>> import threading
66     >>> thread = threading.Thread(target=am_i_the_main_thread)
67     >>> thread.start()
68     >>> thread.join()
69     >>> result
70     False
71     """
72     return threading.current_thread() is threading.main_thread()
73
74
75 def background_thread(
76     _funct: Optional[Callable[..., Any]],
77 ) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
78     """A function decorator to create a background thread.
79
80     Args:
81         _funct: The function being wrapped such that it is invoked
82             on a background thread.
83
84     Example usage::
85
86         import threading
87         import time
88
89         from pyutils.parallelize import thread_utils
90
91         @thread_utils.background_thread
92         def random(a: int, b: str, stop_event: threading.Event) -> None:
93             while True:
94                 print(f"Hi there {b}: {a}!")
95                 time.sleep(10.0)
96                 if stop_event.is_set():
97                     return
98
99         def main() -> None:
100             (thread, event) = random(22, "dude")
101             print("back!")
102             time.sleep(30.0)
103             event.set()
104             thread.join()
105
106     .. warning::
107
108         In addition to any other arguments the function has, it must
109         take a stop_event as the last unnamed argument which it should
110         periodically check.  If the event is set, it means the thread has
111         been requested to terminate ASAP.
112     """
113
114     def wrapper(funct: Callable):
115         @functools.wraps(funct)
116         def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
117             should_terminate = threading.Event()
118             should_terminate.clear()
119             newargs = (*a, should_terminate)
120             thread = threading.Thread(
121                 target=funct,
122                 args=newargs,
123                 kwargs=kwa,
124             )
125             thread.start()
126             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
127             return (thread, should_terminate)
128
129         return inner_wrapper
130
131     if _funct is None:
132         return wrapper  # type: ignore
133     else:
134         return wrapper(_funct)
135
136
137 class ThreadWithReturnValue(threading.Thread, Runnable):
138     """A thread whose return value is plumbed back out as the return
139     value of :meth:`join`.  Use like a normal thread::
140
141         import threading
142
143         from pyutils.parallelize import thread_utils
144
145         def thread_entry_point(args):
146             # do something interesting...
147             return result
148
149         if __name__ == "__main__":
150             thread = thread_utils.ThreadWithReturnValue(
151                 target=thread_entry_point,
152                 args=(whatever)
153             )
154             thread.start()
155             result = thread.join()
156             print(f"thread finished and returned {result}")
157
158     """
159
160     def __init__(
161         self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None
162     ):
163         threading.Thread.__init__(
164             self,
165             group=None,
166             target=target,
167             name=None,
168             args=args,
169             kwargs=kwargs,
170             daemon=daemon,
171         )
172         self._target = target
173         self._return = None
174         self._args = args
175         self._kwargs = kwargs
176
177     def run(self) -> None:
178         """Create a little wrapper around invoking the real thread entry
179         point so we can pay attention to its return value."""
180         if self._target is not None:
181             self._return = self._target(*self._args, **self._kwargs)
182
183     def join(self, *args) -> Any:
184         """Wait until the thread terminates and return the value it terminated with
185         as the result of join.
186
187         Like normal :meth:`join`, this blocks the calling thread until
188         the thread whose :meth:`join` is called terminates – either
189         normally or through an unhandled exception or until the
190         optional timeout occurs.
191
192         When the timeout argument is present and not None, it should
193         be a floating point number specifying a timeout for the
194         operation in seconds (or fractions thereof).
195
196         When the timeout argument is not present or None, the
197         operation will block until the thread terminates.
198
199         A thread can be joined many times.
200
201         :meth:`join` raises a RuntimeError if an attempt is made to join the
202         current thread as that would cause a deadlock. It is also an
203         error to join a thread before it has been started and
204         attempts to do so raises the same exception.
205         """
206         threading.Thread.join(self, *args)
207         return self._return
208
209
210 def periodically_invoke(
211     period_sec: float,
212     stop_after: Optional[int],
213 ):
214     """
215     Periodically invoke the decorated function on a background thread.
216
217     Args:
218         period_sec: the delay period in seconds between invocations
219         stop_after: total number of invocations to make or, if None,
220             call forever
221
222     Returns:
223         a :class:`Thread` object and an :class:`Event` that, when
224         signaled, will stop the invocations.
225
226     .. note::
227         It is possible to be invoked one time after the :class:`Event`
228         is set.  This event can be used to stop infinite
229         invocation style or finite invocation style decorations.
230
231     Usage::
232
233         from pyutils.parallelize import thread_utils
234
235         @thread_utils.periodically_invoke(period_sec=1.0, stop_after=3)
236         def hello(name: str) -> None:
237             print(f"Hello, {name}")
238
239         @thread_utils.periodically_invoke(period_sec=0.5, stop_after=None)
240         def there(name: str, age: int) -> None:
241             print(f"   ...there {name}, {age}")
242
243     Usage as a decorator doesn't give you access to the returned stop event or
244     thread object.  To get those, wrap your periodic function manually::
245
246         from pyutils.parallelize import thread_utils
247
248         def periodic(m: str) -> None:
249             print(m)
250
251         f = thread_utils.periodically_invoke(period_sec=5.0, stop_after=None)(periodic)
252         thread, event = f("testing")
253         ...
254         event.set()
255         thread.join()
256
257     See also :mod:`pyutils.state_tracker`.
258     """
259
260     def decorator_repeat(func):
261         def helper_thread(should_terminate, *args, **kwargs) -> None:
262             if stop_after is None:
263                 while True:
264                     func(*args, **kwargs)
265                     res = should_terminate.wait(period_sec)
266                     if res:
267                         return
268             else:
269                 for _ in range(stop_after):
270                     func(*args, **kwargs)
271                     res = should_terminate.wait(period_sec)
272                     if res:
273                         return
274                 return
275
276         @functools.wraps(func)
277         def wrapper_repeat(*args, **kwargs):
278             should_terminate = threading.Event()
279             should_terminate.clear()
280             newargs = (should_terminate, *args)
281             thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
282             thread.start()
283             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
284             return (thread, should_terminate)
285
286         return wrapper_repeat
287
288     return decorator_repeat
289
290
291 if __name__ == "__main__":
292     import doctest
293
294     doctest.testmod()