4 A smart, fast test runner. Used in a git pre-commit hook.
7 from __future__ import annotations
15 from abc import ABC, abstractmethod
16 from dataclasses import dataclass
17 from typing import Any, Dict, List, Optional, Tuple
19 from overrides import overrides
21 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
22 from pyutils.files import file_utils
23 from pyutils.parallelize import parallelize as par
24 from pyutils.parallelize import smart_future, thread_utils
26 logger = logging.getLogger(__name__)
27 args = config.add_commandline_args(
28 f"Run Tests Driver ({__file__})", f"Args related to {__file__}"
30 args.add_argument("--unittests", "-u", action="store_true", help="Run unittests.")
31 args.add_argument("--doctests", "-d", action="store_true", help="Run doctests.")
33 "--integration", "-i", action="store_true", help="Run integration tests."
39 help="Run unittests, doctests and integration tests. Equivalient to -u -d -i",
45 help="Run tests and capture code coverage data",
48 HOME = os.environ["HOME"]
50 # These tests will be run twice in --coverage mode: once to get code
51 # coverage and then again with not coverage enabeled. This is because
52 # they pay attention to code performance which is adversely affected
54 PERF_SENSATIVE_TESTS = set(["string_utils_test.py"])
55 TESTS_TO_SKIP = set(["zookeeper_test.py", "zookeeper.py", "run_tests.py"])
61 class TestingParameters:
63 """Should we stop as soon as one error has occurred?"""
65 halt_event: threading.Event
66 """An event that, when set, indicates to stop ASAP."""
72 """The name of the test"""
75 """The kind of the test"""
78 """The command line to execute"""
84 """The name of this test / set of tests."""
86 tests_executed: Dict[str, float]
87 """Tests that were executed."""
89 tests_succeeded: List[str]
90 """Tests that succeeded."""
92 tests_failed: List[str]
93 """Tests that failed."""
95 tests_timed_out: List[str]
96 """Tests that timed out."""
98 def __add__(self, other):
99 merged = dict_utils.coalesce(
100 [self.tests_executed, other.tests_executed],
101 aggregation_function=dict_utils.raise_on_duplicated_keys,
103 self.tests_executed = merged
104 self.tests_succeeded.extend(other.tests_succeeded)
105 self.tests_failed.extend(other.tests_failed)
106 self.tests_timed_out.extend(other.tests_timed_out)
112 def empty_test_results(suite_name: str) -> TestResults:
122 def single_test_succeeded(name: str) -> TestResults:
123 return TestResults(name, {}, [name], [], [])
126 def single_test_failed(name: str) -> TestResults:
136 def single_test_timed_out(name: str) -> TestResults:
145 def __repr__(self) -> str:
146 out = f"{self.name}: "
147 out += f'{ansi.fg("green")}'
148 out += f"{len(self.tests_succeeded)}/{len(self.tests_executed)} passed"
149 out += f"{ansi.reset()}.\n"
150 tests_with_known_status = len(self.tests_succeeded)
152 if len(self.tests_failed) > 0:
153 out += f' ..{ansi.fg("red")}'
154 out += f"{len(self.tests_failed)} tests failed"
155 out += f"{ansi.reset()}:\n"
156 for test in self.tests_failed:
158 tests_with_known_status += len(self.tests_failed)
160 if len(self.tests_timed_out) > 0:
161 out += f' ..{ansi.fg("lightning yellow")}'
162 out += f"{len(self.tests_timed_out)} tests timed out"
163 out += f"{ansi.reset()}:\n"
164 for test in self.tests_failed:
166 tests_with_known_status += len(self.tests_timed_out)
168 missing = len(self.tests_executed) - tests_with_known_status
170 out += f' ..{ansi.fg("lightning yellow")}'
171 out += f"{missing} tests aborted early"
172 out += f"{ansi.reset()}\n"
175 def _key(self) -> Tuple[str, Tuple, Tuple, Tuple]:
178 tuple(self.tests_succeeded),
179 tuple(self.tests_failed),
180 tuple(self.tests_timed_out),
183 def __hash__(self) -> int:
184 return hash(self._key())
187 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
188 """A Base class for something that runs a test."""
190 def __init__(self, params: TestingParameters):
191 """Create a TestRunner.
194 params: Test running paramters.
197 super().__init__(self, target=self.begin, args=[params])
199 self.test_results = TestResults.empty_test_results(self.get_name())
200 self.lock = threading.Lock()
203 def get_name(self) -> str:
204 """The name of this test collection."""
207 def get_status(self) -> TestResults:
208 """Ask the TestRunner for its status."""
210 return self.test_results
213 def begin(self, params: TestingParameters) -> TestResults:
214 """Start execution."""
218 class TemplatedTestRunner(TestRunner, ABC):
219 """A TestRunner that has a recipe for executing the tests."""
221 def __init__(self, params: TestingParameters):
222 super().__init__(params)
224 # Note: because of @parallelize on run_tests it actually
225 # returns a SmartFuture with a TestResult inside of it.
226 # That's the reason for this Any business.
227 self.running: List[Any] = []
228 self.already_cancelled = False
231 def identify_tests(self) -> List[TestToRun]:
232 """Return a list of tuples (test, cmdline) that should be executed."""
236 def run_test(self, test: TestToRun) -> TestResults:
237 """Run a single test and return its TestResults."""
240 def check_for_abort(self) -> bool:
241 """Periodically called to check to see if we need to stop."""
243 if self.params.halt_event.is_set():
244 logger.debug("Thread %s saw halt event; exiting.", self.get_name())
247 if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
248 logger.debug("Thread %s saw abnormal results; exiting.", self.get_name())
252 def persist_output(self, test: TestToRun, message: str, output: str) -> None:
253 """Called to save the output of a test run."""
255 dest = f"{test.name}-output.txt"
256 with open(f"./test_output/{dest}", "w") as wf:
257 print(message, file=wf)
258 print("-" * len(message), file=wf)
261 def execute_commandline(
265 timeout: float = 120.0,
267 """Execute a particular commandline to run a test."""
269 msg = f"{self.get_name()}: {test.name} ({test.cmdline}) "
271 output = exec_utils.cmd(
273 timeout_seconds=timeout,
275 if "***Test Failed***" in output:
276 msg += "failed; doctest failure message detected."
278 self.persist_output(test, msg, output)
279 return TestResults.single_test_failed(test.name)
282 self.persist_output(test, msg, output)
284 return TestResults.single_test_succeeded(test.name)
286 except subprocess.TimeoutExpired as e:
287 msg += f"timed out after {e.timeout:.1f} seconds."
290 "%s: %s output when it timed out: %s",
295 self.persist_output(test, msg, e.output.decode("utf-8"))
296 return TestResults.single_test_timed_out(test.name)
298 except subprocess.CalledProcessError as e:
299 msg += f"failed with exit code {e.returncode}."
302 "%s: %s output when it failed: %s", self.get_name(), test.name, e.output
304 self.persist_output(test, msg, e.output.decode("utf-8"))
305 return TestResults.single_test_failed(test.name)
308 if not self.already_cancelled and self.check_for_abort():
310 "%s: aborting %d running futures to exit early.",
314 for x in self.running:
315 x.wrapped_future.cancel()
318 def begin(self, params: TestingParameters) -> TestResults:
319 logger.debug("Thread %s started.", self.get_name())
320 interesting_tests = self.identify_tests()
322 "%s: Identified %d tests to be run.",
324 len(interesting_tests),
327 for test_to_run in interesting_tests:
328 self.running.append(self.run_test(test_to_run))
330 "%s: Test %s started in the background.",
334 self.test_results.tests_executed[test_to_run.name] = time.time()
337 for result in smart_future.wait_any(
338 self.running, timeout=1.0, callback=self.callback, log_exceptions=False
340 if result and result not in already_seen:
341 logger.debug("Test %s finished.", result.name)
342 self.test_results += result
343 already_seen.add(result)
345 if self.check_for_abort():
346 logger.error("%s: exiting early.", self.get_name())
347 return self.test_results
349 logger.debug("%s: executed all tests and returning normally", self.get_name())
350 return self.test_results
353 class UnittestTestRunner(TemplatedTestRunner):
354 """Run all known Unittests."""
357 def get_name(self) -> str:
361 def identify_tests(self) -> List[TestToRun]:
363 for test in file_utils.get_matching_files_recursive(ROOT, "*_test.py"):
364 basename = file_utils.without_path(test)
365 if basename in TESTS_TO_SKIP:
367 if config.config["coverage"]:
371 kind="unittest capturing coverage",
372 cmdline=f"coverage run --source ../src {test} --unittests_ignore_perf 2>&1",
375 if basename in PERF_SENSATIVE_TESTS:
378 name=f"{basename}_no_coverage",
379 kind="unittest w/o coverage to record perf",
380 cmdline=f"{test} 2>&1",
388 cmdline=f"{test} 2>&1",
394 def run_test(self, test: TestToRun) -> TestResults:
395 return self.execute_commandline(test)
398 class DoctestTestRunner(TemplatedTestRunner):
399 """Run all known Doctests."""
402 def get_name(self) -> str:
406 def identify_tests(self) -> List[TestToRun]:
408 out = exec_utils.cmd(f'/usr/bin/grep -lR "^ *import doctest" {ROOT}/*')
409 for test in out.split("\n"):
410 if re.match(r".*\.py$", test):
411 basename = file_utils.without_path(test)
412 if basename in TESTS_TO_SKIP:
414 if config.config["coverage"]:
418 kind="doctest capturing coverage",
419 cmdline=f"coverage run --source ../src {test} 2>&1",
422 if basename in PERF_SENSATIVE_TESTS:
425 name=f"{basename}_no_coverage",
426 kind="doctest w/o coverage to record perf",
427 cmdline=f"python3 {test} 2>&1",
435 cmdline=f"python3 {test} 2>&1",
441 def run_test(self, test: TestToRun) -> TestResults:
442 return self.execute_commandline(test)
445 class IntegrationTestRunner(TemplatedTestRunner):
446 """Run all know Integration tests."""
449 def get_name(self) -> str:
450 return "Integration Tests"
453 def identify_tests(self) -> List[TestToRun]:
455 for test in file_utils.get_matching_files_recursive(ROOT, "*_itest.py"):
456 basename = file_utils.without_path(test)
457 if basename in TESTS_TO_SKIP:
459 if config.config["coverage"]:
463 kind="integration test capturing coverage",
464 cmdline=f"coverage run --source ../src {test} 2>&1",
467 if basename in PERF_SENSATIVE_TESTS:
470 name=f"{basename}_no_coverage",
471 kind="integration test w/o coverage to capture perf",
472 cmdline=f"{test} 2>&1",
478 name=basename, kind="integration test", cmdline=f"{test} 2>&1"
484 def run_test(self, test: TestToRun) -> TestResults:
485 return self.execute_commandline(test)
488 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
489 """Give a final report about the tests that were run."""
491 for result in results.values():
493 print("Unexpected unhandled exception in test runner!!!")
496 print(result, end="")
497 total_problems += len(result.tests_failed)
498 total_problems += len(result.tests_timed_out)
500 if total_problems > 0:
502 f"{ansi.bold()}Test output / logging can be found under ./test_output{ansi.reset()}"
504 return total_problems
507 def code_coverage_report():
508 """Give a final code coverage report."""
509 text_utils.header("Code Coverage")
510 exec_utils.cmd("coverage combine .coverage*")
511 out = exec_utils.cmd(
512 "coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover"
516 f"""To recall this report w/o re-running the tests:
518 $ {ansi.bold()}coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover{ansi.reset()}
520 ...from the 'tests' directory. Note that subsequent calls to
521 run_tests.py with --coverage will klobber previous results. See:
523 https://coverage.readthedocs.io/en/6.2/
528 @bootstrap.initialize
529 def main() -> Optional[int]:
531 threads: List[TestRunner] = []
533 halt_event = threading.Event()
535 params = TestingParameters(
537 halt_event=halt_event,
540 if config.config["coverage"]:
541 logger.debug('Clearing existing coverage data via "coverage erase".')
542 exec_utils.cmd("coverage erase")
543 if config.config["unittests"] or config.config["all"]:
545 threads.append(UnittestTestRunner(params))
546 if config.config["doctests"] or config.config["all"]:
548 threads.append(DoctestTestRunner(params))
549 if config.config["integration"] or config.config["all"]:
551 threads.append(IntegrationTestRunner(params))
555 config.error("One of --unittests, --doctests or --integration is required.", 1)
557 for thread in threads:
560 start_time = time.time()
561 last_update = start_time
562 results: Dict[str, Optional[TestResults]] = {}
565 while len(results) != len(threads):
570 for thread in threads:
572 tr = thread.get_status()
573 started += len(tr.tests_executed)
574 failed += len(tr.tests_failed) + len(tr.tests_timed_out)
575 done += failed + len(tr.tests_succeeded)
576 running = set(tr.tests_executed.keys())
577 running -= set(tr.tests_failed)
578 running -= set(tr.tests_succeeded)
579 running -= set(tr.tests_timed_out)
580 running_with_start_time = {
581 test: tr.tests_executed[test] for test in running
583 still_running[tid] = running_with_start_time
585 # Maybe print tests that are still running.
587 if now - start_time > 5.0:
588 if now - last_update > 3.0:
591 for _, running_dict in still_running.items():
592 for test_name, start_time in running_dict.items():
593 if now - start_time > 10.0:
594 update.append(f"{test_name}@{now-start_time:.1f}s")
596 update.append(test_name)
597 print(f"\r{ansi.clear_line()}")
599 print(f'Still running: {",".join(update)}')
601 print(f"Still running: {len(update)} tests.")
603 # Maybe signal the other threads to stop too.
604 if not thread.is_alive():
605 if tid not in results:
606 result = thread.join()
608 results[tid] = result
609 if (len(result.tests_failed) + len(result.tests_timed_out)) > 0:
611 "Thread %s returned abnormal results; killing the others.",
617 "Thread %s took an unhandled exception... bug in run_tests.py?! Aborting.",
623 color = ansi.fg("green")
625 color = ansi.fg("red")
628 percent_done = done / started * 100.0
632 if percent_done < 100.0:
634 text_utils.bar_graph_string(
637 text=text_utils.BarGraphText.FRACTION,
644 print(f" {color}{now - start_time:.1f}s{ansi.reset()}", end="\r")
647 print(f"{ansi.clear_line()}\n{ansi.underline()}Final Report:{ansi.reset()}")
648 if config.config["coverage"]:
649 code_coverage_report()
650 print(f"Test suite runtime: {time.time() - start_time:.1f}s")
651 total_problems = test_results_report(results)
652 if total_problems > 0:
654 "Exiting with non-zero return code %d due to problems.", total_problems
656 return total_problems
659 if __name__ == "__main__":