One bugfix and some cosmetics.
[python_utils.git] / tests / run_tests.py
index 9dc067782e7d667d25651d86f018cc89d5389499..5162e238f1d37181b5dc9ea3988e1a443b3231c4 100755 (executable)
@@ -12,7 +12,7 @@ import threading
 import time
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from overrides import overrides
 
@@ -22,6 +22,7 @@ import config
 import exec_utils
 import file_utils
 import parallelize as par
+import smart_future
 import text_utils
 import thread_utils
 
@@ -115,12 +116,17 @@ class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
             tests_failed=[],
             tests_timed_out=[],
         )
+        self.tests_started = 0
 
     @abstractmethod
     def get_name(self) -> str:
         """The name of this test collection."""
         pass
 
+    def get_status(self) -> Tuple[int, TestResults]:
+        """Ask the TestRunner for its status."""
+        return (self.tests_started, self.test_results)
+
     @abstractmethod
     def begin(self, params: TestingParameters) -> TestResults:
         """Start execution."""
@@ -151,19 +157,6 @@ class TemplatedTestRunner(TestRunner, ABC):
                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
                 raise Exception("Kill myself!")
 
-    def status_report(self, running: List[Any], done: List[Any]):
-        """Periodically called to report current status."""
-
-        total = len(running) + len(done)
-        logging.info(
-            '%s: %d/%d in flight; %d/%d completed.',
-            self.get_name(),
-            len(running),
-            total,
-            len(done),
-            total,
-        )
-
     def persist_output(self, test_name: str, message: str, output: str) -> None:
         """Called to save the output of a test run."""
 
@@ -198,7 +191,7 @@ class TemplatedTestRunner(TestRunner, ABC):
             logger.debug(
                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
             )
-            self.persist_output(test_name, msg, e.output)
+            self.persist_output(test_name, msg, e.output.decode('utf-8'))
             return TestResults(
                 test_name,
                 [test_name],
@@ -210,7 +203,7 @@ class TemplatedTestRunner(TestRunner, ABC):
             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
             logger.error(msg)
             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
-            self.persist_output(test_name, msg, e.output)
+            self.persist_output(test_name, msg, e.output.decode('utf-8'))
             return TestResults(
                 test_name,
                 [test_name],
@@ -223,26 +216,17 @@ class TemplatedTestRunner(TestRunner, ABC):
     def begin(self, params: TestingParameters) -> TestResults:
         logger.debug('Thread %s started.', self.get_name())
         interesting_tests = self.identify_tests()
+
         running: List[Any] = []
-        done: List[Any] = []
         for test in interesting_tests:
             running.append(self.run_test(test))
+        self.tests_started = len(running)
 
-        while len(running) > 0:
-            self.status_report(running, done)
+        for future in smart_future.wait_any(running):
             self.check_for_abort()
-            newly_finished = []
-            for fut in running:
-                if fut.is_ready():
-                    newly_finished.append(fut)
-                    result = fut._resolve()
-                    logger.debug('Test %s finished.', result.name)
-                    self.test_results += result
-
-            for fut in newly_finished:
-                running.remove(fut)
-                done.append(fut)
-            time.sleep(1.0)
+            result = future._resolve()
+            logger.debug('Test %s finished.', result.name)
+            self.test_results += result
 
         logger.debug('Thread %s finished.', self.get_name())
         return self.test_results
@@ -253,7 +237,7 @@ class UnittestTestRunner(TemplatedTestRunner):
 
     @overrides
     def get_name(self) -> str:
-        return "UnittestTestRunner"
+        return "Unittests"
 
     @overrides
     def identify_tests(self) -> List[str]:
@@ -273,7 +257,7 @@ class DoctestTestRunner(TemplatedTestRunner):
 
     @overrides
     def get_name(self) -> str:
-        return "DoctestTestRunner"
+        return "Doctests"
 
     @overrides
     def identify_tests(self) -> List[str]:
@@ -299,7 +283,7 @@ class IntegrationTestRunner(TemplatedTestRunner):
 
     @overrides
     def get_name(self) -> str:
-        return "IntegrationTestRunner"
+        return "Integration Tests"
 
     @overrides
     def identify_tests(self) -> List[str]:
@@ -383,6 +367,10 @@ def main() -> Optional[int]:
 
     results: Dict[str, TestResults] = {}
     while len(results) != len(threads):
+        started = 0
+        done = 0
+        failed = 0
+
         for thread in threads:
             if not thread.is_alive():
                 tid = thread.name
@@ -395,7 +383,35 @@ def main() -> Optional[int]:
                                 'Thread %s returned abnormal results; killing the others.', tid
                             )
                             halt_event.set()
-        time.sleep(1.0)
+            else:
+                (s, tr) = thread.get_status()
+                started += s
+                failed += len(tr.tests_failed) + len(tr.tests_timed_out)
+                done += failed + len(tr.tests_succeeded)
+
+        if started > 0:
+            percent_done = done / started
+        else:
+            percent_done = 0.0
+
+        if failed == 0:
+            color = ansi.fg('green')
+        else:
+            color = ansi.fg('red')
+
+        if percent_done < 100.0:
+            print(
+                text_utils.bar_graph(
+                    percent_done,
+                    width=80,
+                    fgcolor=color,
+                ),
+                end='\r',
+                flush=True,
+            )
+        else:
+            print("Finished.\n")
+        time.sleep(0.5)
 
     if config.config['coverage']:
         code_coverage_report()