#!/usr/bin/env python3 """ A smart, fast test runner. Used in a git pre-commit hook. """ import logging import os import re import subprocess import threading import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple from overrides import overrides import ansi import bootstrap import config import exec_utils import file_utils import parallelize as par import smart_future import text_utils import thread_utils logger = logging.getLogger(__name__) args = config.add_commandline_args(f'({__file__})', 'Args related to __file__') args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.') args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.') args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.') args.add_argument( '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data' ) HOME = os.environ['HOME'] @dataclass class TestingParameters: halt_on_error: bool """Should we stop as soon as one error has occurred?""" halt_event: threading.Event """An event that, when set, indicates to stop ASAP.""" @dataclass class TestResults: name: str """The name of this test / set of tests.""" tests_executed: List[str] """Tests that were executed.""" tests_succeeded: List[str] """Tests that succeeded.""" tests_failed: List[str] """Tests that failed.""" tests_timed_out: List[str] """Tests that timed out.""" def __add__(self, other): self.tests_executed.extend(other.tests_executed) self.tests_succeeded.extend(other.tests_succeeded) self.tests_failed.extend(other.tests_failed) self.tests_timed_out.extend(other.tests_timed_out) return self __radd__ = __add__ def __repr__(self) -> str: out = f'{self.name}: ' out += f'{ansi.fg("green")}' out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed' out += f'{ansi.reset()}.\n' if len(self.tests_failed) > 0: out += f' ..{ansi.fg("red")}' out += f'{len(self.tests_failed)} tests failed' out += f'{ansi.reset()}:\n' for test in self.tests_failed: out += f' {test}\n' out += '\n' if len(self.tests_timed_out) > 0: out += f' ..{ansi.fg("yellow")}' out += f'{len(self.tests_timed_out)} tests timed out' out += f'{ansi.reset()}:\n' for test in self.tests_failed: out += f' {test}\n' out += '\n' return out class TestRunner(ABC, thread_utils.ThreadWithReturnValue): """A Base class for something that runs a test.""" def __init__(self, params: TestingParameters): """Create a TestRunner. Args: params: Test running paramters. """ super().__init__(self, target=self.begin, args=[params]) self.params = params self.test_results = TestResults( name=self.get_name(), tests_executed=[], tests_succeeded=[], tests_failed=[], tests_timed_out=[], ) self.tests_started = 0 @abstractmethod def get_name(self) -> str: """The name of this test collection.""" pass def get_status(self) -> Tuple[int, TestResults]: """Ask the TestRunner for its status.""" return (self.tests_started, self.test_results) @abstractmethod def begin(self, params: TestingParameters) -> TestResults: """Start execution.""" pass class TemplatedTestRunner(TestRunner, ABC): """A TestRunner that has a recipe for executing the tests.""" @abstractmethod def identify_tests(self) -> List[str]: """Return a list of tests that should be executed.""" pass @abstractmethod def run_test(self, test: Any) -> TestResults: """Run a single test and return its TestResults.""" pass def check_for_abort(self): """Periodically caled to check to see if we need to stop.""" if self.params.halt_event.is_set(): logger.debug('Thread %s saw halt event; exiting.', self.get_name()) raise Exception("Kill myself!") if self.params.halt_on_error: if len(self.test_results.tests_failed) > 0: logger.error('Thread %s saw abnormal results; exiting.', self.get_name()) raise Exception("Kill myself!") def persist_output(self, test_name: str, message: str, output: str) -> None: """Called to save the output of a test run.""" basename = file_utils.without_path(test_name) dest = f'{basename}-output.txt' with open(f'./test_output/{dest}', 'w') as wf: print(message, file=wf) print('-' * len(message), file=wf) wf.write(output) def execute_commandline( self, test_name: str, cmdline: str, *, timeout: float = 120.0, ) -> TestResults: """Execute a particular commandline to run a test.""" try: logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline) output = exec_utils.cmd( cmdline, timeout_seconds=timeout, ) self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output) logger.debug('%s (%s) succeeded', test_name, cmdline) return TestResults(test_name, [test_name], [test_name], [], []) except subprocess.TimeoutExpired as e: msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.' logger.error(msg) logger.debug( '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output ) self.persist_output(test_name, msg, e.output.decode('utf-8')) return TestResults( test_name, [test_name], [], [], [test_name], ) except subprocess.CalledProcessError as e: msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}' logger.error(msg) logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output) self.persist_output(test_name, msg, e.output.decode('utf-8')) return TestResults( test_name, [test_name], [], [test_name], [], ) @overrides def begin(self, params: TestingParameters) -> TestResults: logger.debug('Thread %s started.', self.get_name()) interesting_tests = self.identify_tests() running: List[Any] = [] for test in interesting_tests: running.append(self.run_test(test)) self.tests_started = len(running) for future in smart_future.wait_any(running): self.check_for_abort() result = future._resolve() logger.debug('Test %s finished.', result.name) self.test_results += result logger.debug('Thread %s finished.', self.get_name()) return self.test_results class UnittestTestRunner(TemplatedTestRunner): """Run all known Unittests.""" @overrides def get_name(self) -> str: return "Unittests" @overrides def identify_tests(self) -> List[str]: return list(file_utils.expand_globs('*_test.py')) @par.parallelize def run_test(self, test: Any) -> TestResults: if config.config['coverage']: cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf' else: cmdline = test return self.execute_commandline(test, cmdline) class DoctestTestRunner(TemplatedTestRunner): """Run all known Doctests.""" @overrides def get_name(self) -> str: return "Doctests" @overrides def identify_tests(self) -> List[str]: ret = [] out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*') for line in out.split('\n'): if re.match(r'.*\.py$', line): if 'run_tests.py' not in line: ret.append(line) return ret @par.parallelize def run_test(self, test: Any) -> TestResults: if config.config['coverage']: cmdline = f'coverage run --source {HOME}/lib {test} 2>&1' else: cmdline = f'python3 {test}' return self.execute_commandline(test, cmdline) class IntegrationTestRunner(TemplatedTestRunner): """Run all know Integration tests.""" @overrides def get_name(self) -> str: return "Integration Tests" @overrides def identify_tests(self) -> List[str]: return list(file_utils.expand_globs('*_itest.py')) @par.parallelize def run_test(self, test: Any) -> TestResults: if config.config['coverage']: cmdline = f'coverage run --source {HOME}/lib {test}' else: cmdline = test return self.execute_commandline(test, cmdline) def test_results_report(results: Dict[str, TestResults]) -> int: """Give a final report about the tests that were run.""" total_problems = 0 for result in results.values(): print(result, end='') total_problems += len(result.tests_failed) total_problems += len(result.tests_timed_out) if total_problems > 0: print('Reminder: look in ./test_output to view test output logs') return total_problems def code_coverage_report(): """Give a final code coverage report.""" text_utils.header('Code Coverage') exec_utils.cmd('coverage combine .coverage*') out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover') print(out) print( """ To recall this report w/o re-running the tests: $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover ...from the 'tests' directory. Note that subsequent calls to run_tests.py with --coverage will klobber previous results. See: https://coverage.readthedocs.io/en/6.2/ """ ) @bootstrap.initialize def main() -> Optional[int]: saw_flag = False halt_event = threading.Event() threads: List[TestRunner] = [] halt_event.clear() params = TestingParameters( halt_on_error=True, halt_event=halt_event, ) if config.config['coverage']: logger.debug('Clearing existing coverage data via "coverage erase".') exec_utils.cmd('coverage erase') if config.config['unittests']: saw_flag = True threads.append(UnittestTestRunner(params)) if config.config['doctests']: saw_flag = True threads.append(DoctestTestRunner(params)) if config.config['integration']: saw_flag = True threads.append(IntegrationTestRunner(params)) if not saw_flag: config.print_usage() print('ERROR: one of --unittests, --doctests or --integration is required.') return 1 for thread in threads: thread.start() results: Dict[str, TestResults] = {} while len(results) != len(threads): started = 0 done = 0 failed = 0 for thread in threads: if not thread.is_alive(): tid = thread.name if tid not in results: result = thread.join() if result: results[tid] = result if len(result.tests_failed) > 0: logger.error( 'Thread %s returned abnormal results; killing the others.', tid ) halt_event.set() else: (s, tr) = thread.get_status() started += s failed += len(tr.tests_failed) + len(tr.tests_timed_out) done += failed + len(tr.tests_succeeded) if started > 0: percent_done = done / started else: percent_done = 0.0 if failed == 0: color = ansi.fg('green') else: color = ansi.fg('red') if percent_done < 100.0: print( text_utils.bar_graph( percent_done, width=80, fgcolor=color, ), end='\r', flush=True, ) else: print("Finished.\n") time.sleep(0.5) if config.config['coverage']: code_coverage_report() total_problems = test_results_report(results) return total_problems if __name__ == '__main__': main()