--- /dev/null
+#!/usr/bin/env python3
+
+"""
+A smart, fast test runner.
+"""
+
+import logging
+import os
+import re
+import subprocess
+import threading
+import time
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+from overrides import overrides
+
+import bootstrap
+import config
+import exec_utils
+import file_utils
+import parallelize as par
+import text_utils
+import thread_utils
+
+logger = logging.getLogger(__name__)
+args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
+args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
+args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
+args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
+args.add_argument(
+ '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
+)
+
+HOME = os.environ['HOME']
+
+
+@dataclass
+class TestingParameters:
+ halt_on_error: bool
+ halt_event: threading.Event
+
+
+@dataclass
+class TestResults:
+ name: str
+ num_tests_executed: int
+ num_tests_succeeded: int
+ num_tests_failed: int
+ normal_exit: bool
+ output: str
+
+
+class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
+ def __init__(self, params: TestingParameters):
+ super().__init__(self, target=self.begin, args=[params])
+ self.params = params
+ self.test_results = TestResults(
+ name=f"All {self.get_name()} tests",
+ num_tests_executed=0,
+ num_tests_succeeded=0,
+ num_tests_failed=0,
+ normal_exit=True,
+ output="",
+ )
+
+ def aggregate_test_results(self, result: TestResults):
+ self.test_results.num_tests_executed += result.num_tests_executed
+ self.test_results.num_tests_succeeded += result.num_tests_succeeded
+ self.test_results.num_tests_failed += result.num_tests_failed
+ self.test_results.normal_exit = self.test_results.normal_exit and result.normal_exit
+ self.test_results.output += "\n\n\n" + result.output
+
+ @abstractmethod
+ def get_name(self) -> str:
+ pass
+
+ @abstractmethod
+ def begin(self, params: TestingParameters) -> TestResults:
+ pass
+
+
+class TemplatedTestRunner(TestRunner, ABC):
+ @abstractmethod
+ def identify_tests(self) -> List[Any]:
+ pass
+
+ @abstractmethod
+ def run_test(self, test: Any) -> TestResults:
+ pass
+
+ def check_for_abort(self):
+ if self.params.halt_event.is_set():
+ logger.debug('Thread %s saw halt event; exiting.', self.get_name())
+ raise Exception("Kill myself!")
+ if not self.test_results.normal_exit:
+ if self.params.halt_on_error:
+ logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
+ raise Exception("Kill myself!")
+
+ def status_report(self, running: List[Any], done: List[Any]):
+ total = len(running) + len(done)
+ logging.info(
+ '%s: %d/%d in flight; %d/%d completed.',
+ self.get_name(),
+ len(running),
+ total,
+ len(done),
+ total,
+ )
+
+ @overrides
+ def begin(self, params: TestingParameters) -> TestResults:
+ logger.debug('Thread %s started.', self.get_name())
+ interesting_tests = self.identify_tests()
+ running: List[Any] = []
+ done: List[Any] = []
+ for test in interesting_tests:
+ running.append(self.run_test(test))
+
+ while len(running) > 0:
+ self.status_report(running, done)
+ self.check_for_abort()
+ newly_finished = []
+ for fut in running:
+ if fut.is_ready():
+ newly_finished.append(fut)
+ result = fut._resolve()
+ logger.debug('Test %s finished.', result.name)
+ self.aggregate_test_results(result)
+
+ for fut in newly_finished:
+ running.remove(fut)
+ done.append(fut)
+ time.sleep(0.25)
+
+ logger.debug('Thread %s finished.', self.get_name())
+ return self.test_results
+
+
+class UnittestTestRunner(TemplatedTestRunner):
+ @overrides
+ def get_name(self) -> str:
+ return "UnittestTestRunner"
+
+ @overrides
+ def identify_tests(self) -> List[Any]:
+ return list(file_utils.expand_globs('*_test.py'))
+
+ @par.parallelize
+ def run_test(self, test: Any) -> TestResults:
+ if config.config['coverage']:
+ cmdline = f'coverage run --source {HOME}/lib --append {test} --unittests_ignore_perf'
+ else:
+ cmdline = test
+
+ try:
+ logger.debug('Running unittest %s (%s)', test, cmdline)
+ output = exec_utils.cmd(
+ cmdline,
+ timeout_seconds=120.0,
+ )
+ except TimeoutError:
+ logger.error('Unittest %s timed out; ran for > 120.0 seconds', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Unittest {test} timed out.",
+ )
+ except subprocess.CalledProcessError:
+ logger.error('Unittest %s failed.', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Unittest {test} failed.",
+ )
+ return TestResults(test, 1, 1, 0, True, output)
+
+
+class DoctestTestRunner(TemplatedTestRunner):
+ @overrides
+ def get_name(self) -> str:
+ return "DoctestTestRunner"
+
+ @overrides
+ def identify_tests(self) -> List[Any]:
+ ret = []
+ out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
+ for line in out.split('\n'):
+ if re.match(r'.*\.py$', line):
+ if 'run_tests.py' not in line:
+ ret.append(line)
+ return ret
+
+ @par.parallelize
+ def run_test(self, test: Any) -> TestResults:
+ if config.config['coverage']:
+ cmdline = f'coverage run --source {HOME}/lib --append {test} 2>&1'
+ else:
+ cmdline = f'python3 {test}'
+ try:
+ logger.debug('Running doctest %s (%s).', test, cmdline)
+ output = exec_utils.cmd(
+ cmdline,
+ timeout_seconds=120.0,
+ )
+ except TimeoutError:
+ logger.error('Doctest %s timed out; ran for > 120.0 seconds', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Doctest {test} timed out.",
+ )
+ except subprocess.CalledProcessError:
+ logger.error('Doctest %s failed.', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Docttest {test} failed.",
+ )
+ return TestResults(
+ test,
+ 1,
+ 1,
+ 0,
+ True,
+ "",
+ )
+
+
+class IntegrationTestRunner(TemplatedTestRunner):
+ @overrides
+ def get_name(self) -> str:
+ return "IntegrationTestRunner"
+
+ @overrides
+ def identify_tests(self) -> List[Any]:
+ return list(file_utils.expand_globs('*_itest.py'))
+
+ @par.parallelize
+ def run_test(self, test: Any) -> TestResults:
+ if config.config['coverage']:
+ cmdline = f'coverage run --source {HOME}/lib --append {test}'
+ else:
+ cmdline = test
+ try:
+ logger.debug('Running integration test %s (%s).', test, cmdline)
+ output = exec_utils.cmd(
+ cmdline,
+ timeout_seconds=240.0,
+ )
+ except TimeoutError:
+ logger.error('Integration Test %s timed out; ran for > 240.0 seconds', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Integration Test {test} timed out.",
+ )
+ except subprocess.CalledProcessError:
+ logger.error('Integration Test %s failed.', test)
+ return TestResults(
+ test,
+ 1,
+ 0,
+ 1,
+ False,
+ f"Integration Test {test} failed.",
+ )
+ return TestResults(
+ test,
+ 1,
+ 1,
+ 0,
+ True,
+ "",
+ )
+
+
+def test_results_report(results: Dict[str, TestResults]):
+ print(results)
+
+
+def code_coverage_report():
+ text_utils.header('Code Coverage')
+ out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
+ print(out)
+ print(
+ """
+To recall this report w/o re-running the tests:
+
+ $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
+
+...from the 'tests' directory. Note that subsequent calls to
+run_tests.py with --coverage will klobber previous results. See:
+
+ https://coverage.readthedocs.io/en/6.2/
+
+"""
+ )
+
+
+def main() -> Optional[int]:
+ saw_flag = False
+ halt_event = threading.Event()
+ threads: List[TestRunner] = []
+
+ halt_event.clear()
+ params = TestingParameters(
+ halt_on_error=True,
+ halt_event=halt_event,
+ )
+
+ if config.config['coverage']:
+ logger.debug('Clearing existing coverage data via "coverage erase".')
+ exec_utils.cmd('coverage erase')
+
+ if config.config['unittests']:
+ saw_flag = True
+ threads.append(UnittestTestRunner(params))
+ if config.config['doctests']:
+ saw_flag = True
+ threads.append(DoctestTestRunner(params))
+ if config.config['integration']:
+ saw_flag = True
+ threads.append(IntegrationTestRunner(params))
+
+ if not saw_flag:
+ config.print_usage()
+ print('ERROR: one of --unittests, --doctests or --integration is required.')
+ return 1
+
+ for thread in threads:
+ thread.start()
+
+ results: Dict[str, TestResults] = {}
+ while len(results) != len(threads):
+ for thread in threads:
+ if not thread.is_alive():
+ tid = thread.name
+ if tid not in results:
+ result = thread.join()
+ if result:
+ results[tid] = result
+ if not result.normal_exit:
+ halt_event.set()
+ time.sleep(1.0)
+
+ test_results_report(results)
+ if config.config['coverage']:
+ code_coverage_report()
+ return 0
+
+
+if __name__ == '__main__':
+ main()