Final touches on the new test runner.
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import text_utils
26 import thread_utils
27
28 logger = logging.getLogger(__name__)
29 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
30 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
31 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
32 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
33 args.add_argument(
34     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
35 )
36
37 HOME = os.environ['HOME']
38
39
40 @dataclass
41 class TestingParameters:
42     halt_on_error: bool
43     halt_event: threading.Event
44
45
46 @dataclass
47 class TestResults:
48     name: str
49     tests_executed: List[str]
50     tests_succeeded: List[str]
51     tests_failed: List[str]
52     tests_timed_out: List[str]
53
54
55 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
56     def __init__(self, params: TestingParameters):
57         super().__init__(self, target=self.begin, args=[params])
58         self.params = params
59         self.test_results = TestResults(
60             name=self.get_name(),
61             tests_executed=[],
62             tests_succeeded=[],
63             tests_failed=[],
64             tests_timed_out=[],
65         )
66
67     def aggregate_test_results(self, result: TestResults):
68         self.test_results.tests_executed.extend(result.tests_executed)
69         self.test_results.tests_succeeded.extend(result.tests_succeeded)
70         self.test_results.tests_failed.extend(result.tests_failed)
71         self.test_results.tests_timed_out.extend(result.tests_timed_out)
72
73     @abstractmethod
74     def get_name(self) -> str:
75         pass
76
77     @abstractmethod
78     def begin(self, params: TestingParameters) -> TestResults:
79         pass
80
81
82 class TemplatedTestRunner(TestRunner, ABC):
83     @abstractmethod
84     def identify_tests(self) -> List[Any]:
85         pass
86
87     @abstractmethod
88     def run_test(self, test: Any) -> TestResults:
89         pass
90
91     def check_for_abort(self):
92         if self.params.halt_event.is_set():
93             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
94             raise Exception("Kill myself!")
95         if self.params.halt_on_error:
96             if len(self.test_results.tests_failed) > 0:
97                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
98                 raise Exception("Kill myself!")
99
100     def status_report(self, running: List[Any], done: List[Any]):
101         total = len(running) + len(done)
102         logging.info(
103             '%s: %d/%d in flight; %d/%d completed.',
104             self.get_name(),
105             len(running),
106             total,
107             len(done),
108             total,
109         )
110
111     def persist_output(self, test_name: str, message: str, output: str) -> None:
112         basename = file_utils.without_path(test_name)
113         dest = f'{basename}-output.txt'
114         with open(f'./test_output/{dest}', 'w') as wf:
115             print(message, file=wf)
116             print('-' * len(message), file=wf)
117             wf.write(output)
118
119     def execute_commandline(
120         self,
121         test_name: str,
122         cmdline: str,
123         *,
124         timeout: float = 120.0,
125     ) -> TestResults:
126
127         try:
128             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
129             output = exec_utils.cmd(
130                 cmdline,
131                 timeout_seconds=timeout,
132             )
133             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
134             logger.debug('%s (%s) succeeded', test_name, cmdline)
135             return TestResults(test_name, [test_name], [test_name], [], [])
136         except subprocess.TimeoutExpired as e:
137             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
138             logger.error(msg)
139             logger.debug(
140                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
141             )
142             self.persist_output(test_name, msg, e.output)
143             return TestResults(
144                 test_name,
145                 [test_name],
146                 [],
147                 [],
148                 [test_name],
149             )
150         except subprocess.CalledProcessError as e:
151             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
152             logger.error(msg)
153             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
154             self.persist_output(test_name, msg, e.output)
155             return TestResults(
156                 test_name,
157                 [test_name],
158                 [],
159                 [test_name],
160                 [],
161             )
162
163     @overrides
164     def begin(self, params: TestingParameters) -> TestResults:
165         logger.debug('Thread %s started.', self.get_name())
166         interesting_tests = self.identify_tests()
167         running: List[Any] = []
168         done: List[Any] = []
169         for test in interesting_tests:
170             running.append(self.run_test(test))
171
172         while len(running) > 0:
173             self.status_report(running, done)
174             self.check_for_abort()
175             newly_finished = []
176             for fut in running:
177                 if fut.is_ready():
178                     newly_finished.append(fut)
179                     result = fut._resolve()
180                     logger.debug('Test %s finished.', result.name)
181                     self.aggregate_test_results(result)
182
183             for fut in newly_finished:
184                 running.remove(fut)
185                 done.append(fut)
186             time.sleep(1.0)
187
188         logger.debug('Thread %s finished.', self.get_name())
189         return self.test_results
190
191
192 class UnittestTestRunner(TemplatedTestRunner):
193     @overrides
194     def get_name(self) -> str:
195         return "UnittestTestRunner"
196
197     @overrides
198     def identify_tests(self) -> List[Any]:
199         return list(file_utils.expand_globs('*_test.py'))
200
201     @par.parallelize
202     def run_test(self, test: Any) -> TestResults:
203         if config.config['coverage']:
204             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
205         else:
206             cmdline = test
207         return self.execute_commandline(test, cmdline)
208
209
210 class DoctestTestRunner(TemplatedTestRunner):
211     @overrides
212     def get_name(self) -> str:
213         return "DoctestTestRunner"
214
215     @overrides
216     def identify_tests(self) -> List[Any]:
217         ret = []
218         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
219         for line in out.split('\n'):
220             if re.match(r'.*\.py$', line):
221                 if 'run_tests.py' not in line:
222                     ret.append(line)
223         return ret
224
225     @par.parallelize
226     def run_test(self, test: Any) -> TestResults:
227         if config.config['coverage']:
228             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
229         else:
230             cmdline = f'python3 {test}'
231         return self.execute_commandline(test, cmdline)
232
233
234 class IntegrationTestRunner(TemplatedTestRunner):
235     @overrides
236     def get_name(self) -> str:
237         return "IntegrationTestRunner"
238
239     @overrides
240     def identify_tests(self) -> List[Any]:
241         return list(file_utils.expand_globs('*_itest.py'))
242
243     @par.parallelize
244     def run_test(self, test: Any) -> TestResults:
245         if config.config['coverage']:
246             cmdline = f'coverage run --source {HOME}/lib {test}'
247         else:
248             cmdline = test
249         return self.execute_commandline(test, cmdline)
250
251
252 def test_results_report(results: Dict[str, TestResults]) -> int:
253     total_problems = 0
254     for type, result in results.items():
255         print(f'{result.name}: ', end='')
256         print(
257             f'{ansi.fg("green")}{len(result.tests_succeeded)}/{len(result.tests_executed)} passed{ansi.reset()}.'
258         )
259         if len(result.tests_failed) > 0:
260             print(f'  ..{ansi.fg("red")}{len(result.tests_failed)} tests failed{ansi.reset()}:')
261             for test in result.tests_failed:
262                 print(f'    {test}')
263             total_problems += len(result.tests_failed)
264
265         if len(result.tests_timed_out) > 0:
266             print(
267                 f'  ..{ansi.fg("yellow")}{len(result.tests_timed_out)} tests timed out{ansi.reset()}:'
268             )
269             for test in result.tests_failed:
270                 print(f'    {test}')
271             total_problems += len(result.tests_timed_out)
272
273     if total_problems > 0:
274         print('Reminder: look in ./test_output to view test output logs')
275     return total_problems
276
277
278 def code_coverage_report():
279     text_utils.header('Code Coverage')
280     exec_utils.cmd('coverage combine .coverage*')
281     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
282     print(out)
283     print(
284         """
285 To recall this report w/o re-running the tests:
286
287     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
288
289 ...from the 'tests' directory.  Note that subsequent calls to
290 run_tests.py with --coverage will klobber previous results.  See:
291
292     https://coverage.readthedocs.io/en/6.2/
293 """
294     )
295
296
297 @bootstrap.initialize
298 def main() -> Optional[int]:
299     saw_flag = False
300     halt_event = threading.Event()
301     threads: List[TestRunner] = []
302
303     halt_event.clear()
304     params = TestingParameters(
305         halt_on_error=True,
306         halt_event=halt_event,
307     )
308
309     if config.config['coverage']:
310         logger.debug('Clearing existing coverage data via "coverage erase".')
311         exec_utils.cmd('coverage erase')
312
313     if config.config['unittests']:
314         saw_flag = True
315         threads.append(UnittestTestRunner(params))
316     if config.config['doctests']:
317         saw_flag = True
318         threads.append(DoctestTestRunner(params))
319     if config.config['integration']:
320         saw_flag = True
321         threads.append(IntegrationTestRunner(params))
322
323     if not saw_flag:
324         config.print_usage()
325         print('ERROR: one of --unittests, --doctests or --integration is required.')
326         return 1
327
328     for thread in threads:
329         thread.start()
330
331     results: Dict[str, TestResults] = {}
332     while len(results) != len(threads):
333         for thread in threads:
334             if not thread.is_alive():
335                 tid = thread.name
336                 if tid not in results:
337                     result = thread.join()
338                     if result:
339                         results[tid] = result
340                         if len(result.tests_failed) > 0:
341                             logger.error(
342                                 'Thread %s returned abnormal results; killing the others.', tid
343                             )
344                             halt_event.set()
345         time.sleep(1.0)
346
347     if config.config['coverage']:
348         code_coverage_report()
349     total_problems = test_results_report(results)
350     return total_problems
351
352
353 if __name__ == '__main__':
354     main()