Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / parallelize / thread_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """Utilities for dealing with threads + threading."""
6
7 import functools
8 import logging
9 import os
10 import threading
11 from typing import Any, Callable, Optional, Tuple
12
13 from pyutils.typez.typing import Runnable
14
15 # This module is commonly used by others in here and should avoid
16 # taking any unnecessary dependencies back on them.
17
18 logger = logging.getLogger(__name__)
19
20
21 def current_thread_id() -> str:
22     """
23     Returns:
24         A string composed of the parent process' id, the
25         current process' id and the current thread name that can be used
26         as a unique identifier for the current thread.  The former two are
27         numbers (pids) whereas the latter is a thread id passed during
28         thread creation time.
29
30     >>> from pyutils.parallelize import thread_utils
31     >>> ret = thread_utils.current_thread_id()
32     >>> ret  # doctest: +SKIP
33     '76891/84444/MainThread:'
34     >>> (ppid, pid, tid) = ret.split('/')
35     >>> ppid.isnumeric()
36     True
37     >>> pid.isnumeric()
38     True
39     """
40     ppid = os.getppid()
41     pid = os.getpid()
42     tid = threading.current_thread().name
43     return f"{ppid}/{pid}/{tid}:"
44
45
46 def is_current_thread_main_thread() -> bool:
47     """
48     Returns:
49         True is the current (calling) thread is the process' main
50         thread and False otherwise.
51
52     >>> from pyutils.parallelize import thread_utils
53     >>> thread_utils.is_current_thread_main_thread()
54     True
55
56     >>> result = None
57     >>> def am_i_the_main_thread():
58     ...     global result
59     ...     result = thread_utils.is_current_thread_main_thread()
60
61     >>> am_i_the_main_thread()
62     >>> result
63     True
64
65     >>> import threading
66     >>> thread = threading.Thread(target=am_i_the_main_thread)
67     >>> thread.start()
68     >>> thread.join()
69     >>> result
70     False
71     """
72     return threading.current_thread() is threading.main_thread()
73
74
75 def background_thread(
76     _funct: Optional[Callable[..., Any]],
77 ) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
78     """A function decorator to create a background thread.
79
80     Args:
81         _funct: The function being wrapped such that it is invoked
82             on a background thread.
83
84     Example usage::
85
86         import threading
87         import time
88
89         from pyutils.parallelize import thread_utils
90
91         @thread_utils.background_thread
92         def random(a: int, b: str, stop_event: threading.Event) -> None:
93             while True:
94                 print(f"Hi there {b}: {a}!")
95                 time.sleep(10.0)
96                 if stop_event.is_set():
97                     return
98
99         def main() -> None:
100             (thread, event) = random(22, "dude")
101             print("back!")
102             time.sleep(30.0)
103             event.set()
104             thread.join()
105
106     .. warning::
107
108         In addition to any other arguments the function has, it must
109         take a stop_event as the last unnamed argument which it should
110         periodically check.  If the event is set, it means the thread has
111         been requested to terminate ASAP.
112     """
113
114     def wrapper(funct: Callable):
115         @functools.wraps(funct)
116         def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
117             should_terminate = threading.Event()
118             should_terminate.clear()
119             newargs = (*a, should_terminate)
120             thread = threading.Thread(
121                 target=funct,
122                 args=newargs,
123                 kwargs=kwa,
124             )
125             thread.start()
126             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
127             return (thread, should_terminate)
128
129         return inner_wrapper
130
131     if _funct is None:
132         return wrapper  # type: ignore
133     else:
134         return wrapper(_funct)
135
136
137 class ThreadWithReturnValue(threading.Thread, Runnable):
138     """A thread whose return value is plumbed back out as the return
139     value of :meth:`join`.  Use like a normal thread::
140
141         import threading
142
143         from pyutils.parallelize import thread_utils
144
145         def thread_entry_point(args):
146             # do something interesting...
147             return result
148
149         if __name__ == "__main__":
150             thread = thread_utils.ThreadWithReturnValue(
151                 target=thread_entry_point,
152                 args=(whatever)
153             )
154             thread.start()
155             result = thread.join()
156             print(f"thread finished and returned {result}")
157
158     """
159
160     def __init__(
161         self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None
162     ):
163         threading.Thread.__init__(
164             self,
165             group=None,
166             target=target,
167             name=None,
168             args=args,
169             kwargs=kwargs,
170             daemon=daemon,
171         )
172         self._target = target
173         self._return = None
174         self._args = args
175         self._kwargs = kwargs
176
177     def run(self) -> None:
178         """Create a little wrapper around invoking the real thread entry
179         point so we can pay attention to its return value."""
180         if self._target is not None:
181             self._return = self._target(*self._args, **self._kwargs)
182
183     def join(self, *args) -> Any:
184         """Wait until the thread terminates and return the value it terminated with
185         as the result of join.
186
187         Like normal :meth:`join`, this blocks the calling thread until
188         the thread whose :meth:`join` is called terminates – either
189         normally or through an unhandled exception or until the
190         optional timeout occurs.
191
192         When the timeout argument is present and not None, it should
193         be a floating point number specifying a timeout for the
194         operation in seconds (or fractions thereof).
195
196         When the timeout argument is not present or None, the
197         operation will block until the thread terminates.
198
199         A thread can be joined many times.
200
201         Raises:
202             RuntimeError: an attempt is made to join the current thread
203                 as that would cause a deadlock. It is also an error to join
204                 a thread before it has been started and attempts to do so
205                 raises the same exception.
206         """
207         threading.Thread.join(self, *args)
208         return self._return
209
210
211 def periodically_invoke(
212     period_sec: float,
213     stop_after: Optional[int],
214 ):
215     """
216     Periodically invoke the decorated function on a background thread.
217
218     Args:
219         period_sec: the delay period in seconds between invocations
220         stop_after: total number of invocations to make or, if None,
221             call forever
222
223     Returns:
224         a :class:`Thread` object and an :class:`Event` that, when
225         signaled, will stop the invocations.
226
227     .. note::
228         It is possible to be invoked one time after the :class:`Event`
229         is set.  This event can be used to stop infinite
230         invocation style or finite invocation style decorations.
231
232     Usage::
233
234         from pyutils.parallelize import thread_utils
235
236         @thread_utils.periodically_invoke(period_sec=1.0, stop_after=3)
237         def hello(name: str) -> None:
238             print(f"Hello, {name}")
239
240         @thread_utils.periodically_invoke(period_sec=0.5, stop_after=None)
241         def there(name: str, age: int) -> None:
242             print(f"   ...there {name}, {age}")
243
244     Usage as a decorator doesn't give you access to the returned stop event or
245     thread object.  To get those, wrap your periodic function manually::
246
247         from pyutils.parallelize import thread_utils
248
249         def periodic(m: str) -> None:
250             print(m)
251
252         f = thread_utils.periodically_invoke(period_sec=5.0, stop_after=None)(periodic)
253         thread, event = f("testing")
254         ...
255         event.set()
256         thread.join()
257
258     See also :mod:`pyutils.state_tracker`.
259     """
260
261     def decorator_repeat(func):
262         def helper_thread(should_terminate, *args, **kwargs) -> None:
263             if stop_after is None:
264                 while True:
265                     func(*args, **kwargs)
266                     res = should_terminate.wait(period_sec)
267                     if res:
268                         return
269             else:
270                 for _ in range(stop_after):
271                     func(*args, **kwargs)
272                     res = should_terminate.wait(period_sec)
273                     if res:
274                         return
275                 return
276
277         @functools.wraps(func)
278         def wrapper_repeat(*args, **kwargs):
279             should_terminate = threading.Event()
280             should_terminate.clear()
281             newargs = (should_terminate, *args)
282             thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
283             thread.start()
284             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
285             return (thread, should_terminate)
286
287         return wrapper_repeat
288
289     return decorator_repeat
290
291
292 if __name__ == "__main__":
293     import doctest
294
295     doctest.testmod()