Improve documentation / doctests.
[pyutils.git] / src / pyutils / parallelize / thread_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Utilities for dealing with threads + threading."""
6
7 import functools
8 import logging
9 import os
10 import threading
11 from typing import Any, Callable, Optional, Tuple
12
13 # This module is commonly used by others in here and should avoid
14 # taking any unnecessary dependencies back on them.
15
16 logger = logging.getLogger(__name__)
17
18
19 def current_thread_id() -> str:
20     """
21     Returns:
22         A string composed of the parent process' id, the
23         current process' id and the current thread name that can be used
24         as a unique identifier for the current thread.  The former two are
25         numbers (pids) whereas the latter is a thread id passed during
26         thread creation time.
27
28     >>> from pyutils.parallelize import thread_utils
29     >>> ret = thread_utils.current_thread_id()
30     >>> ret  # doctest: +SKIP
31     '76891/84444/MainThread:'
32     >>> (ppid, pid, tid) = ret.split('/')
33     >>> ppid.isnumeric()
34     True
35     >>> pid.isnumeric()
36     True
37     """
38     ppid = os.getppid()
39     pid = os.getpid()
40     tid = threading.current_thread().name
41     return f'{ppid}/{pid}/{tid}:'
42
43
44 def is_current_thread_main_thread() -> bool:
45     """
46     Returns:
47         True is the current (calling) thread is the process' main
48         thread and False otherwise.
49
50     >>> from pyutils.parallelize import thread_utils
51     >>> thread_utils.is_current_thread_main_thread()
52     True
53
54     >>> result = None
55     >>> def am_i_the_main_thread():
56     ...     global result
57     ...     result = thread_utils.is_current_thread_main_thread()
58
59     >>> am_i_the_main_thread()
60     >>> result
61     True
62
63     >>> import threading
64     >>> thread = threading.Thread(target=am_i_the_main_thread)
65     >>> thread.start()
66     >>> thread.join()
67     >>> result
68     False
69     """
70     return threading.current_thread() is threading.main_thread()
71
72
73 def background_thread(
74     _funct: Optional[Callable[..., Any]],
75 ) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
76     """A function decorator to create a background thread.
77
78     Args:
79         _funct: The function being wrapped such that it is invoked
80             on a background thread.
81
82     Example usage::
83
84         import threading
85         import time
86
87         from pyutils.parallelize import thread_utils
88
89         @thread_utils.background_thread
90         def random(a: int, b: str, stop_event: threading.Event) -> None:
91             while True:
92                 print(f"Hi there {b}: {a}!")
93                 time.sleep(10.0)
94                 if stop_event.is_set():
95                     return
96
97         def main() -> None:
98             (thread, event) = random(22, "dude")
99             print("back!")
100             time.sleep(30.0)
101             event.set()
102             thread.join()
103
104     .. warning::
105
106         In addition to any other arguments the function has, it must
107         take a stop_event as the last unnamed argument which it should
108         periodically check.  If the event is set, it means the thread has
109         been requested to terminate ASAP.
110     """
111
112     def wrapper(funct: Callable):
113         @functools.wraps(funct)
114         def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]:
115             should_terminate = threading.Event()
116             should_terminate.clear()
117             newargs = (*a, should_terminate)
118             thread = threading.Thread(
119                 target=funct,
120                 args=newargs,
121                 kwargs=kwa,
122             )
123             thread.start()
124             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
125             return (thread, should_terminate)
126
127         return inner_wrapper
128
129     if _funct is None:
130         return wrapper  # type: ignore
131     else:
132         return wrapper(_funct)
133
134
135 class ThreadWithReturnValue(threading.Thread):
136     """A thread whose return value is plumbed back out as the return
137     value of :meth:`join`.  Use like a normal thread::
138
139         import threading
140
141         from pyutils.parallelize import thread_utils
142
143         def thread_entry_point(args):
144             # do something interesting...
145             return result
146
147         if __name__ == "__main__":
148             thread = thread_utils.ThreadWithReturnValue(
149                 target=thread_entry_point,
150                 args=(whatever)
151             )
152             thread.start()
153             result = thread.join()
154             print(f"thread finished and returned {result}")
155
156     """
157
158     def __init__(
159         self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None
160     ):
161         threading.Thread.__init__(
162             self, group=None, target=target, name=None, args=args, kwargs=kwargs
163         )
164         self._target = target
165         self._return = None
166
167     def run(self) -> None:
168         """Create a little wrapper around invoking the real thread entry
169         point so we can pay attention to its return value."""
170         if self._target is not None:
171             self._return = self._target(*self._args, **self._kwargs)
172
173     def join(self, *args) -> Any:
174         """Wait until the thread terminates and return the value it terminated with
175         as the result of join.
176
177         Like normal :meth:`join`, this blocks the calling thread until
178         the thread whose :meth:`join` is called terminates – either
179         normally or through an unhandled exception or until the
180         optional timeout occurs.
181
182         When the timeout argument is present and not None, it should
183         be a floating point number specifying a timeout for the
184         operation in seconds (or fractions thereof).
185
186         When the timeout argument is not present or None, the
187         operation will block until the thread terminates.
188
189         A thread can be joined many times.
190
191         :meth:`join` raises a RuntimeError if an attempt is made to join the
192         current thread as that would cause a deadlock. It is also an
193         error to join a thread before it has been started and
194         attempts to do so raises the same exception.
195         """
196         threading.Thread.join(self, *args)
197         return self._return
198
199
200 def periodically_invoke(
201     period_sec: float,
202     stop_after: Optional[int],
203 ):
204     """
205     Periodically invoke the decorated function on a background thread.
206
207     Args:
208         period_sec: the delay period in seconds between invocations
209         stop_after: total number of invocations to make or, if None,
210             call forever
211
212     Returns:
213         a :class:`Thread` object and an :class:`Event` that, when
214         signaled, will stop the invocations.
215
216     .. note::
217         It is possible to be invoked one time after the :class:`Event`
218         is set.  This event can be used to stop infinite
219         invocation style or finite invocation style decorations.
220
221     Usage::
222
223         from pyutils.parallelize import thread_utils
224
225         @thread_utils.periodically_invoke(period_sec=1.0, stop_after=3)
226         def hello(name: str) -> None:
227             print(f"Hello, {name}")
228
229         @thread_utils.periodically_invoke(period_sec=0.5, stop_after=None)
230         def there(name: str, age: int) -> None:
231             print(f"   ...there {name}, {age}")
232
233     Usage as a decorator doesn't give you access to the returned stop event or
234     thread object.  To get those, wrap your periodic function manually::
235
236         from pyutils.parallelize import thread_utils
237
238         def periodic(m: str) -> None:
239             print(m)
240
241         f = thread_utils.periodically_invoke(period_sec=5.0, stop_after=None)(periodic)
242         thread, event = f("testing")
243         ...
244         event.set()
245         thread.join()
246
247     See also :mod:`pyutils.state_tracker`.
248     """
249
250     def decorator_repeat(func):
251         def helper_thread(should_terminate, *args, **kwargs) -> None:
252             if stop_after is None:
253                 while True:
254                     func(*args, **kwargs)
255                     res = should_terminate.wait(period_sec)
256                     if res:
257                         return
258             else:
259                 for _ in range(stop_after):
260                     func(*args, **kwargs)
261                     res = should_terminate.wait(period_sec)
262                     if res:
263                         return
264                 return
265
266         @functools.wraps(func)
267         def wrapper_repeat(*args, **kwargs):
268             should_terminate = threading.Event()
269             should_terminate.clear()
270             newargs = (should_terminate, *args)
271             thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs)
272             thread.start()
273             logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident)
274             return (thread, should_terminate)
275
276         return wrapper_repeat
277
278     return decorator_repeat
279
280
281 if __name__ == '__main__':
282     import doctest
283
284     doctest.testmod()