5162e238f1d37181b5dc9ea3988e1a443b3231c4
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import smart_future
26 import text_utils
27 import thread_utils
28
29 logger = logging.getLogger(__name__)
30 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
31 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
32 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
33 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
34 args.add_argument(
35     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
36 )
37
38 HOME = os.environ['HOME']
39
40
41 @dataclass
42 class TestingParameters:
43     halt_on_error: bool
44     """Should we stop as soon as one error has occurred?"""
45
46     halt_event: threading.Event
47     """An event that, when set, indicates to stop ASAP."""
48
49
50 @dataclass
51 class TestResults:
52     name: str
53     """The name of this test / set of tests."""
54
55     tests_executed: List[str]
56     """Tests that were executed."""
57
58     tests_succeeded: List[str]
59     """Tests that succeeded."""
60
61     tests_failed: List[str]
62     """Tests that failed."""
63
64     tests_timed_out: List[str]
65     """Tests that timed out."""
66
67     def __add__(self, other):
68         self.tests_executed.extend(other.tests_executed)
69         self.tests_succeeded.extend(other.tests_succeeded)
70         self.tests_failed.extend(other.tests_failed)
71         self.tests_timed_out.extend(other.tests_timed_out)
72         return self
73
74     __radd__ = __add__
75
76     def __repr__(self) -> str:
77         out = f'{self.name}: '
78         out += f'{ansi.fg("green")}'
79         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
80         out += f'{ansi.reset()}.\n'
81
82         if len(self.tests_failed) > 0:
83             out += f'  ..{ansi.fg("red")}'
84             out += f'{len(self.tests_failed)} tests failed'
85             out += f'{ansi.reset()}:\n'
86             for test in self.tests_failed:
87                 out += f'    {test}\n'
88             out += '\n'
89
90         if len(self.tests_timed_out) > 0:
91             out += f'  ..{ansi.fg("yellow")}'
92             out += f'{len(self.tests_timed_out)} tests timed out'
93             out += f'{ansi.reset()}:\n'
94             for test in self.tests_failed:
95                 out += f'    {test}\n'
96             out += '\n'
97         return out
98
99
100 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
101     """A Base class for something that runs a test."""
102
103     def __init__(self, params: TestingParameters):
104         """Create a TestRunner.
105
106         Args:
107             params: Test running paramters.
108
109         """
110         super().__init__(self, target=self.begin, args=[params])
111         self.params = params
112         self.test_results = TestResults(
113             name=self.get_name(),
114             tests_executed=[],
115             tests_succeeded=[],
116             tests_failed=[],
117             tests_timed_out=[],
118         )
119         self.tests_started = 0
120
121     @abstractmethod
122     def get_name(self) -> str:
123         """The name of this test collection."""
124         pass
125
126     def get_status(self) -> Tuple[int, TestResults]:
127         """Ask the TestRunner for its status."""
128         return (self.tests_started, self.test_results)
129
130     @abstractmethod
131     def begin(self, params: TestingParameters) -> TestResults:
132         """Start execution."""
133         pass
134
135
136 class TemplatedTestRunner(TestRunner, ABC):
137     """A TestRunner that has a recipe for executing the tests."""
138
139     @abstractmethod
140     def identify_tests(self) -> List[str]:
141         """Return a list of tests that should be executed."""
142         pass
143
144     @abstractmethod
145     def run_test(self, test: Any) -> TestResults:
146         """Run a single test and return its TestResults."""
147         pass
148
149     def check_for_abort(self):
150         """Periodically caled to check to see if we need to stop."""
151
152         if self.params.halt_event.is_set():
153             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
154             raise Exception("Kill myself!")
155         if self.params.halt_on_error:
156             if len(self.test_results.tests_failed) > 0:
157                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
158                 raise Exception("Kill myself!")
159
160     def persist_output(self, test_name: str, message: str, output: str) -> None:
161         """Called to save the output of a test run."""
162
163         basename = file_utils.without_path(test_name)
164         dest = f'{basename}-output.txt'
165         with open(f'./test_output/{dest}', 'w') as wf:
166             print(message, file=wf)
167             print('-' * len(message), file=wf)
168             wf.write(output)
169
170     def execute_commandline(
171         self,
172         test_name: str,
173         cmdline: str,
174         *,
175         timeout: float = 120.0,
176     ) -> TestResults:
177         """Execute a particular commandline to run a test."""
178
179         try:
180             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
181             output = exec_utils.cmd(
182                 cmdline,
183                 timeout_seconds=timeout,
184             )
185             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
186             logger.debug('%s (%s) succeeded', test_name, cmdline)
187             return TestResults(test_name, [test_name], [test_name], [], [])
188         except subprocess.TimeoutExpired as e:
189             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
190             logger.error(msg)
191             logger.debug(
192                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
193             )
194             self.persist_output(test_name, msg, e.output.decode('utf-8'))
195             return TestResults(
196                 test_name,
197                 [test_name],
198                 [],
199                 [],
200                 [test_name],
201             )
202         except subprocess.CalledProcessError as e:
203             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
204             logger.error(msg)
205             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
206             self.persist_output(test_name, msg, e.output.decode('utf-8'))
207             return TestResults(
208                 test_name,
209                 [test_name],
210                 [],
211                 [test_name],
212                 [],
213             )
214
215     @overrides
216     def begin(self, params: TestingParameters) -> TestResults:
217         logger.debug('Thread %s started.', self.get_name())
218         interesting_tests = self.identify_tests()
219
220         running: List[Any] = []
221         for test in interesting_tests:
222             running.append(self.run_test(test))
223         self.tests_started = len(running)
224
225         for future in smart_future.wait_any(running):
226             self.check_for_abort()
227             result = future._resolve()
228             logger.debug('Test %s finished.', result.name)
229             self.test_results += result
230
231         logger.debug('Thread %s finished.', self.get_name())
232         return self.test_results
233
234
235 class UnittestTestRunner(TemplatedTestRunner):
236     """Run all known Unittests."""
237
238     @overrides
239     def get_name(self) -> str:
240         return "Unittests"
241
242     @overrides
243     def identify_tests(self) -> List[str]:
244         return list(file_utils.expand_globs('*_test.py'))
245
246     @par.parallelize
247     def run_test(self, test: Any) -> TestResults:
248         if config.config['coverage']:
249             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
250         else:
251             cmdline = test
252         return self.execute_commandline(test, cmdline)
253
254
255 class DoctestTestRunner(TemplatedTestRunner):
256     """Run all known Doctests."""
257
258     @overrides
259     def get_name(self) -> str:
260         return "Doctests"
261
262     @overrides
263     def identify_tests(self) -> List[str]:
264         ret = []
265         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
266         for line in out.split('\n'):
267             if re.match(r'.*\.py$', line):
268                 if 'run_tests.py' not in line:
269                     ret.append(line)
270         return ret
271
272     @par.parallelize
273     def run_test(self, test: Any) -> TestResults:
274         if config.config['coverage']:
275             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
276         else:
277             cmdline = f'python3 {test}'
278         return self.execute_commandline(test, cmdline)
279
280
281 class IntegrationTestRunner(TemplatedTestRunner):
282     """Run all know Integration tests."""
283
284     @overrides
285     def get_name(self) -> str:
286         return "Integration Tests"
287
288     @overrides
289     def identify_tests(self) -> List[str]:
290         return list(file_utils.expand_globs('*_itest.py'))
291
292     @par.parallelize
293     def run_test(self, test: Any) -> TestResults:
294         if config.config['coverage']:
295             cmdline = f'coverage run --source {HOME}/lib {test}'
296         else:
297             cmdline = test
298         return self.execute_commandline(test, cmdline)
299
300
301 def test_results_report(results: Dict[str, TestResults]) -> int:
302     """Give a final report about the tests that were run."""
303     total_problems = 0
304     for result in results.values():
305         print(result, end='')
306         total_problems += len(result.tests_failed)
307         total_problems += len(result.tests_timed_out)
308
309     if total_problems > 0:
310         print('Reminder: look in ./test_output to view test output logs')
311     return total_problems
312
313
314 def code_coverage_report():
315     """Give a final code coverage report."""
316     text_utils.header('Code Coverage')
317     exec_utils.cmd('coverage combine .coverage*')
318     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
319     print(out)
320     print(
321         """
322 To recall this report w/o re-running the tests:
323
324     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
325
326 ...from the 'tests' directory.  Note that subsequent calls to
327 run_tests.py with --coverage will klobber previous results.  See:
328
329     https://coverage.readthedocs.io/en/6.2/
330 """
331     )
332
333
334 @bootstrap.initialize
335 def main() -> Optional[int]:
336     saw_flag = False
337     halt_event = threading.Event()
338     threads: List[TestRunner] = []
339
340     halt_event.clear()
341     params = TestingParameters(
342         halt_on_error=True,
343         halt_event=halt_event,
344     )
345
346     if config.config['coverage']:
347         logger.debug('Clearing existing coverage data via "coverage erase".')
348         exec_utils.cmd('coverage erase')
349
350     if config.config['unittests']:
351         saw_flag = True
352         threads.append(UnittestTestRunner(params))
353     if config.config['doctests']:
354         saw_flag = True
355         threads.append(DoctestTestRunner(params))
356     if config.config['integration']:
357         saw_flag = True
358         threads.append(IntegrationTestRunner(params))
359
360     if not saw_flag:
361         config.print_usage()
362         print('ERROR: one of --unittests, --doctests or --integration is required.')
363         return 1
364
365     for thread in threads:
366         thread.start()
367
368     results: Dict[str, TestResults] = {}
369     while len(results) != len(threads):
370         started = 0
371         done = 0
372         failed = 0
373
374         for thread in threads:
375             if not thread.is_alive():
376                 tid = thread.name
377                 if tid not in results:
378                     result = thread.join()
379                     if result:
380                         results[tid] = result
381                         if len(result.tests_failed) > 0:
382                             logger.error(
383                                 'Thread %s returned abnormal results; killing the others.', tid
384                             )
385                             halt_event.set()
386             else:
387                 (s, tr) = thread.get_status()
388                 started += s
389                 failed += len(tr.tests_failed) + len(tr.tests_timed_out)
390                 done += failed + len(tr.tests_succeeded)
391
392         if started > 0:
393             percent_done = done / started
394         else:
395             percent_done = 0.0
396
397         if failed == 0:
398             color = ansi.fg('green')
399         else:
400             color = ansi.fg('red')
401
402         if percent_done < 100.0:
403             print(
404                 text_utils.bar_graph(
405                     percent_done,
406                     width=80,
407                     fgcolor=color,
408                 ),
409                 end='\r',
410                 flush=True,
411             )
412         else:
413             print("Finished.\n")
414         time.sleep(0.5)
415
416     if config.config['coverage']:
417         code_coverage_report()
418     total_problems = test_results_report(results)
419     return total_problems
420
421
422 if __name__ == '__main__':
423     main()