Code cleanup for run_test.py
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import text_utils
26 import thread_utils
27
28 logger = logging.getLogger(__name__)
29 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
30 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
31 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
32 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
33 args.add_argument(
34     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
35 )
36
37 HOME = os.environ['HOME']
38
39
40 @dataclass
41 class TestingParameters:
42     halt_on_error: bool
43     """Should we stop as soon as one error has occurred?"""
44
45     halt_event: threading.Event
46     """An event that, when set, indicates to stop ASAP."""
47
48
49 @dataclass
50 class TestResults:
51     name: str
52     """The name of this test / set of tests."""
53
54     tests_executed: List[str]
55     """Tests that were executed."""
56
57     tests_succeeded: List[str]
58     """Tests that succeeded."""
59
60     tests_failed: List[str]
61     """Tests that failed."""
62
63     tests_timed_out: List[str]
64     """Tests that timed out."""
65
66     def __add__(self, other):
67         self.tests_executed.extend(other.tests_executed)
68         self.tests_succeeded.extend(other.tests_succeeded)
69         self.tests_failed.extend(other.tests_failed)
70         self.tests_timed_out.extend(other.tests_timed_out)
71         return self
72
73     __radd__ = __add__
74
75     def __repr__(self) -> str:
76         out = f'{self.name}: '
77         out += f'{ansi.fg("green")}'
78         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
79         out += f'{ansi.reset()}.\n'
80
81         if len(self.tests_failed) > 0:
82             out += f'  ..{ansi.fg("red")}'
83             out += f'{len(self.tests_failed)} tests failed'
84             out += f'{ansi.reset()}:\n'
85             for test in self.tests_failed:
86                 out += f'    {test}\n'
87             out += '\n'
88
89         if len(self.tests_timed_out) > 0:
90             out += f'  ..{ansi.fg("yellow")}'
91             out += f'{len(self.tests_timed_out)} tests timed out'
92             out += f'{ansi.reset()}:\n'
93             for test in self.tests_failed:
94                 out += f'    {test}\n'
95             out += '\n'
96         return out
97
98
99 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
100     """A Base class for something that runs a test."""
101
102     def __init__(self, params: TestingParameters):
103         """Create a TestRunner.
104
105         Args:
106             params: Test running paramters.
107
108         """
109         super().__init__(self, target=self.begin, args=[params])
110         self.params = params
111         self.test_results = TestResults(
112             name=self.get_name(),
113             tests_executed=[],
114             tests_succeeded=[],
115             tests_failed=[],
116             tests_timed_out=[],
117         )
118
119     @abstractmethod
120     def get_name(self) -> str:
121         """The name of this test collection."""
122         pass
123
124     @abstractmethod
125     def begin(self, params: TestingParameters) -> TestResults:
126         """Start execution."""
127         pass
128
129
130 class TemplatedTestRunner(TestRunner, ABC):
131     """A TestRunner that has a recipe for executing the tests."""
132
133     @abstractmethod
134     def identify_tests(self) -> List[str]:
135         """Return a list of tests that should be executed."""
136         pass
137
138     @abstractmethod
139     def run_test(self, test: Any) -> TestResults:
140         """Run a single test and return its TestResults."""
141         pass
142
143     def check_for_abort(self):
144         """Periodically caled to check to see if we need to stop."""
145
146         if self.params.halt_event.is_set():
147             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
148             raise Exception("Kill myself!")
149         if self.params.halt_on_error:
150             if len(self.test_results.tests_failed) > 0:
151                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
152                 raise Exception("Kill myself!")
153
154     def status_report(self, running: List[Any], done: List[Any]):
155         """Periodically called to report current status."""
156
157         total = len(running) + len(done)
158         logging.info(
159             '%s: %d/%d in flight; %d/%d completed.',
160             self.get_name(),
161             len(running),
162             total,
163             len(done),
164             total,
165         )
166
167     def persist_output(self, test_name: str, message: str, output: str) -> None:
168         """Called to save the output of a test run."""
169
170         basename = file_utils.without_path(test_name)
171         dest = f'{basename}-output.txt'
172         with open(f'./test_output/{dest}', 'w') as wf:
173             print(message, file=wf)
174             print('-' * len(message), file=wf)
175             wf.write(output)
176
177     def execute_commandline(
178         self,
179         test_name: str,
180         cmdline: str,
181         *,
182         timeout: float = 120.0,
183     ) -> TestResults:
184         """Execute a particular commandline to run a test."""
185
186         try:
187             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
188             output = exec_utils.cmd(
189                 cmdline,
190                 timeout_seconds=timeout,
191             )
192             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
193             logger.debug('%s (%s) succeeded', test_name, cmdline)
194             return TestResults(test_name, [test_name], [test_name], [], [])
195         except subprocess.TimeoutExpired as e:
196             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
197             logger.error(msg)
198             logger.debug(
199                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
200             )
201             self.persist_output(test_name, msg, e.output)
202             return TestResults(
203                 test_name,
204                 [test_name],
205                 [],
206                 [],
207                 [test_name],
208             )
209         except subprocess.CalledProcessError as e:
210             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
211             logger.error(msg)
212             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
213             self.persist_output(test_name, msg, e.output)
214             return TestResults(
215                 test_name,
216                 [test_name],
217                 [],
218                 [test_name],
219                 [],
220             )
221
222     @overrides
223     def begin(self, params: TestingParameters) -> TestResults:
224         logger.debug('Thread %s started.', self.get_name())
225         interesting_tests = self.identify_tests()
226         running: List[Any] = []
227         done: List[Any] = []
228         for test in interesting_tests:
229             running.append(self.run_test(test))
230
231         while len(running) > 0:
232             self.status_report(running, done)
233             self.check_for_abort()
234             newly_finished = []
235             for fut in running:
236                 if fut.is_ready():
237                     newly_finished.append(fut)
238                     result = fut._resolve()
239                     logger.debug('Test %s finished.', result.name)
240                     self.test_results += result
241
242             for fut in newly_finished:
243                 running.remove(fut)
244                 done.append(fut)
245             time.sleep(1.0)
246
247         logger.debug('Thread %s finished.', self.get_name())
248         return self.test_results
249
250
251 class UnittestTestRunner(TemplatedTestRunner):
252     """Run all known Unittests."""
253
254     @overrides
255     def get_name(self) -> str:
256         return "UnittestTestRunner"
257
258     @overrides
259     def identify_tests(self) -> List[str]:
260         return list(file_utils.expand_globs('*_test.py'))
261
262     @par.parallelize
263     def run_test(self, test: Any) -> TestResults:
264         if config.config['coverage']:
265             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
266         else:
267             cmdline = test
268         return self.execute_commandline(test, cmdline)
269
270
271 class DoctestTestRunner(TemplatedTestRunner):
272     """Run all known Doctests."""
273
274     @overrides
275     def get_name(self) -> str:
276         return "DoctestTestRunner"
277
278     @overrides
279     def identify_tests(self) -> List[str]:
280         ret = []
281         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
282         for line in out.split('\n'):
283             if re.match(r'.*\.py$', line):
284                 if 'run_tests.py' not in line:
285                     ret.append(line)
286         return ret
287
288     @par.parallelize
289     def run_test(self, test: Any) -> TestResults:
290         if config.config['coverage']:
291             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
292         else:
293             cmdline = f'python3 {test}'
294         return self.execute_commandline(test, cmdline)
295
296
297 class IntegrationTestRunner(TemplatedTestRunner):
298     """Run all know Integration tests."""
299
300     @overrides
301     def get_name(self) -> str:
302         return "IntegrationTestRunner"
303
304     @overrides
305     def identify_tests(self) -> List[str]:
306         return list(file_utils.expand_globs('*_itest.py'))
307
308     @par.parallelize
309     def run_test(self, test: Any) -> TestResults:
310         if config.config['coverage']:
311             cmdline = f'coverage run --source {HOME}/lib {test}'
312         else:
313             cmdline = test
314         return self.execute_commandline(test, cmdline)
315
316
317 def test_results_report(results: Dict[str, TestResults]) -> int:
318     """Give a final report about the tests that were run."""
319     total_problems = 0
320     for result in results.values():
321         print(result, end='')
322         total_problems += len(result.tests_failed)
323         total_problems += len(result.tests_timed_out)
324
325     if total_problems > 0:
326         print('Reminder: look in ./test_output to view test output logs')
327     return total_problems
328
329
330 def code_coverage_report():
331     """Give a final code coverage report."""
332     text_utils.header('Code Coverage')
333     exec_utils.cmd('coverage combine .coverage*')
334     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
335     print(out)
336     print(
337         """
338 To recall this report w/o re-running the tests:
339
340     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
341
342 ...from the 'tests' directory.  Note that subsequent calls to
343 run_tests.py with --coverage will klobber previous results.  See:
344
345     https://coverage.readthedocs.io/en/6.2/
346 """
347     )
348
349
350 @bootstrap.initialize
351 def main() -> Optional[int]:
352     saw_flag = False
353     halt_event = threading.Event()
354     threads: List[TestRunner] = []
355
356     halt_event.clear()
357     params = TestingParameters(
358         halt_on_error=True,
359         halt_event=halt_event,
360     )
361
362     if config.config['coverage']:
363         logger.debug('Clearing existing coverage data via "coverage erase".')
364         exec_utils.cmd('coverage erase')
365
366     if config.config['unittests']:
367         saw_flag = True
368         threads.append(UnittestTestRunner(params))
369     if config.config['doctests']:
370         saw_flag = True
371         threads.append(DoctestTestRunner(params))
372     if config.config['integration']:
373         saw_flag = True
374         threads.append(IntegrationTestRunner(params))
375
376     if not saw_flag:
377         config.print_usage()
378         print('ERROR: one of --unittests, --doctests or --integration is required.')
379         return 1
380
381     for thread in threads:
382         thread.start()
383
384     results: Dict[str, TestResults] = {}
385     while len(results) != len(threads):
386         for thread in threads:
387             if not thread.is_alive():
388                 tid = thread.name
389                 if tid not in results:
390                     result = thread.join()
391                     if result:
392                         results[tid] = result
393                         if len(result.tests_failed) > 0:
394                             logger.error(
395                                 'Thread %s returned abnormal results; killing the others.', tid
396                             )
397                             halt_event.set()
398         time.sleep(1.0)
399
400     if config.config['coverage']:
401         code_coverage_report()
402     total_problems = test_results_report(results)
403     return total_problems
404
405
406 if __name__ == '__main__':
407     main()