f1266e0962cfb44027c3c0cfa2a5a0a98354bf49
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import smart_future
26 import text_utils
27 import thread_utils
28
29 logger = logging.getLogger(__name__)
30 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
31 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
32 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
33 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
34 args.add_argument(
35     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
36 )
37
38 HOME = os.environ['HOME']
39
40
41 @dataclass
42 class TestingParameters:
43     halt_on_error: bool
44     """Should we stop as soon as one error has occurred?"""
45
46     halt_event: threading.Event
47     """An event that, when set, indicates to stop ASAP."""
48
49
50 @dataclass
51 class TestResults:
52     name: str
53     """The name of this test / set of tests."""
54
55     tests_executed: List[str]
56     """Tests that were executed."""
57
58     tests_succeeded: List[str]
59     """Tests that succeeded."""
60
61     tests_failed: List[str]
62     """Tests that failed."""
63
64     tests_timed_out: List[str]
65     """Tests that timed out."""
66
67     def __add__(self, other):
68         self.tests_executed.extend(other.tests_executed)
69         self.tests_succeeded.extend(other.tests_succeeded)
70         self.tests_failed.extend(other.tests_failed)
71         self.tests_timed_out.extend(other.tests_timed_out)
72         return self
73
74     __radd__ = __add__
75
76     def __repr__(self) -> str:
77         out = f'{self.name}: '
78         out += f'{ansi.fg("green")}'
79         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
80         out += f'{ansi.reset()}.\n'
81
82         if len(self.tests_failed) > 0:
83             out += f'  ..{ansi.fg("red")}'
84             out += f'{len(self.tests_failed)} tests failed'
85             out += f'{ansi.reset()}:\n'
86             for test in self.tests_failed:
87                 out += f'    {test}\n'
88             out += '\n'
89
90         if len(self.tests_timed_out) > 0:
91             out += f'  ..{ansi.fg("yellow")}'
92             out += f'{len(self.tests_timed_out)} tests timed out'
93             out += f'{ansi.reset()}:\n'
94             for test in self.tests_failed:
95                 out += f'    {test}\n'
96             out += '\n'
97         return out
98
99
100 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
101     """A Base class for something that runs a test."""
102
103     def __init__(self, params: TestingParameters):
104         """Create a TestRunner.
105
106         Args:
107             params: Test running paramters.
108
109         """
110         super().__init__(self, target=self.begin, args=[params])
111         self.params = params
112         self.test_results = TestResults(
113             name=self.get_name(),
114             tests_executed=[],
115             tests_succeeded=[],
116             tests_failed=[],
117             tests_timed_out=[],
118         )
119
120     @abstractmethod
121     def get_name(self) -> str:
122         """The name of this test collection."""
123         pass
124
125     @abstractmethod
126     def begin(self, params: TestingParameters) -> TestResults:
127         """Start execution."""
128         pass
129
130
131 class TemplatedTestRunner(TestRunner, ABC):
132     """A TestRunner that has a recipe for executing the tests."""
133
134     @abstractmethod
135     def identify_tests(self) -> List[str]:
136         """Return a list of tests that should be executed."""
137         pass
138
139     @abstractmethod
140     def run_test(self, test: Any) -> TestResults:
141         """Run a single test and return its TestResults."""
142         pass
143
144     def check_for_abort(self):
145         """Periodically caled to check to see if we need to stop."""
146
147         if self.params.halt_event.is_set():
148             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
149             raise Exception("Kill myself!")
150         if self.params.halt_on_error:
151             if len(self.test_results.tests_failed) > 0:
152                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
153                 raise Exception("Kill myself!")
154
155     def status_report(self, started: int, result: TestResults):
156         """Periodically called to report current status."""
157
158         finished = (
159             len(self.test_results.tests_succeeded)
160             + len(self.test_results.tests_failed)
161             + len(self.test_results.tests_timed_out)
162         )
163         running = started - finished
164         finished_percent = finished / started * 100.0
165         logging.info(
166             '%s: %d/%d in flight; %d/%d finished (%.1f%%).',
167             self.get_name(),
168             running,
169             started,
170             finished,
171             started,
172             finished_percent,
173         )
174
175     def persist_output(self, test_name: str, message: str, output: str) -> None:
176         """Called to save the output of a test run."""
177
178         basename = file_utils.without_path(test_name)
179         dest = f'{basename}-output.txt'
180         with open(f'./test_output/{dest}', 'w') as wf:
181             print(message, file=wf)
182             print('-' * len(message), file=wf)
183             wf.write(output)
184
185     def execute_commandline(
186         self,
187         test_name: str,
188         cmdline: str,
189         *,
190         timeout: float = 120.0,
191     ) -> TestResults:
192         """Execute a particular commandline to run a test."""
193
194         try:
195             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
196             output = exec_utils.cmd(
197                 cmdline,
198                 timeout_seconds=timeout,
199             )
200             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
201             logger.debug('%s (%s) succeeded', test_name, cmdline)
202             return TestResults(test_name, [test_name], [test_name], [], [])
203         except subprocess.TimeoutExpired as e:
204             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
205             logger.error(msg)
206             logger.debug(
207                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
208             )
209             self.persist_output(test_name, msg, e.output.decode('utf-8'))
210             return TestResults(
211                 test_name,
212                 [test_name],
213                 [],
214                 [],
215                 [test_name],
216             )
217         except subprocess.CalledProcessError as e:
218             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
219             logger.error(msg)
220             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
221             self.persist_output(test_name, msg, e.output.decide('utf-8'))
222             return TestResults(
223                 test_name,
224                 [test_name],
225                 [],
226                 [test_name],
227                 [],
228             )
229
230     @overrides
231     def begin(self, params: TestingParameters) -> TestResults:
232         logger.debug('Thread %s started.', self.get_name())
233         interesting_tests = self.identify_tests()
234
235         running: List[Any] = []
236         for test in interesting_tests:
237             running.append(self.run_test(test))
238         started = len(running)
239
240         for future in smart_future.wait_any(running):
241             self.check_for_abort()
242             result = future._resolve()
243             self.status_report(started, result)
244             logger.debug('Test %s finished.', result.name)
245             self.test_results += result
246
247         logger.debug('Thread %s finished.', self.get_name())
248         return self.test_results
249
250
251 class UnittestTestRunner(TemplatedTestRunner):
252     """Run all known Unittests."""
253
254     @overrides
255     def get_name(self) -> str:
256         return "Unittests"
257
258     @overrides
259     def identify_tests(self) -> List[str]:
260         return list(file_utils.expand_globs('*_test.py'))
261
262     @par.parallelize
263     def run_test(self, test: Any) -> TestResults:
264         if config.config['coverage']:
265             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
266         else:
267             cmdline = test
268         return self.execute_commandline(test, cmdline)
269
270
271 class DoctestTestRunner(TemplatedTestRunner):
272     """Run all known Doctests."""
273
274     @overrides
275     def get_name(self) -> str:
276         return "Doctests"
277
278     @overrides
279     def identify_tests(self) -> List[str]:
280         ret = []
281         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
282         for line in out.split('\n'):
283             if re.match(r'.*\.py$', line):
284                 if 'run_tests.py' not in line:
285                     ret.append(line)
286         return ret
287
288     @par.parallelize
289     def run_test(self, test: Any) -> TestResults:
290         if config.config['coverage']:
291             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
292         else:
293             cmdline = f'python3 {test}'
294         return self.execute_commandline(test, cmdline)
295
296
297 class IntegrationTestRunner(TemplatedTestRunner):
298     """Run all know Integration tests."""
299
300     @overrides
301     def get_name(self) -> str:
302         return "Integration Tests"
303
304     @overrides
305     def identify_tests(self) -> List[str]:
306         return list(file_utils.expand_globs('*_itest.py'))
307
308     @par.parallelize
309     def run_test(self, test: Any) -> TestResults:
310         if config.config['coverage']:
311             cmdline = f'coverage run --source {HOME}/lib {test}'
312         else:
313             cmdline = test
314         return self.execute_commandline(test, cmdline)
315
316
317 def test_results_report(results: Dict[str, TestResults]) -> int:
318     """Give a final report about the tests that were run."""
319     total_problems = 0
320     for result in results.values():
321         print(result, end='')
322         total_problems += len(result.tests_failed)
323         total_problems += len(result.tests_timed_out)
324
325     if total_problems > 0:
326         print('Reminder: look in ./test_output to view test output logs')
327     return total_problems
328
329
330 def code_coverage_report():
331     """Give a final code coverage report."""
332     text_utils.header('Code Coverage')
333     exec_utils.cmd('coverage combine .coverage*')
334     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
335     print(out)
336     print(
337         """
338 To recall this report w/o re-running the tests:
339
340     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
341
342 ...from the 'tests' directory.  Note that subsequent calls to
343 run_tests.py with --coverage will klobber previous results.  See:
344
345     https://coverage.readthedocs.io/en/6.2/
346 """
347     )
348
349
350 @bootstrap.initialize
351 def main() -> Optional[int]:
352     saw_flag = False
353     halt_event = threading.Event()
354     threads: List[TestRunner] = []
355
356     halt_event.clear()
357     params = TestingParameters(
358         halt_on_error=True,
359         halt_event=halt_event,
360     )
361
362     if config.config['coverage']:
363         logger.debug('Clearing existing coverage data via "coverage erase".')
364         exec_utils.cmd('coverage erase')
365
366     if config.config['unittests']:
367         saw_flag = True
368         threads.append(UnittestTestRunner(params))
369     if config.config['doctests']:
370         saw_flag = True
371         threads.append(DoctestTestRunner(params))
372     if config.config['integration']:
373         saw_flag = True
374         threads.append(IntegrationTestRunner(params))
375
376     if not saw_flag:
377         config.print_usage()
378         print('ERROR: one of --unittests, --doctests or --integration is required.')
379         return 1
380
381     for thread in threads:
382         thread.start()
383
384     results: Dict[str, TestResults] = {}
385     while len(results) != len(threads):
386         for thread in threads:
387             if not thread.is_alive():
388                 tid = thread.name
389                 if tid not in results:
390                     result = thread.join()
391                     if result:
392                         results[tid] = result
393                         if len(result.tests_failed) > 0:
394                             logger.error(
395                                 'Thread %s returned abnormal results; killing the others.', tid
396                             )
397                             halt_event.set()
398         time.sleep(1.0)
399
400     if config.config['coverage']:
401         code_coverage_report()
402     total_problems = test_results_report(results)
403     return total_problems
404
405
406 if __name__ == '__main__':
407     main()