Add timeout_seconds to cmd_showing_output.
authorScott Gasch <[email protected]>
Sat, 4 Jun 2022 16:06:07 +0000 (09:06 -0700)
committerScott Gasch <[email protected]>
Sat, 4 Jun 2022 16:06:07 +0000 (09:06 -0700)
exec_utils.py
tests/exec_utils_test.py

index 51aaeb454ccfaeb2fe077303e54014bfec6cdfce..49484c61e40e4bcf9332e213c8afce2a00795e90 100644 (file)
@@ -18,6 +18,8 @@ logger = logging.getLogger(__file__)
 
 def cmd_showing_output(
     command: str,
+    *,
+    timeout_seconds: Optional[float] = None,
 ) -> int:
     """Kick off a child process.  Capture and emit all output that it
     produces on stdout and stderr in a character by character manner
@@ -27,15 +29,22 @@ def cmd_showing_output(
 
     Args:
         command: the command to execute
+        timeout_seconds: terminate the subprocess if it takes longer
+            than N seconds; None means to wait as long as it takes.
 
     Returns:
         the exit status of the subprocess once the subprocess has
-        exited
+        exited.  Raises TimeoutExpired after killing the subprocess
+        if the timeout expires.
 
     Side effects:
         prints all output of the child process (stdout or stderr)
     """
 
+    def timer_expired(p):
+        p.kill()
+        raise subprocess.TimeoutExpired(command, timeout_seconds)
+
     line_enders = set([b'\n', b'\r'])
     sel = selectors.DefaultSelector()
     with subprocess.Popen(
@@ -46,28 +55,38 @@ def cmd_showing_output(
         stderr=subprocess.PIPE,
         universal_newlines=False,
     ) as p:
-        sel.register(p.stdout, selectors.EVENT_READ)  # type: ignore
-        sel.register(p.stderr, selectors.EVENT_READ)  # type: ignore
-        done = False
-        while not done:
-            for key, _ in sel.select():
-                char = key.fileobj.read(1)  # type: ignore
-                if not char:
-                    sel.unregister(key.fileobj)
-                    if len(sel.get_map()) == 0:
-                        sys.stdout.flush()
-                        sys.stderr.flush()
-                        sel.close()
-                        done = True
-                if key.fileobj is p.stdout:
-                    os.write(sys.stdout.fileno(), char)
-                    if char in line_enders:
-                        sys.stdout.flush()
-                else:
-                    os.write(sys.stderr.fileno(), char)
-                    if char in line_enders:
-                        sys.stderr.flush()
-        p.wait()
+        timer = None
+        if timeout_seconds:
+            import threading
+
+            timer = threading.Timer(timeout_seconds, timer_expired(p))
+            timer.start()
+        try:
+            sel.register(p.stdout, selectors.EVENT_READ)  # type: ignore
+            sel.register(p.stderr, selectors.EVENT_READ)  # type: ignore
+            done = False
+            while not done:
+                for key, _ in sel.select():
+                    char = key.fileobj.read(1)  # type: ignore
+                    if not char:
+                        sel.unregister(key.fileobj)
+                        if len(sel.get_map()) == 0:
+                            sys.stdout.flush()
+                            sys.stderr.flush()
+                            sel.close()
+                            done = True
+                    if key.fileobj is p.stdout:
+                        os.write(sys.stdout.fileno(), char)
+                        if char in line_enders:
+                            sys.stdout.flush()
+                    else:
+                        os.write(sys.stderr.fileno(), char)
+                        if char in line_enders:
+                            sys.stderr.flush()
+            p.wait()
+        finally:
+            if timer:
+                timer.cancel()
         return p.returncode
 
 
index 4c003aa4f045d46d939ade972f47027a8b4ae4bd..11dda89350c9f23c99bb0dba0fdf706d9cb65154 100755 (executable)
@@ -4,6 +4,7 @@
 
 """exec_utils unittest."""
 
+import subprocess
 import unittest
 
 import exec_utils
@@ -18,6 +19,14 @@ class TestExecUtils(unittest.TestCase):
         self.assertEqual(0, ret)
         record().close()
 
+    def test_cmd_showing_output_with_timeout(self):
+        try:
+            exec_utils.cmd_showing_output('sleep 10', timeout_seconds=0.1)
+        except subprocess.TimeoutExpired:
+            pass
+        else:
+            self.fail('Expected a TimeoutException, didn\'t see one.')
+
     def test_cmd_showing_output_fails(self):
         with unittest_utils.RecordStdout() as record:
             ret = exec_utils.cmd_showing_output('/usr/bin/printf hello && false')