Code cleanup for run_test.py
[python_utils.git] / tests / run_tests.py
index 7e7bad593da9f8b453626ef018393b9436548833..9dc067782e7d667d25651d86f018cc89d5389499 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 
 """
-A smart, fast test runner.
+A smart, fast test runner.  Used in a git pre-commit hook.
 """
 
 import logging
@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional
 
 from overrides import overrides
 
+import ansi
 import bootstrap
 import config
 import exec_utils
@@ -39,55 +40,109 @@ HOME = os.environ['HOME']
 @dataclass
 class TestingParameters:
     halt_on_error: bool
+    """Should we stop as soon as one error has occurred?"""
+
     halt_event: threading.Event
+    """An event that, when set, indicates to stop ASAP."""
 
 
 @dataclass
 class TestResults:
     name: str
+    """The name of this test / set of tests."""
+
     tests_executed: List[str]
+    """Tests that were executed."""
+
     tests_succeeded: List[str]
+    """Tests that succeeded."""
+
     tests_failed: List[str]
+    """Tests that failed."""
+
     tests_timed_out: List[str]
+    """Tests that timed out."""
+
+    def __add__(self, other):
+        self.tests_executed.extend(other.tests_executed)
+        self.tests_succeeded.extend(other.tests_succeeded)
+        self.tests_failed.extend(other.tests_failed)
+        self.tests_timed_out.extend(other.tests_timed_out)
+        return self
+
+    __radd__ = __add__
+
+    def __repr__(self) -> str:
+        out = f'{self.name}: '
+        out += f'{ansi.fg("green")}'
+        out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
+        out += f'{ansi.reset()}.\n'
+
+        if len(self.tests_failed) > 0:
+            out += f'  ..{ansi.fg("red")}'
+            out += f'{len(self.tests_failed)} tests failed'
+            out += f'{ansi.reset()}:\n'
+            for test in self.tests_failed:
+                out += f'    {test}\n'
+            out += '\n'
+
+        if len(self.tests_timed_out) > 0:
+            out += f'  ..{ansi.fg("yellow")}'
+            out += f'{len(self.tests_timed_out)} tests timed out'
+            out += f'{ansi.reset()}:\n'
+            for test in self.tests_failed:
+                out += f'    {test}\n'
+            out += '\n'
+        return out
 
 
 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
+    """A Base class for something that runs a test."""
+
     def __init__(self, params: TestingParameters):
+        """Create a TestRunner.
+
+        Args:
+            params: Test running paramters.
+
+        """
         super().__init__(self, target=self.begin, args=[params])
         self.params = params
         self.test_results = TestResults(
-            name=f"All {self.get_name()} tests",
+            name=self.get_name(),
             tests_executed=[],
             tests_succeeded=[],
             tests_failed=[],
             tests_timed_out=[],
         )
 
-    def aggregate_test_results(self, result: TestResults):
-        self.test_results.tests_executed.extend(result.tests_executed)
-        self.test_results.tests_succeeded.extend(result.tests_succeeded)
-        self.test_results.tests_failed.extend(result.tests_failed)
-        self.test_results.tests_timed_out.extend(result.tests_timed_out)
-
     @abstractmethod
     def get_name(self) -> str:
+        """The name of this test collection."""
         pass
 
     @abstractmethod
     def begin(self, params: TestingParameters) -> TestResults:
+        """Start execution."""
         pass
 
 
 class TemplatedTestRunner(TestRunner, ABC):
+    """A TestRunner that has a recipe for executing the tests."""
+
     @abstractmethod
-    def identify_tests(self) -> List[Any]:
+    def identify_tests(self) -> List[str]:
+        """Return a list of tests that should be executed."""
         pass
 
     @abstractmethod
     def run_test(self, test: Any) -> TestResults:
+        """Run a single test and return its TestResults."""
         pass
 
     def check_for_abort(self):
+        """Periodically caled to check to see if we need to stop."""
+
         if self.params.halt_event.is_set():
             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
             raise Exception("Kill myself!")
@@ -97,6 +152,8 @@ class TemplatedTestRunner(TestRunner, ABC):
                 raise Exception("Kill myself!")
 
     def status_report(self, running: List[Any], done: List[Any]):
+        """Periodically called to report current status."""
+
         total = len(running) + len(done)
         logging.info(
             '%s: %d/%d in flight; %d/%d completed.',
@@ -108,6 +165,8 @@ class TemplatedTestRunner(TestRunner, ABC):
         )
 
     def persist_output(self, test_name: str, message: str, output: str) -> None:
+        """Called to save the output of a test run."""
+
         basename = file_utils.without_path(test_name)
         dest = f'{basename}-output.txt'
         with open(f'./test_output/{dest}', 'w') as wf:
@@ -122,6 +181,7 @@ class TemplatedTestRunner(TestRunner, ABC):
         *,
         timeout: float = 120.0,
     ) -> TestResults:
+        """Execute a particular commandline to run a test."""
 
         try:
             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
@@ -177,24 +237,26 @@ class TemplatedTestRunner(TestRunner, ABC):
                     newly_finished.append(fut)
                     result = fut._resolve()
                     logger.debug('Test %s finished.', result.name)
-                    self.aggregate_test_results(result)
+                    self.test_results += result
 
             for fut in newly_finished:
                 running.remove(fut)
                 done.append(fut)
-            time.sleep(0.25)
+            time.sleep(1.0)
 
         logger.debug('Thread %s finished.', self.get_name())
         return self.test_results
 
 
 class UnittestTestRunner(TemplatedTestRunner):
+    """Run all known Unittests."""
+
     @overrides
     def get_name(self) -> str:
         return "UnittestTestRunner"
 
     @overrides
-    def identify_tests(self) -> List[Any]:
+    def identify_tests(self) -> List[str]:
         return list(file_utils.expand_globs('*_test.py'))
 
     @par.parallelize
@@ -207,12 +269,14 @@ class UnittestTestRunner(TemplatedTestRunner):
 
 
 class DoctestTestRunner(TemplatedTestRunner):
+    """Run all known Doctests."""
+
     @overrides
     def get_name(self) -> str:
         return "DoctestTestRunner"
 
     @overrides
-    def identify_tests(self) -> List[Any]:
+    def identify_tests(self) -> List[str]:
         ret = []
         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
         for line in out.split('\n'):
@@ -231,12 +295,14 @@ class DoctestTestRunner(TemplatedTestRunner):
 
 
 class IntegrationTestRunner(TemplatedTestRunner):
+    """Run all know Integration tests."""
+
     @overrides
     def get_name(self) -> str:
         return "IntegrationTestRunner"
 
     @overrides
-    def identify_tests(self) -> List[Any]:
+    def identify_tests(self) -> List[str]:
         return list(file_utils.expand_globs('*_itest.py'))
 
     @par.parallelize
@@ -248,28 +314,23 @@ class IntegrationTestRunner(TemplatedTestRunner):
         return self.execute_commandline(test, cmdline)
 
 
-def test_results_report(results: Dict[str, TestResults]):
-    for type, result in results.items():
-        print(text_utils.header(f'{result.name}'))
-        print(f'  Ran {len(result.tests_executed)} tests.')
-        print(f'  ..{len(result.tests_succeeded)} tests succeeded.')
-        if len(result.tests_failed) > 0:
-            print(f'  ..{len(result.tests_failed)} tests failed:')
-            for test in result.tests_failed:
-                print(f'    {test}')
+def test_results_report(results: Dict[str, TestResults]) -> int:
+    """Give a final report about the tests that were run."""
+    total_problems = 0
+    for result in results.values():
+        print(result, end='')
+        total_problems += len(result.tests_failed)
+        total_problems += len(result.tests_timed_out)
 
-        if len(result.tests_timed_out) > 0:
-            print(f'  ..{len(result.tests_timed_out)} tests timed out:')
-            for test in result.tests_failed:
-                print(f'    {test}')
-
-        if len(result.tests_failed) + len(result.tests_timed_out):
-            print('Reminder: look in ./test_output to view test output logs')
+    if total_problems > 0:
+        print('Reminder: look in ./test_output to view test output logs')
+    return total_problems
 
 
 def code_coverage_report():
+    """Give a final code coverage report."""
     text_utils.header('Code Coverage')
-    out = exec_utils.cmd('coverage combine .coverage*')
+    exec_utils.cmd('coverage combine .coverage*')
     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
     print(out)
     print(
@@ -282,7 +343,6 @@ To recall this report w/o re-running the tests:
 run_tests.py with --coverage will klobber previous results.  See:
 
     https://coverage.readthedocs.io/en/6.2/
-
 """
     )
 
@@ -337,10 +397,10 @@ def main() -> Optional[int]:
                             halt_event.set()
         time.sleep(1.0)
 
-    test_results_report(results)
     if config.config['coverage']:
         code_coverage_report()
-    return 0
+    total_problems = test_results_report(results)
+    return total_problems
 
 
 if __name__ == '__main__':