4 A smart, fast test runner. Used in a git pre-commit hook.
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
17 from overrides import overrides
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26 f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
31 '--integration', '-i', action='store_true', help='Run integration tests.'
37 help='Run unittests, doctests and integration tests. Equivalient to -u -d -i',
43 help='Run tests and capture code coverage data',
46 HOME = os.environ['HOME']
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled. This is because
50 # they pay attention to code performance which is adversely affected
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
59 class TestingParameters:
61 """Should we stop as soon as one error has occurred?"""
63 halt_event: threading.Event
64 """An event that, when set, indicates to stop ASAP."""
70 """The name of the test"""
73 """The kind of the test"""
76 """The command line to execute"""
82 """The name of this test / set of tests."""
84 tests_executed: Dict[str, float]
85 """Tests that were executed."""
87 tests_succeeded: List[str]
88 """Tests that succeeded."""
90 tests_failed: List[str]
91 """Tests that failed."""
93 tests_timed_out: List[str]
94 """Tests that timed out."""
96 def __add__(self, other):
97 merged = dict_utils.coalesce(
98 [self.tests_executed, other.tests_executed],
99 aggregation_function=dict_utils.raise_on_duplicated_keys,
101 self.tests_executed = merged
102 self.tests_succeeded.extend(other.tests_succeeded)
103 self.tests_failed.extend(other.tests_failed)
104 self.tests_timed_out.extend(other.tests_timed_out)
109 def __repr__(self) -> str:
110 out = f'{self.name}: '
111 out += f'{ansi.fg("green")}'
112 out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
113 out += f'{ansi.reset()}.\n'
115 if len(self.tests_failed) > 0:
116 out += f' ..{ansi.fg("red")}'
117 out += f'{len(self.tests_failed)} tests failed'
118 out += f'{ansi.reset()}:\n'
119 for test in self.tests_failed:
123 if len(self.tests_timed_out) > 0:
124 out += f' ..{ansi.fg("yellow")}'
125 out += f'{len(self.tests_timed_out)} tests timed out'
126 out += f'{ansi.reset()}:\n'
127 for test in self.tests_failed:
133 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
134 """A Base class for something that runs a test."""
136 def __init__(self, params: TestingParameters):
137 """Create a TestRunner.
140 params: Test running paramters.
143 super().__init__(self, target=self.begin, args=[params])
145 self.test_results = TestResults(
146 name=self.get_name(),
152 self.lock = threading.Lock()
155 def get_name(self) -> str:
156 """The name of this test collection."""
159 def get_status(self) -> TestResults:
160 """Ask the TestRunner for its status."""
162 return self.test_results
165 def begin(self, params: TestingParameters) -> TestResults:
166 """Start execution."""
170 class TemplatedTestRunner(TestRunner, ABC):
171 """A TestRunner that has a recipe for executing the tests."""
174 def identify_tests(self) -> List[TestToRun]:
175 """Return a list of tuples (test, cmdline) that should be executed."""
179 def run_test(self, test: TestToRun) -> TestResults:
180 """Run a single test and return its TestResults."""
183 def check_for_abort(self) -> bool:
184 """Periodically caled to check to see if we need to stop."""
186 if self.params.halt_event.is_set():
187 logger.debug('Thread %s saw halt event; exiting.', self.get_name())
190 if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
191 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
195 def persist_output(self, test: TestToRun, message: str, output: str) -> None:
196 """Called to save the output of a test run."""
198 dest = f'{test.name}-output.txt'
199 with open(f'./test_output/{dest}', 'w') as wf:
200 print(message, file=wf)
201 print('-' * len(message), file=wf)
204 def execute_commandline(
208 timeout: float = 120.0,
210 """Execute a particular commandline to run a test."""
213 output = exec_utils.cmd(
215 timeout_seconds=timeout,
217 if "***Test Failed***" in output:
218 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
220 self.persist_output(test, msg, output)
230 test, f'{test.name} ({test.cmdline}) succeeded.', output
233 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
235 return TestResults(test.name, {}, [test.name], [], [])
236 except subprocess.TimeoutExpired as e:
237 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
240 '%s: %s output when it timed out: %s',
245 self.persist_output(test, msg, e.output.decode('utf-8'))
253 except subprocess.CalledProcessError as e:
254 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
257 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
259 self.persist_output(test, msg, e.output.decode('utf-8'))
269 def begin(self, params: TestingParameters) -> TestResults:
270 logger.debug('Thread %s started.', self.get_name())
271 interesting_tests = self.identify_tests()
273 '%s: Identified %d tests to be run.',
275 len(interesting_tests),
278 # Note: because of @parallelize on run_tests it actually
279 # returns a SmartFuture with a TestResult inside of it.
280 # That's the reason for this Any business.
281 running: List[Any] = []
282 for test_to_run in interesting_tests:
283 running.append(self.run_test(test_to_run))
285 '%s: Test %s started in the background.',
289 self.test_results.tests_executed[test_to_run.name] = time.time()
291 for future in smart_future.wait_any(running, log_exceptions=False):
292 result = future._resolve()
293 logger.debug('Test %s finished.', result.name)
294 self.test_results += result
295 if self.check_for_abort():
297 '%s: check_for_abort told us to exit early.', self.get_name()
299 return self.test_results
301 logger.debug('Thread %s finished running all tests', self.get_name())
302 return self.test_results
305 class UnittestTestRunner(TemplatedTestRunner):
306 """Run all known Unittests."""
309 def get_name(self) -> str:
313 def identify_tests(self) -> List[TestToRun]:
315 for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
316 basename = file_utils.without_path(test)
317 if basename in TESTS_TO_SKIP:
319 if config.config['coverage']:
323 kind='unittest capturing coverage',
324 cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
327 if basename in PERF_SENSATIVE_TESTS:
331 kind='unittest w/o coverage to record perf',
332 cmdline=f'{test} 2>&1',
340 cmdline=f'{test} 2>&1',
346 def run_test(self, test: TestToRun) -> TestResults:
347 return self.execute_commandline(test)
350 class DoctestTestRunner(TemplatedTestRunner):
351 """Run all known Doctests."""
354 def get_name(self) -> str:
358 def identify_tests(self) -> List[TestToRun]:
360 out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
361 for test in out.split('\n'):
362 if re.match(r'.*\.py$', test):
363 basename = file_utils.without_path(test)
364 if basename in TESTS_TO_SKIP:
366 if config.config['coverage']:
370 kind='doctest capturing coverage',
371 cmdline=f'coverage run --source ../src {test} 2>&1',
374 if basename in PERF_SENSATIVE_TESTS:
378 kind='doctest w/o coverage to record perf',
379 cmdline=f'python3 {test} 2>&1',
387 cmdline=f'python3 {test} 2>&1',
393 def run_test(self, test: TestToRun) -> TestResults:
394 return self.execute_commandline(test)
397 class IntegrationTestRunner(TemplatedTestRunner):
398 """Run all know Integration tests."""
401 def get_name(self) -> str:
402 return "Integration Tests"
405 def identify_tests(self) -> List[TestToRun]:
407 for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
408 basename = file_utils.without_path(test)
409 if basename in TESTS_TO_SKIP:
411 if config.config['coverage']:
415 kind='integration test capturing coverage',
416 cmdline=f'coverage run --source ../src {test} 2>&1',
419 if basename in PERF_SENSATIVE_TESTS:
423 kind='integration test w/o coverage to capture perf',
424 cmdline=f'{test} 2>&1',
430 name=basename, kind='integration test', cmdline=f'{test} 2>&1'
436 def run_test(self, test: TestToRun) -> TestResults:
437 return self.execute_commandline(test)
440 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
441 """Give a final report about the tests that were run."""
443 for result in results.values():
445 print('Unexpected unhandled exception in test runner!!!')
448 print(result, end='')
449 total_problems += len(result.tests_failed)
450 total_problems += len(result.tests_timed_out)
452 if total_problems > 0:
453 print('Reminder: look in ./test_output to view test output logs')
454 return total_problems
457 def code_coverage_report():
458 """Give a final code coverage report."""
459 text_utils.header('Code Coverage')
460 exec_utils.cmd('coverage combine .coverage*')
461 out = exec_utils.cmd(
462 'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
466 """To recall this report w/o re-running the tests:
468 $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
470 ...from the 'tests' directory. Note that subsequent calls to
471 run_tests.py with --coverage will klobber previous results. See:
473 https://coverage.readthedocs.io/en/6.2/
478 @bootstrap.initialize
479 def main() -> Optional[int]:
481 halt_event = threading.Event()
482 threads: List[TestRunner] = []
485 params = TestingParameters(
487 halt_event=halt_event,
490 if config.config['coverage']:
491 logger.debug('Clearing existing coverage data via "coverage erase".')
492 exec_utils.cmd('coverage erase')
494 if config.config['unittests'] or config.config['all']:
496 threads.append(UnittestTestRunner(params))
497 if config.config['doctests'] or config.config['all']:
499 threads.append(DoctestTestRunner(params))
500 if config.config['integration'] or config.config['all']:
502 threads.append(IntegrationTestRunner(params))
506 print('ERROR: one of --unittests, --doctests or --integration is required.')
509 for thread in threads:
512 results: Dict[str, Optional[TestResults]] = {}
513 start_time = time.time()
514 last_update = start_time
517 while len(results) != len(threads):
522 for thread in threads:
524 tr = thread.get_status()
525 started += len(tr.tests_executed)
526 failed += len(tr.tests_failed) + len(tr.tests_timed_out)
527 done += failed + len(tr.tests_succeeded)
528 running = set(tr.tests_executed.keys())
529 running -= set(tr.tests_failed)
530 running -= set(tr.tests_succeeded)
531 running -= set(tr.tests_timed_out)
532 running_with_start_time = {
533 test: tr.tests_executed[test] for test in running
535 still_running[tid] = running_with_start_time
538 if now - start_time > 5.0:
539 if now - last_update > 3.0:
542 for _, running_dict in still_running.items():
543 for test_name, start_time in running_dict.items():
544 if now - start_time > 10.0:
545 update.append(f'{test_name}@{now-start_time:.1f}s')
547 update.append(test_name)
548 print(f'\r{ansi.clear_line()}')
550 print(f'Still running: {",".join(update)}')
552 print(f'Still running: {len(update)} tests.')
554 if not thread.is_alive():
555 if tid not in results:
556 result = thread.join()
558 results[tid] = result
559 if len(result.tests_failed) > 0:
561 'Thread %s returned abnormal results; killing the others.',
567 'Thread %s took an unhandled exception... bug in run_tests.py?! Aborting.',
574 percent_done = done / started
579 color = ansi.fg('green')
581 color = ansi.fg('red')
583 if percent_done < 100.0:
585 text_utils.bar_graph_string(
588 text=text_utils.BarGraphText.FRACTION,
597 print(f'{ansi.clear_line()}Final Report:')
598 if config.config['coverage']:
599 code_coverage_report()
600 total_problems = test_results_report(results)
601 return total_problems
604 if __name__ == '__main__':