More documentation improvements.
[pyutils.git] / src / pyutils / parallelize / thread_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Utilities for dealing with threads + threading."""
6
7 import functools
8 import logging
9 import os
10 import threading
11 from typing import Any, Callable, Optional, Tuple
12
13 # This module is commonly used by others in here and should avoid
14 # taking any unnecessary dependencies back on them.
15
16 logger = logging.getLogger(__name__)
17
18
19 def current_thread_id() -> str:
20     """
21     Returns:
22         a string composed of the parent process' id, the current
23         process' id and the current thread identifier.  The former two are
24         numbers (pids) whereas the latter is a thread id passed during thread
25         creation time.
26
27     >>> ret = current_thread_id()
28     >>> (ppid, pid, tid) = ret.split('/')
29     >>> ppid.isnumeric()
30     True
31     >>> pid.isnumeric()
32     True
33
34     """
35     ppid = os.getppid()
36     pid = os.getpid()
37     tid = threading.current_thread().name
38     return f'{ppid}/{pid}/{tid}:'
39
40
41 def is_current_thread_main_thread() -> bool:
42     """
43     Returns:
44         True is the current (calling) thread is the process' main
45         thread and False otherwise.
46
47     >>> is_current_thread_main_thread()
48     True
49
50     >>> result = None
51     >>> def thunk():
52     ...     global result
53     ...     result = is_current_thread_main_thread()
54
55     >>> thunk()
56     >>> result
57     True
58
59     >>> import threading
60     >>> thread = threading.Thread(target=thunk)
61     >>> thread.start()
62     >>> thread.join()
63     >>> result
64     False
65
66     """
67     return threading.current_thread() is threading.main_thread()
68
69
70 def background_thread(
71     _funct: Optional[Callable[..., Any]],
72 ) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
73     """A function decorator to create a background thread.
74
75     Args:
76         _funct: The function being wrapped such that it is invoked
77             on a background thread.
78
79     Example usage::
80
81         @background_thread
82         def random(a: int, b: str, stop_event: threading.Event) -> None:
83             while True:
84                 print(f"Hi there {b}: {a}!")
85                 time.sleep(10.0)
86                 if stop_event.is_set():
87                     return
88
89         def main() -> None:
90             (thread, event) = random(22, "dude")
91             print("back!")
92             time.sleep(30.0)
93             event.set()
94             thread.join()
95
96     .. warning::
97
98         In addition to any other arguments the function has, it must
99         take a stop_event as the last unnamed argument which it should
100         periodically check.  If the event is set, it means the thread has
101         been requested to terminate ASAP.
102     """
103
104     def wrapper(funct: Callable):
105         @functools.wraps(funct)
106         def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
107             should_terminate = threading.Event()
108             should_terminate.clear()
109             newargs = (*a, should_terminate)
110             thread = threading.Thread(
111                 target=funct,
112                 args=newargs,
113                 kwargs=kwa,
114             )
115             thread.start()
116             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
117             return (thread, should_terminate)
118
119         return inner_wrapper
120
121     if _funct is None:
122         return wrapper  # type: ignore
123     else:
124         return wrapper(_funct)
125
126
127 class ThreadWithReturnValue(threading.Thread):
128     """A thread whose return value is plumbed back out as the return
129     value of :meth:`join`.  Use like a normal thread::
130
131         def thread_entry_point(args):
132             # do something interesting...
133             return result
134
135         if __name__ == "__main__":
136             thread = ThreadWithReturnValue(
137                 target=thread_entry_point,
138                 args=(whatever)
139             )
140             thread.start()
141             result = thread.join()
142             print(f"thread finished and returned {result}")
143
144     """
145
146     def __init__(
147         self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None
148     ):
149         threading.Thread.__init__(
150             self, group=None, target=target, name=None, args=args, kwargs=kwargs
151         )
152         self._target = target
153         self._return = None
154
155     def run(self) -> None:
156         """Create a little wrapper around invoking the real thread entry
157         point so we can pay attention to its return value."""
158         if self._target is not None:
159             self._return = self._target(*self._args, **self._kwargs)
160
161     def join(self, *args) -> Any:
162         """Wait until the thread terminates and return the value it terminated with
163         as the result of join.
164
165         Like normal :meth:`join`, this blocks the calling thread until
166         the thread whose :meth:`join` is called terminates – either
167         normally or through an unhandled exception or until the
168         optional timeout occurs.
169
170         When the timeout argument is present and not None, it should
171         be a floating point number specifying a timeout for the
172         operation in seconds (or fractions thereof).
173
174         When the timeout argument is not present or None, the
175         operation will block until the thread terminates.
176
177         A thread can be joined many times.
178
179         :meth:`join` raises a RuntimeError if an attempt is made to join the
180         current thread as that would cause a deadlock. It is also an
181         error to join a thread before it has been started and
182         attempts to do so raises the same exception.
183         """
184         threading.Thread.join(self, *args)
185         return self._return
186
187
188 def periodically_invoke(
189     period_sec: float,
190     stop_after: Optional[int],
191 ):
192     """
193     Periodically invoke the decorated function on a background thread.
194
195     Args:
196         period_sec: the delay period in seconds between invocations
197         stop_after: total number of invocations to make or, if None,
198             call forever
199
200     Returns:
201         a :class:`Thread` object and an :class:`Event` that, when
202         signaled, will stop the invocations.
203
204     .. note::
205         It is possible to be invoked one time after the :class:`Event`
206         is set.  This event can be used to stop infinite
207         invocation style or finite invocation style decorations.
208
209     Usage::
210
211         @periodically_invoke(period_sec=1.0, stop_after=3)
212         def hello(name: str) -> None:
213             print(f"Hello, {name}")
214
215         @periodically_invoke(period_sec=0.5, stop_after=None)
216         def there(name: str, age: int) -> None:
217             print(f"   ...there {name}, {age}")
218
219     Usage as a decorator doesn't give you access to the returned stop event or
220     thread object.  To get those, wrap your periodic function manually::
221
222         import pyutils.parallelize.thread_utils as tu
223
224         def periodic(m: str) -> None:
225             print(m)
226
227         f = tu.periodically_invoke(period_sec=5.0, stop_after=None)(periodic)
228         thread, event = f("testing")
229         ...
230         event.set()
231         thread.join()
232
233     See also :mod:`pyutils.state_tracker`.
234
235     """
236
237     def decorator_repeat(func):
238         def helper_thread(should_terminate, *args, **kwargs) -> None:
239             if stop_after is None:
240                 while True:
241                     func(*args, **kwargs)
242                     res = should_terminate.wait(period_sec)
243                     if res:
244                         return
245             else:
246                 for _ in range(stop_after):
247                     func(*args, **kwargs)
248                     res = should_terminate.wait(period_sec)
249                     if res:
250                         return
251                 return
252
253         @functools.wraps(func)
254         def wrapper_repeat(*args, **kwargs):
255             should_terminate = threading.Event()
256             should_terminate.clear()
257             newargs = (should_terminate, *args)
258             thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
259             thread.start()
260             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
261             return (thread, should_terminate)
262
263         return wrapper_repeat
264
265     return decorator_repeat
266
267
268 if __name__ == '__main__':
269     import doctest
270
271     doctest.testmod()