Cut version 0.0.1b9
[pyutils.git] / tests / run_tests.py
index 06d4a9c839f71013674f907391c2b64eea610d67..464358401ff4966be1071f68bbd25aedf72fd23e 100755 (executable)
@@ -4,6 +4,8 @@
 A smart, fast test runner.  Used in a git pre-commit hook.
 """
 
+from __future__ import annotations
+
 import logging
 import os
 import re
@@ -12,46 +14,45 @@ import threading
 import time
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from overrides import overrides
 
 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
 from pyutils.files import file_utils
-from pyutils.parallelize import deferred_operand
 from pyutils.parallelize import parallelize as par
 from pyutils.parallelize import smart_future, thread_utils
 
 logger = logging.getLogger(__name__)
 args = config.add_commandline_args(
-    f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
+    f"Run Tests Driver ({__file__})", f"Args related to {__file__}"
 )
-args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
-args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
+args.add_argument("--unittests", "-u", action="store_true", help="Run unittests.")
+args.add_argument("--doctests", "-d", action="store_true", help="Run doctests.")
 args.add_argument(
-    '--integration', '-i', action='store_true', help='Run integration tests.'
+    "--integration", "-i", action="store_true", help="Run integration tests."
 )
 args.add_argument(
-    '--all',
-    '-a',
-    action='store_true',
-    help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
+    "--all",
+    "-a",
+    action="store_true",
+    help="Run unittests, doctests and integration tests.  Equivalient to -u -d -i",
 )
 args.add_argument(
-    '--coverage',
-    '-c',
-    action='store_true',
-    help='Run tests and capture code coverage data',
+    "--coverage",
+    "-c",
+    action="store_true",
+    help="Run tests and capture code coverage data",
 )
 
-HOME = os.environ['HOME']
+HOME = os.environ["HOME"]
 
 # These tests will be run twice in --coverage mode: once to get code
 # coverage and then again with not coverage enabeled.  This is because
 # they pay attention to code performance which is adversely affected
 # by coverage.
-PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
-TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
+PERF_SENSATIVE_TESTS = set(["string_utils_test.py"])
+TESTS_TO_SKIP = set(["zookeeper_test.py", "zookeeper.py", "run_tests.py"])
 
 ROOT = ".."
 
@@ -107,29 +108,81 @@ class TestResults:
 
     __radd__ = __add__
 
+    @staticmethod
+    def empty_test_results(suite_name: str) -> TestResults:
+        return TestResults(
+            name=suite_name,
+            tests_executed={},
+            tests_succeeded=[],
+            tests_failed=[],
+            tests_timed_out=[],
+        )
+
+    @staticmethod
+    def single_test_succeeded(name: str) -> TestResults:
+        return TestResults(name, {}, [name], [], [])
+
+    @staticmethod
+    def single_test_failed(name: str) -> TestResults:
+        return TestResults(
+            name,
+            {},
+            [],
+            [name],
+            [],
+        )
+
+    @staticmethod
+    def single_test_timed_out(name: str) -> TestResults:
+        return TestResults(
+            name,
+            {},
+            [],
+            [],
+            [name],
+        )
+
     def __repr__(self) -> str:
-        out = f'{self.name}: '
+        out = f"{self.name}: "
         out += f'{ansi.fg("green")}'
-        out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
-        out += f'{ansi.reset()}.\n'
+        out += f"{len(self.tests_succeeded)}/{len(self.tests_executed)} passed"
+        out += f"{ansi.reset()}.\n"
+        tests_with_known_status = len(self.tests_succeeded)
 
         if len(self.tests_failed) > 0:
             out += f'  ..{ansi.fg("red")}'
-            out += f'{len(self.tests_failed)} tests failed'
-            out += f'{ansi.reset()}:\n'
+            out += f"{len(self.tests_failed)} tests failed"
+            out += f"{ansi.reset()}:\n"
             for test in self.tests_failed:
-                out += f'    {test}\n'
-            out += '\n'
+                out += f"    {test}\n"
+            tests_with_known_status += len(self.tests_failed)
 
         if len(self.tests_timed_out) > 0:
-            out += f'  ..{ansi.fg("yellow")}'
-            out += f'{len(self.tests_timed_out)} tests timed out'
-            out += f'{ansi.reset()}:\n'
+            out += f'  ..{ansi.fg("lightning yellow")}'
+            out += f"{len(self.tests_timed_out)} tests timed out"
+            out += f"{ansi.reset()}:\n"
             for test in self.tests_failed:
-                out += f'    {test}\n'
-            out += '\n'
+                out += f"    {test}\n"
+            tests_with_known_status += len(self.tests_timed_out)
+
+        missing = len(self.tests_executed) - tests_with_known_status
+        if missing:
+            out += f'  ..{ansi.fg("lightning yellow")}'
+            out += f"{missing} tests aborted early"
+            out += f"{ansi.reset()}\n"
         return out
 
+    def _key(self) -> Tuple[str, Tuple, Tuple, Tuple]:
+        return (
+            self.name,
+            tuple(self.tests_succeeded),
+            tuple(self.tests_failed),
+            tuple(self.tests_timed_out),
+        )
+
+    def __hash__(self) -> int:
+        return hash(self._key())
+
 
 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
     """A Base class for something that runs a test."""
@@ -143,13 +196,7 @@ class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
         """
         super().__init__(self, target=self.begin, args=[params])
         self.params = params
-        self.test_results = TestResults(
-            name=self.get_name(),
-            tests_executed={},
-            tests_succeeded=[],
-            tests_failed=[],
-            tests_timed_out=[],
-        )
+        self.test_results = TestResults.empty_test_results(self.get_name())
         self.lock = threading.Lock()
 
     @abstractmethod
@@ -171,6 +218,15 @@ class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
 class TemplatedTestRunner(TestRunner, ABC):
     """A TestRunner that has a recipe for executing the tests."""
 
+    def __init__(self, params: TestingParameters):
+        super().__init__(params)
+
+        # Note: because of @parallelize on run_tests it actually
+        # returns a SmartFuture with a TestResult inside of it.
+        # That's the reason for this Any business.
+        self.running: List[Any] = []
+        self.already_cancelled = False
+
     @abstractmethod
     def identify_tests(self) -> List[TestToRun]:
         """Return a list of tuples (test, cmdline) that should be executed."""
@@ -182,24 +238,24 @@ class TemplatedTestRunner(TestRunner, ABC):
         pass
 
     def check_for_abort(self) -> bool:
-        """Periodically caled to check to see if we need to stop."""
+        """Periodically called to check to see if we need to stop."""
 
         if self.params.halt_event.is_set():
-            logger.debug('Thread %s saw halt event; exiting.', self.get_name())
+            logger.debug("Thread %s saw halt event; exiting.", self.get_name())
             return True
 
         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
-            logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
+            logger.debug("Thread %s saw abnormal results; exiting.", self.get_name())
             return True
         return False
 
     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
         """Called to save the output of a test run."""
 
-        dest = f'{test.name}-output.txt'
-        with open(f'./test_output/{dest}', 'w') as wf:
+        dest = f"{test.name}-output.txt"
+        with open(f"./test_output/{dest}", "w") as wf:
             print(message, file=wf)
-            print('-' * len(message), file=wf)
+            print("-" * len(message), file=wf)
             wf.write(output)
 
     def execute_commandline(
@@ -210,99 +266,87 @@ class TemplatedTestRunner(TestRunner, ABC):
     ) -> TestResults:
         """Execute a particular commandline to run a test."""
 
+        msg = f"{self.get_name()}: {test.name} ({test.cmdline}) "
         try:
             output = exec_utils.cmd(
                 test.cmdline,
                 timeout_seconds=timeout,
             )
             if "***Test Failed***" in output:
-                msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
+                msg += "failed; doctest failure message detected."
                 logger.error(msg)
                 self.persist_output(test, msg, output)
-                return TestResults(
-                    test.name,
-                    {},
-                    [],
-                    [test.name],
-                    [],
-                )
+                return TestResults.single_test_failed(test.name)
+
+            msg += "succeeded."
+            self.persist_output(test, msg, output)
+            logger.debug(msg)
+            return TestResults.single_test_succeeded(test.name)
 
-            self.persist_output(
-                test, f'{test.name} ({test.cmdline}) succeeded.', output
-            )
-            logger.debug(
-                '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
-            )
-            return TestResults(test.name, {}, [test.name], [], [])
         except subprocess.TimeoutExpired as e:
-            msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
+            msg += f"timed out after {e.timeout:.1f} seconds."
             logger.error(msg)
             logger.debug(
-                '%s: %s output when it timed out: %s',
+                "%s: %s output when it timed out: %s",
                 self.get_name(),
                 test.name,
                 e.output,
             )
-            self.persist_output(test, msg, e.output.decode('utf-8'))
-            return TestResults(
-                test.name,
-                {},
-                [],
-                [],
-                [test.name],
-            )
+            self.persist_output(test, msg, e.output.decode("utf-8"))
+            return TestResults.single_test_timed_out(test.name)
+
         except subprocess.CalledProcessError as e:
-            msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
+            msg += f"failed with exit code {e.returncode}."
             logger.error(msg)
             logger.debug(
-                '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
+                "%s: %s output when it failed: %s", self.get_name(), test.name, e.output
             )
-            self.persist_output(test, msg, e.output.decode('utf-8'))
-            return TestResults(
-                test.name,
-                {},
-                [],
-                [test.name],
-                [],
+            self.persist_output(test, msg, e.output.decode("utf-8"))
+            return TestResults.single_test_failed(test.name)
+
+    def callback(self):
+        if not self.already_cancelled and self.check_for_abort():
+            logger.debug(
+                "%s: aborting %d running futures to exit early.",
+                self.get_name(),
+                len(self.running),
             )
+            for x in self.running:
+                x.wrapped_future.cancel()
 
     @overrides
     def begin(self, params: TestingParameters) -> TestResults:
-        logger.debug('Thread %s started.', self.get_name())
+        logger.debug("Thread %s started.", self.get_name())
         interesting_tests = self.identify_tests()
         logger.debug(
-            '%s: Identified %d tests to be run.',
+            "%s: Identified %d tests to be run.",
             self.get_name(),
             len(interesting_tests),
         )
 
-        # Note: because of @parallelize on run_tests it actually
-        # returns a SmartFuture with a TestResult inside of it.
-        # That's the reason for this Any business.
-        running: List[Any] = []
         for test_to_run in interesting_tests:
-            running.append(self.run_test(test_to_run))
+            self.running.append(self.run_test(test_to_run))
             logger.debug(
-                '%s: Test %s started in the background.',
+                "%s: Test %s started in the background.",
                 self.get_name(),
                 test_to_run.name,
             )
             self.test_results.tests_executed[test_to_run.name] = time.time()
 
-        for future in smart_future.wait_any(running, log_exceptions=False):
-            result = deferred_operand.DeferredOperand.resolve(future)
-            logger.debug('Test %s finished.', result.name)
+        already_seen = set()
+        for result in smart_future.wait_any(
+            self.running, timeout=1.0, callback=self.callback, log_exceptions=False
+        ):
+            if result and result not in already_seen:
+                logger.debug("Test %s finished.", result.name)
+                self.test_results += result
+                already_seen.add(result)
 
-            # We sometimes run the same test more than once.  Do not allow
-            # one run's results to klobber the other's.
-            self.test_results += result
             if self.check_for_abort():
-                logger.debug(
-                    '%s: check_for_abort told us to exit early.', self.get_name()
-                )
+                logger.error("%s: exiting early.", self.get_name())
                 return self.test_results
 
-        logger.debug('Thread %s finished running all tests', self.get_name())
+        logger.debug("%s: executed all tests and returning normally", self.get_name())
         return self.test_results
 
 
@@ -316,32 +360,32 @@ class UnittestTestRunner(TemplatedTestRunner):
     @overrides
     def identify_tests(self) -> List[TestToRun]:
         ret = []
-        for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
+        for test in file_utils.get_matching_files_recursive(ROOT, "*_test.py"):
             basename = file_utils.without_path(test)
             if basename in TESTS_TO_SKIP:
                 continue
-            if config.config['coverage']:
+            if config.config["coverage"]:
                 ret.append(
                     TestToRun(
                         name=basename,
-                        kind='unittest capturing coverage',
-                        cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
+                        kind="unittest capturing coverage",
+                        cmdline=f"coverage run --source ../src {test} --unittests_ignore_perf 2>&1",
                     )
                 )
                 if basename in PERF_SENSATIVE_TESTS:
                     ret.append(
                         TestToRun(
-                            name=f'{basename}_no_coverage',
-                            kind='unittest w/o coverage to record perf',
-                            cmdline=f'{test} 2>&1',
+                            name=f"{basename}_no_coverage",
+                            kind="unittest w/o coverage to record perf",
+                            cmdline=f"{test} 2>&1",
                         )
                     )
             else:
                 ret.append(
                     TestToRun(
                         name=basename,
-                        kind='unittest',
-                        cmdline=f'{test} 2>&1',
+                        kind="unittest",
+                        cmdline=f"{test} 2>&1",
                     )
                 )
         return ret
@@ -362,33 +406,33 @@ class DoctestTestRunner(TemplatedTestRunner):
     def identify_tests(self) -> List[TestToRun]:
         ret = []
         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
-        for test in out.split('\n'):
-            if re.match(r'.*\.py$', test):
+        for test in out.split("\n"):
+            if re.match(r".*\.py$", test):
                 basename = file_utils.without_path(test)
                 if basename in TESTS_TO_SKIP:
                     continue
-                if config.config['coverage']:
+                if config.config["coverage"]:
                     ret.append(
                         TestToRun(
                             name=basename,
-                            kind='doctest capturing coverage',
-                            cmdline=f'coverage run --source ../src {test} 2>&1',
+                            kind="doctest capturing coverage",
+                            cmdline=f"coverage run --source ../src {test} 2>&1",
                         )
                     )
                     if basename in PERF_SENSATIVE_TESTS:
                         ret.append(
                             TestToRun(
-                                name=f'{basename}_no_coverage',
-                                kind='doctest w/o coverage to record perf',
-                                cmdline=f'python3 {test} 2>&1',
+                                name=f"{basename}_no_coverage",
+                                kind="doctest w/o coverage to record perf",
+                                cmdline=f"python3 {test} 2>&1",
                             )
                         )
                 else:
                     ret.append(
                         TestToRun(
                             name=basename,
-                            kind='doctest',
-                            cmdline=f'python3 {test} 2>&1',
+                            kind="doctest",
+                            cmdline=f"python3 {test} 2>&1",
                         )
                     )
         return ret
@@ -408,30 +452,30 @@ class IntegrationTestRunner(TemplatedTestRunner):
     @overrides
     def identify_tests(self) -> List[TestToRun]:
         ret = []
-        for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
+        for test in file_utils.get_matching_files_recursive(ROOT, "*_itest.py"):
             basename = file_utils.without_path(test)
             if basename in TESTS_TO_SKIP:
                 continue
-            if config.config['coverage']:
+            if config.config["coverage"]:
                 ret.append(
                     TestToRun(
                         name=basename,
-                        kind='integration test capturing coverage',
-                        cmdline=f'coverage run --source ../src {test} 2>&1',
+                        kind="integration test capturing coverage",
+                        cmdline=f"coverage run --source ../src {test} 2>&1",
                     )
                 )
                 if basename in PERF_SENSATIVE_TESTS:
                     ret.append(
                         TestToRun(
-                            name=f'{basename}_no_coverage',
-                            kind='integration test w/o coverage to capture perf',
-                            cmdline=f'{test} 2>&1',
+                            name=f"{basename}_no_coverage",
+                            kind="integration test w/o coverage to capture perf",
+                            cmdline=f"{test} 2>&1",
                         )
                     )
             else:
                 ret.append(
                     TestToRun(
-                        name=basename, kind='integration test', cmdline=f'{test} 2>&1'
+                        name=basename, kind="integration test", cmdline=f"{test} 2>&1"
                     )
                 )
         return ret
@@ -446,30 +490,32 @@ def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
     total_problems = 0
     for result in results.values():
         if result is None:
-            print('Unexpected unhandled exception in test runner!!!')
+            print("Unexpected unhandled exception in test runner!!!")
             total_problems += 1
         else:
-            print(result, end='')
+            print(result, end="")
             total_problems += len(result.tests_failed)
             total_problems += len(result.tests_timed_out)
 
     if total_problems > 0:
-        print('Reminder: look in ./test_output to view test output logs')
+        print(
+            f"{ansi.bold()}Test output / logging can be found under ./test_output{ansi.reset()}"
+        )
     return total_problems
 
 
 def code_coverage_report():
     """Give a final code coverage report."""
-    text_utils.header('Code Coverage')
-    exec_utils.cmd('coverage combine .coverage*')
+    text_utils.header("Code Coverage")
+    exec_utils.cmd("coverage combine .coverage*")
     out = exec_utils.cmd(
-        'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
+        "coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover"
     )
     print(out)
     print(
-        """To recall this report w/o re-running the tests:
+        f"""To recall this report w/o re-running the tests:
 
-    $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
+    $ {ansi.bold()}coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover{ansi.reset()}
 
 ...from the 'tests' directory.  Note that subsequent calls to
 run_tests.py with --coverage will klobber previous results.  See:
@@ -482,40 +528,38 @@ run_tests.py with --coverage will klobber previous results.  See:
 @bootstrap.initialize
 def main() -> Optional[int]:
     saw_flag = False
-    halt_event = threading.Event()
     threads: List[TestRunner] = []
 
+    halt_event = threading.Event()
     halt_event.clear()
     params = TestingParameters(
         halt_on_error=True,
         halt_event=halt_event,
     )
 
-    if config.config['coverage']:
+    if config.config["coverage"]:
         logger.debug('Clearing existing coverage data via "coverage erase".')
-        exec_utils.cmd('coverage erase')
-
-    if config.config['unittests'] or config.config['all']:
+        exec_utils.cmd("coverage erase")
+    if config.config["unittests"] or config.config["all"]:
         saw_flag = True
         threads.append(UnittestTestRunner(params))
-    if config.config['doctests'] or config.config['all']:
+    if config.config["doctests"] or config.config["all"]:
         saw_flag = True
         threads.append(DoctestTestRunner(params))
-    if config.config['integration'] or config.config['all']:
+    if config.config["integration"] or config.config["all"]:
         saw_flag = True
         threads.append(IntegrationTestRunner(params))
 
     if not saw_flag:
         config.print_usage()
-        print('ERROR: one of --unittests, --doctests or --integration is required.')
-        return 1
+        config.error("One of --unittests, --doctests or --integration is required.", 1)
 
     for thread in threads:
         thread.start()
 
-    results: Dict[str, Optional[TestResults]] = {}
     start_time = time.time()
     last_update = start_time
+    results: Dict[str, Optional[TestResults]] = {}
     still_running = {}
 
     while len(results) != len(threads):
@@ -538,6 +582,7 @@ def main() -> Optional[int]:
             }
             still_running[tid] = running_with_start_time
 
+            # Maybe print tests that are still running.
             now = time.time()
             if now - start_time > 5.0:
                 if now - last_update > 3.0:
@@ -546,38 +591,38 @@ def main() -> Optional[int]:
                     for _, running_dict in still_running.items():
                         for test_name, start_time in running_dict.items():
                             if now - start_time > 10.0:
-                                update.append(f'{test_name}@{now-start_time:.1f}s')
+                                update.append(f"{test_name}@{now-start_time:.1f}s")
                             else:
                                 update.append(test_name)
-                    print(f'\r{ansi.clear_line()}')
+                    print(f"\r{ansi.clear_line()}")
                     if len(update) < 4:
                         print(f'Still running: {",".join(update)}')
                     else:
-                        print(f'Still running: {len(update)} tests.')
+                        print(f"Still running: {len(update)} tests.")
 
+            # Maybe signal the other threads to stop too.
             if not thread.is_alive():
                 if tid not in results:
                     result = thread.join()
                     if result:
                         results[tid] = result
-                        if len(result.tests_failed) > 0:
+                        if (len(result.tests_failed) + len(result.tests_timed_out)) > 0:
                             logger.error(
-                                'Thread %s returned abnormal results; killing the others.',
-                                tid,
+                                "Thread %s returned abnormal results; killing the others.",
+                                thread.get_name(),
                             )
                             halt_event.set()
                     else:
                         logger.error(
-                            'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
+                            "Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.",
                             tid,
                         )
                         halt_event.set()
                         results[tid] = None
 
-        if failed == 0:
-            color = ansi.fg('green')
-        else:
-            color = ansi.fg('red')
+        color = ansi.fg("green")
+        if failed > 0:
+            color = ansi.fg("red")
 
         if started > 0:
             percent_done = done / started * 100.0
@@ -590,24 +635,26 @@ def main() -> Optional[int]:
                     done,
                     started,
                     text=text_utils.BarGraphText.FRACTION,
-                    width=80,
+                    width=72,
                     fgcolor=color,
                 ),
-                end='\r',
+                end="",
                 flush=True,
             )
-        time.sleep(0.5)
+            print(f"  {color}{now - start_time:.1f}s{ansi.reset()}", end="\r")
+        time.sleep(0.1)
 
-    print(f'{ansi.clear_line()}Final Report:')
-    if config.config['coverage']:
+    print(f"{ansi.clear_line()}\n{ansi.underline()}Final Report:{ansi.reset()}")
+    if config.config["coverage"]:
         code_coverage_report()
+    print(f"Test suite runtime: {time.time() - start_time:.1f}s")
     total_problems = test_results_report(results)
     if total_problems > 0:
         logging.error(
-            'Exiting with non-zero return code %d due to problems.', total_problems
+            "Exiting with non-zero return code %d due to problems.", total_problems
         )
     return total_problems
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()