From b3ef553f4f30614b97e23f2d4ad6d6576ec57adf Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 27 Jan 2022 13:42:09 -0800 Subject: [PATCH] Adding test code trying to improve test coverage. --- exec_utils.py | 55 ++++++++++++++++++--------------- tests/exec_utils_test.py | 30 ++++++++++++++++++ tests/thread_utils_test.py | 63 ++++++++++++++++++++++++++++++++++++++ thread_utils.py | 39 +++++++++++++++++++++++ unittest_utils.py | 2 ++ 5 files changed, 164 insertions(+), 25 deletions(-) create mode 100755 tests/exec_utils_test.py create mode 100755 tests/thread_utils_test.py diff --git a/exec_utils.py b/exec_utils.py index 282a325..dcd30a2 100644 --- a/exec_utils.py +++ b/exec_utils.py @@ -2,6 +2,7 @@ import atexit import logging +import os import selectors import shlex import subprocess @@ -21,36 +22,40 @@ def cmd_showing_output( """ line_enders = set([b'\n', b'\r']) - p = subprocess.Popen( + sel = selectors.DefaultSelector() + with subprocess.Popen( command, shell=True, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, - ) - sel = selectors.DefaultSelector() - sel.register(p.stdout, selectors.EVENT_READ) - sel.register(p.stderr, selectors.EVENT_READ) - stream_ends = 0 - while stream_ends < 2: - for key, _ in sel.select(): - char = key.fileobj.read(1) - if not char: - stream_ends += 1 - continue - if key.fileobj is p.stdout: - sys.stdout.buffer.write(char) - if char in line_enders: - sys.stdout.flush() - else: - sys.stderr.buffer.write(char) - if char in line_enders: - sys.stderr.flush() - p.wait() - sys.stdout.flush() - sys.stderr.flush() - return p.returncode + ) as p: + sel.register(p.stdout, selectors.EVENT_READ) + sel.register(p.stderr, selectors.EVENT_READ) + done = False + while not done: + for key, _ in sel.select(): + char = key.fileobj.read(1) + if not char: + sel.unregister(key.fileobj) + if len(sel.get_map()) == 0: + sys.stdout.flush() + sys.stderr.flush() + sel.close() + done = True + if key.fileobj is p.stdout: + # sys.stdout.buffer.write(char) + os.write(sys.stdout.fileno(), char) + if char in line_enders: + sys.stdout.flush() + else: + # sys.stderr.buffer.write(char) + os.write(sys.stderr.fileno(), char) + if char in line_enders: + sys.stderr.flush() + p.wait() + return p.returncode def cmd_with_timeout(command: str, timeout_seconds: Optional[float]) -> int: @@ -133,7 +138,7 @@ def cmd_in_background(command: str, *, silent: bool = False) -> subprocess.Popen def kill_subproc() -> None: try: if subproc.poll() is None: - logger.info("At exit handler: killing {}: {}".format(subproc, command)) + logger.info(f'At exit handler: killing {subproc} ({command})') subproc.terminate() subproc.wait(timeout=10.0) except BaseException as be: diff --git a/tests/exec_utils_test.py b/tests/exec_utils_test.py new file mode 100755 index 0000000..eb179da --- /dev/null +++ b/tests/exec_utils_test.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import unittest + +import exec_utils +import unittest_utils + + +class TestExecUtils(unittest.TestCase): + def test_cmd_showing_output(self): + with unittest_utils.RecordStdout() as record: + ret = exec_utils.cmd_showing_output('/usr/bin/printf hello') + self.assertEqual('hello', record().readline()) + self.assertEqual(0, ret) + record().close() + + def test_cmd_showing_output_fails(self): + with unittest_utils.RecordStdout() as record: + ret = exec_utils.cmd_showing_output('/usr/bin/printf hello && false') + self.assertEqual('hello', record().readline()) + self.assertEqual(1, ret) + record().close() + + def test_cmd_in_background(self): + p = exec_utils.cmd_in_background('sleep 100') + self.assertEqual(None, p.poll()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/thread_utils_test.py b/tests/thread_utils_test.py new file mode 100755 index 0000000..7fcdca8 --- /dev/null +++ b/tests/thread_utils_test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +import threading +import time +import unittest + +import thread_utils +import unittest_utils + + +class TestThreadUtils(unittest.TestCase): + invocation_count = 0 + + @thread_utils.background_thread + def background_thread(self, a: int, b: str, stop_event: threading.Event) -> None: + while not stop_event.is_set(): + self.assertEqual(123, a) + self.assertEqual('abc', b) + time.sleep(0.1) + + def test_background_thread(self): + (thread, event) = self.background_thread(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + event.set() + thread.join() + self.assertFalse(thread.is_alive()) + + @thread_utils.periodically_invoke(period_sec=0.3, stop_after=3) + def periodic_invocation_target(self, a: int, b: str): + self.assertEqual(123, a) + self.assertEqual('abc', b) + TestThreadUtils.invocation_count += 1 + + def test_periodically_invoke_with_limit(self): + TestThreadUtils.invocation_count = 0 + (thread, event) = self.periodic_invocation_target(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + self.assertEqual(3, TestThreadUtils.invocation_count) + self.assertFalse(thread.is_alive()) + + @thread_utils.periodically_invoke(period_sec=0.1, stop_after=None) + def forever_periodic_invocation_target(self, a: int, b: str): + self.assertEqual(123, a) + self.assertEqual('abc', b) + TestThreadUtils.invocation_count += 1 + + def test_periodically_invoke_runs_forever(self): + TestThreadUtils.invocation_count = 0 + (thread, event) = self.forever_periodic_invocation_target(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + event.set() + thread.join() + self.assertFalse(thread.is_alive()) + self.assertTrue(TestThreadUtils.invocation_count >= 19) + + +if __name__ == '__main__': + unittest.main() diff --git a/thread_utils.py b/thread_utils.py index d8c85f4..51078a4 100644 --- a/thread_utils.py +++ b/thread_utils.py @@ -13,6 +13,19 @@ logger = logging.getLogger(__name__) def current_thread_id() -> str: + """Returns a string composed of the parent process' id, the current + process' id and the current thread identifier. The former two are + numbers (pids) whereas the latter is a thread id passed during thread + creation time. + + >>> ret = current_thread_id() + >>> (ppid, pid, tid) = ret.split('/') + >>> ppid.isnumeric() + True + >>> pid.isnumeric() + True + + """ ppid = os.getppid() pid = os.getpid() tid = threading.current_thread().name @@ -22,6 +35,26 @@ def current_thread_id() -> str: def is_current_thread_main_thread() -> bool: """Returns True is the current (calling) thread is the process' main thread and False otherwise. + + >>> is_current_thread_main_thread() + True + + >>> result = None + >>> def thunk(): + ... global result + ... result = is_current_thread_main_thread() + + >>> thunk() + >>> result + True + + >>> import threading + >>> thread = threading.Thread(target=thunk) + >>> thread.start() + >>> thread.join() + >>> result + False + """ return threading.current_thread() is threading.main_thread() @@ -136,3 +169,9 @@ def periodically_invoke( return wrapper_repeat return decorator_repeat + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/unittest_utils.py b/unittest_utils.py index f4fed35..e84b4eb 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -275,6 +275,7 @@ class RecordStdout(object): ... print("This is a test!") >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: @@ -301,6 +302,7 @@ class RecordStderr(object): ... print("This is a test!", file=sys.stderr) >>> print({record().readline()}) {'This is a test!\\n'} + >>> record().close() """ def __init__(self) -> None: -- 2.45.2