Adding test code trying to improve test coverage.
authorScott <[email protected]>
Thu, 27 Jan 2022 21:42:09 +0000 (13:42 -0800)
committerScott <[email protected]>
Thu, 27 Jan 2022 21:42:09 +0000 (13:42 -0800)
exec_utils.py
tests/exec_utils_test.py [new file with mode: 0755]
tests/thread_utils_test.py [new file with mode: 0755]
thread_utils.py
unittest_utils.py

index 282a325a461e289144b5a58b5a88ce4a90098c83..dcd30a2e937e271ffe75109c019b3b345fa5997d 100644 (file)
@@ -2,6 +2,7 @@
 
 import atexit
 import logging
+import os
 import selectors
 import shlex
 import subprocess
@@ -21,36 +22,40 @@ def cmd_showing_output(
 
     """
     line_enders = set([b'\n', b'\r'])
-    p = subprocess.Popen(
+    sel = selectors.DefaultSelector()
+    with subprocess.Popen(
         command,
         shell=True,
         bufsize=0,
         stdout=subprocess.PIPE,
         stderr=subprocess.PIPE,
         universal_newlines=False,
-    )
-    sel = selectors.DefaultSelector()
-    sel.register(p.stdout, selectors.EVENT_READ)
-    sel.register(p.stderr, selectors.EVENT_READ)
-    stream_ends = 0
-    while stream_ends < 2:
-        for key, _ in sel.select():
-            char = key.fileobj.read(1)
-            if not char:
-                stream_ends += 1
-                continue
-            if key.fileobj is p.stdout:
-                sys.stdout.buffer.write(char)
-                if char in line_enders:
-                    sys.stdout.flush()
-            else:
-                sys.stderr.buffer.write(char)
-                if char in line_enders:
-                    sys.stderr.flush()
-    p.wait()
-    sys.stdout.flush()
-    sys.stderr.flush()
-    return p.returncode
+    ) as p:
+        sel.register(p.stdout, selectors.EVENT_READ)
+        sel.register(p.stderr, selectors.EVENT_READ)
+        done = False
+        while not done:
+            for key, _ in sel.select():
+                char = key.fileobj.read(1)
+                if not char:
+                    sel.unregister(key.fileobj)
+                    if len(sel.get_map()) == 0:
+                        sys.stdout.flush()
+                        sys.stderr.flush()
+                        sel.close()
+                        done = True
+                if key.fileobj is p.stdout:
+                    # sys.stdout.buffer.write(char)
+                    os.write(sys.stdout.fileno(), char)
+                    if char in line_enders:
+                        sys.stdout.flush()
+                else:
+                    # sys.stderr.buffer.write(char)
+                    os.write(sys.stderr.fileno(), char)
+                    if char in line_enders:
+                        sys.stderr.flush()
+        p.wait()
+        return p.returncode
 
 
 def cmd_with_timeout(command: str, timeout_seconds: Optional[float]) -> int:
@@ -133,7 +138,7 @@ def cmd_in_background(command: str, *, silent: bool = False) -> subprocess.Popen
     def kill_subproc() -> None:
         try:
             if subproc.poll() is None:
-                logger.info("At exit handler: killing {}: {}".format(subproc, command))
+                logger.info(f'At exit handler: killing {subproc} ({command})')
                 subproc.terminate()
                 subproc.wait(timeout=10.0)
         except BaseException as be:
diff --git a/tests/exec_utils_test.py b/tests/exec_utils_test.py
new file mode 100755 (executable)
index 0000000..eb179da
--- /dev/null
@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+
+import unittest
+
+import exec_utils
+import unittest_utils
+
+
+class TestExecUtils(unittest.TestCase):
+    def test_cmd_showing_output(self):
+        with unittest_utils.RecordStdout() as record:
+            ret = exec_utils.cmd_showing_output('/usr/bin/printf hello')
+        self.assertEqual('hello', record().readline())
+        self.assertEqual(0, ret)
+        record().close()
+
+    def test_cmd_showing_output_fails(self):
+        with unittest_utils.RecordStdout() as record:
+            ret = exec_utils.cmd_showing_output('/usr/bin/printf hello && false')
+        self.assertEqual('hello', record().readline())
+        self.assertEqual(1, ret)
+        record().close()
+
+    def test_cmd_in_background(self):
+        p = exec_utils.cmd_in_background('sleep 100')
+        self.assertEqual(None, p.poll())
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/thread_utils_test.py b/tests/thread_utils_test.py
new file mode 100755 (executable)
index 0000000..7fcdca8
--- /dev/null
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+
+import threading
+import time
+import unittest
+
+import thread_utils
+import unittest_utils
+
+
+class TestThreadUtils(unittest.TestCase):
+    invocation_count = 0
+
+    @thread_utils.background_thread
+    def background_thread(self, a: int, b: str, stop_event: threading.Event) -> None:
+        while not stop_event.is_set():
+            self.assertEqual(123, a)
+            self.assertEqual('abc', b)
+            time.sleep(0.1)
+
+    def test_background_thread(self):
+        (thread, event) = self.background_thread(123, 'abc')
+        self.assertTrue(thread.is_alive())
+        time.sleep(1.0)
+        event.set()
+        thread.join()
+        self.assertFalse(thread.is_alive())
+
+    @thread_utils.periodically_invoke(period_sec=0.3, stop_after=3)
+    def periodic_invocation_target(self, a: int, b: str):
+        self.assertEqual(123, a)
+        self.assertEqual('abc', b)
+        TestThreadUtils.invocation_count += 1
+
+    def test_periodically_invoke_with_limit(self):
+        TestThreadUtils.invocation_count = 0
+        (thread, event) = self.periodic_invocation_target(123, 'abc')
+        self.assertTrue(thread.is_alive())
+        time.sleep(1.0)
+        self.assertEqual(3, TestThreadUtils.invocation_count)
+        self.assertFalse(thread.is_alive())
+
+    @thread_utils.periodically_invoke(period_sec=0.1, stop_after=None)
+    def forever_periodic_invocation_target(self, a: int, b: str):
+        self.assertEqual(123, a)
+        self.assertEqual('abc', b)
+        TestThreadUtils.invocation_count += 1
+
+    def test_periodically_invoke_runs_forever(self):
+        TestThreadUtils.invocation_count = 0
+        (thread, event) = self.forever_periodic_invocation_target(123, 'abc')
+        self.assertTrue(thread.is_alive())
+        time.sleep(1.0)
+        self.assertTrue(thread.is_alive())
+        time.sleep(1.0)
+        event.set()
+        thread.join()
+        self.assertFalse(thread.is_alive())
+        self.assertTrue(TestThreadUtils.invocation_count >= 19)
+
+
+if __name__ == '__main__':
+    unittest.main()
index d8c85f46fbcaed33864d591458c64a1cebeb162d..51078a4e57ebe9193a3eee4669a1cf33a55bb4e0 100644 (file)
@@ -13,6 +13,19 @@ logger = logging.getLogger(__name__)
 
 
 def current_thread_id() -> str:
+    """Returns a string composed of the parent process' id, the current
+    process' id and the current thread identifier.  The former two are
+    numbers (pids) whereas the latter is a thread id passed during thread
+    creation time.
+
+    >>> ret = current_thread_id()
+    >>> (ppid, pid, tid) = ret.split('/')
+    >>> ppid.isnumeric()
+    True
+    >>> pid.isnumeric()
+    True
+
+    """
     ppid = os.getppid()
     pid = os.getpid()
     tid = threading.current_thread().name
@@ -22,6 +35,26 @@ def current_thread_id() -> str:
 def is_current_thread_main_thread() -> bool:
     """Returns True is the current (calling) thread is the process' main
     thread and False otherwise.
+
+    >>> is_current_thread_main_thread()
+    True
+
+    >>> result = None
+    >>> def thunk():
+    ...     global result
+    ...     result = is_current_thread_main_thread()
+
+    >>> thunk()
+    >>> result
+    True
+
+    >>> import threading
+    >>> thread = threading.Thread(target=thunk)
+    >>> thread.start()
+    >>> thread.join()
+    >>> result
+    False
+
     """
     return threading.current_thread() is threading.main_thread()
 
@@ -136,3 +169,9 @@ def periodically_invoke(
         return wrapper_repeat
 
     return decorator_repeat
+
+
+if __name__ == '__main__':
+    import doctest
+
+    doctest.testmod()
index f4fed35f09fdf29970820bef8566652825327634..e84b4eb929cfb8ac37daf31811b675c9d9d7825e 100644 (file)
@@ -275,6 +275,7 @@ class RecordStdout(object):
     ...     print("This is a test!")
     >>> print({record().readline()})
     {'This is a test!\\n'}
+    >>> record().close()
     """
 
     def __init__(self) -> None:
@@ -301,6 +302,7 @@ class RecordStderr(object):
     ...     print("This is a test!", file=sys.stderr)
     >>> print({record().readline()})
     {'This is a test!\\n'}
+    >>> record().close()
     """
 
     def __init__(self) -> None: