4 A smart, fast test runner. Used in a git pre-commit hook.
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
17 from overrides import overrides
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26 f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
31 '--integration', '-i', action='store_true', help='Run integration tests.'
37 help='Run unittests, doctests and integration tests. Equivalient to -u -d -i',
43 help='Run tests and capture code coverage data',
46 HOME = os.environ['HOME']
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled. This is because
50 # they pay attention to code performance which is adversely affected
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
59 class TestingParameters:
61 """Should we stop as soon as one error has occurred?"""
63 halt_event: threading.Event
64 """An event that, when set, indicates to stop ASAP."""
70 """The name of the test"""
73 """The kind of the test"""
76 """The command line to execute"""
82 """The name of this test / set of tests."""
84 tests_executed: Dict[str, float]
85 """Tests that were executed."""
87 tests_succeeded: List[str]
88 """Tests that succeeded."""
90 tests_failed: List[str]
91 """Tests that failed."""
93 tests_timed_out: List[str]
94 """Tests that timed out."""
96 def __add__(self, other):
97 merged = dict_utils.coalesce(
98 [self.tests_executed, other.tests_executed],
99 aggregation_function=dict_utils.raise_on_duplicated_keys,
101 self.tests_executed = merged
102 self.tests_succeeded.extend(other.tests_succeeded)
103 self.tests_failed.extend(other.tests_failed)
104 self.tests_timed_out.extend(other.tests_timed_out)
109 def __repr__(self) -> str:
110 out = f'{self.name}: '
111 out += f'{ansi.fg("green")}'
112 out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
113 out += f'{ansi.reset()}.\n'
115 if len(self.tests_failed) > 0:
116 out += f' ..{ansi.fg("red")}'
117 out += f'{len(self.tests_failed)} tests failed'
118 out += f'{ansi.reset()}:\n'
119 for test in self.tests_failed:
123 if len(self.tests_timed_out) > 0:
124 out += f' ..{ansi.fg("yellow")}'
125 out += f'{len(self.tests_timed_out)} tests timed out'
126 out += f'{ansi.reset()}:\n'
127 for test in self.tests_failed:
133 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
134 """A Base class for something that runs a test."""
136 def __init__(self, params: TestingParameters):
137 """Create a TestRunner.
140 params: Test running paramters.
143 super().__init__(self, target=self.begin, args=[params])
145 self.test_results = TestResults(
146 name=self.get_name(),
152 self.lock = threading.Lock()
155 def get_name(self) -> str:
156 """The name of this test collection."""
159 def get_status(self) -> TestResults:
160 """Ask the TestRunner for its status."""
162 return self.test_results
165 def begin(self, params: TestingParameters) -> TestResults:
166 """Start execution."""
170 class TemplatedTestRunner(TestRunner, ABC):
171 """A TestRunner that has a recipe for executing the tests."""
174 def identify_tests(self) -> List[TestToRun]:
175 """Return a list of tuples (test, cmdline) that should be executed."""
179 def run_test(self, test: TestToRun) -> TestResults:
180 """Run a single test and return its TestResults."""
183 def check_for_abort(self) -> bool:
184 """Periodically caled to check to see if we need to stop."""
186 if self.params.halt_event.is_set():
187 logger.debug('Thread %s saw halt event; exiting.', self.get_name())
190 if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
191 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
195 def persist_output(self, test: TestToRun, message: str, output: str) -> None:
196 """Called to save the output of a test run."""
198 dest = f'{test.name}-output.txt'
199 with open(f'./test_output/{dest}', 'w') as wf:
200 print(message, file=wf)
201 print('-' * len(message), file=wf)
204 def execute_commandline(
208 timeout: float = 120.0,
210 """Execute a particular commandline to run a test."""
213 output = exec_utils.cmd(
215 timeout_seconds=timeout,
217 if "***Test Failed***" in output:
218 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
220 self.persist_output(test, msg, output)
230 test, f'{test.name} ({test.cmdline}) succeeded.', output
233 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
235 return TestResults(test.name, {}, [test.name], [], [])
236 except subprocess.TimeoutExpired as e:
237 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
240 '%s: %s output when it timed out: %s',
245 self.persist_output(test, msg, e.output.decode('utf-8'))
253 except subprocess.CalledProcessError as e:
254 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
257 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
259 self.persist_output(test, msg, e.output.decode('utf-8'))
269 def begin(self, params: TestingParameters) -> TestResults:
270 logger.debug('Thread %s started.', self.get_name())
271 interesting_tests = self.identify_tests()
273 '%s: Identified %d tests to be run.',
275 len(interesting_tests),
278 # Note: because of @parallelize on run_tests it actually
279 # returns a SmartFuture with a TestResult inside of it.
280 # That's the reason for this Any business.
281 running: List[Any] = []
282 for test_to_run in interesting_tests:
283 running.append(self.run_test(test_to_run))
285 '%s: Test %s started in the background.',
289 self.test_results.tests_executed[test_to_run.name] = time.time()
291 for future in smart_future.wait_any(running, log_exceptions=False):
292 result = future._resolve()
293 logger.debug('Test %s finished.', result.name)
295 # We sometimes run the same test more than once. Do not allow
296 # one run's results to klobber the other's.
297 self.test_results += result
298 if self.check_for_abort():
300 '%s: check_for_abort told us to exit early.', self.get_name()
302 return self.test_results
304 logger.debug('Thread %s finished running all tests', self.get_name())
305 return self.test_results
308 class UnittestTestRunner(TemplatedTestRunner):
309 """Run all known Unittests."""
312 def get_name(self) -> str:
316 def identify_tests(self) -> List[TestToRun]:
318 for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
319 basename = file_utils.without_path(test)
320 if basename in TESTS_TO_SKIP:
322 if config.config['coverage']:
326 kind='unittest capturing coverage',
327 cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
330 if basename in PERF_SENSATIVE_TESTS:
333 name=f'{basename}_no_coverage',
334 kind='unittest w/o coverage to record perf',
335 cmdline=f'{test} 2>&1',
343 cmdline=f'{test} 2>&1',
349 def run_test(self, test: TestToRun) -> TestResults:
350 return self.execute_commandline(test)
353 class DoctestTestRunner(TemplatedTestRunner):
354 """Run all known Doctests."""
357 def get_name(self) -> str:
361 def identify_tests(self) -> List[TestToRun]:
363 out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
364 for test in out.split('\n'):
365 if re.match(r'.*\.py$', test):
366 basename = file_utils.without_path(test)
367 if basename in TESTS_TO_SKIP:
369 if config.config['coverage']:
373 kind='doctest capturing coverage',
374 cmdline=f'coverage run --source ../src {test} 2>&1',
377 if basename in PERF_SENSATIVE_TESTS:
380 name=f'{basename}_no_coverage',
381 kind='doctest w/o coverage to record perf',
382 cmdline=f'python3 {test} 2>&1',
390 cmdline=f'python3 {test} 2>&1',
396 def run_test(self, test: TestToRun) -> TestResults:
397 return self.execute_commandline(test)
400 class IntegrationTestRunner(TemplatedTestRunner):
401 """Run all know Integration tests."""
404 def get_name(self) -> str:
405 return "Integration Tests"
408 def identify_tests(self) -> List[TestToRun]:
410 for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
411 basename = file_utils.without_path(test)
412 if basename in TESTS_TO_SKIP:
414 if config.config['coverage']:
418 kind='integration test capturing coverage',
419 cmdline=f'coverage run --source ../src {test} 2>&1',
422 if basename in PERF_SENSATIVE_TESTS:
425 name=f'{basename}_no_coverage',
426 kind='integration test w/o coverage to capture perf',
427 cmdline=f'{test} 2>&1',
433 name=basename, kind='integration test', cmdline=f'{test} 2>&1'
439 def run_test(self, test: TestToRun) -> TestResults:
440 return self.execute_commandline(test)
443 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
444 """Give a final report about the tests that were run."""
446 for result in results.values():
448 print('Unexpected unhandled exception in test runner!!!')
451 print(result, end='')
452 total_problems += len(result.tests_failed)
453 total_problems += len(result.tests_timed_out)
455 if total_problems > 0:
456 print('Reminder: look in ./test_output to view test output logs')
457 return total_problems
460 def code_coverage_report():
461 """Give a final code coverage report."""
462 text_utils.header('Code Coverage')
463 exec_utils.cmd('coverage combine .coverage*')
464 out = exec_utils.cmd(
465 'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
469 """To recall this report w/o re-running the tests:
471 $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
473 ...from the 'tests' directory. Note that subsequent calls to
474 run_tests.py with --coverage will klobber previous results. See:
476 https://coverage.readthedocs.io/en/6.2/
481 @bootstrap.initialize
482 def main() -> Optional[int]:
484 halt_event = threading.Event()
485 threads: List[TestRunner] = []
488 params = TestingParameters(
490 halt_event=halt_event,
493 if config.config['coverage']:
494 logger.debug('Clearing existing coverage data via "coverage erase".')
495 exec_utils.cmd('coverage erase')
497 if config.config['unittests'] or config.config['all']:
499 threads.append(UnittestTestRunner(params))
500 if config.config['doctests'] or config.config['all']:
502 threads.append(DoctestTestRunner(params))
503 if config.config['integration'] or config.config['all']:
505 threads.append(IntegrationTestRunner(params))
509 print('ERROR: one of --unittests, --doctests or --integration is required.')
512 for thread in threads:
515 results: Dict[str, Optional[TestResults]] = {}
516 start_time = time.time()
517 last_update = start_time
520 while len(results) != len(threads):
525 for thread in threads:
527 tr = thread.get_status()
528 started += len(tr.tests_executed)
529 failed += len(tr.tests_failed) + len(tr.tests_timed_out)
530 done += failed + len(tr.tests_succeeded)
531 running = set(tr.tests_executed.keys())
532 running -= set(tr.tests_failed)
533 running -= set(tr.tests_succeeded)
534 running -= set(tr.tests_timed_out)
535 running_with_start_time = {
536 test: tr.tests_executed[test] for test in running
538 still_running[tid] = running_with_start_time
541 if now - start_time > 5.0:
542 if now - last_update > 3.0:
545 for _, running_dict in still_running.items():
546 for test_name, start_time in running_dict.items():
547 if now - start_time > 10.0:
548 update.append(f'{test_name}@{now-start_time:.1f}s')
550 update.append(test_name)
551 print(f'\r{ansi.clear_line()}')
553 print(f'Still running: {",".join(update)}')
555 print(f'Still running: {len(update)} tests.')
557 if not thread.is_alive():
558 if tid not in results:
559 result = thread.join()
561 results[tid] = result
562 if len(result.tests_failed) > 0:
564 'Thread %s returned abnormal results; killing the others.',
570 'Thread %s took an unhandled exception... bug in run_tests.py?! Aborting.',
577 color = ansi.fg('green')
579 color = ansi.fg('red')
582 percent_done = done / started * 100.0
586 if percent_done < 100.0:
588 text_utils.bar_graph_string(
591 text=text_utils.BarGraphText.FRACTION,
600 print(f'{ansi.clear_line()}Final Report:')
601 if config.config['coverage']:
602 code_coverage_report()
603 total_problems = test_results_report(results)
604 return total_problems
607 if __name__ == '__main__':