Add --all and cleanup clear_line().
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import smart_future
26 import text_utils
27 import thread_utils
28
29 logger = logging.getLogger(__name__)
30 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
31 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
32 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
33 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
34 args.add_argument(
35     '--all',
36     '-a',
37     action='store_true',
38     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
39 )
40 args.add_argument(
41     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
42 )
43
44 HOME = os.environ['HOME']
45
46
47 @dataclass
48 class TestingParameters:
49     halt_on_error: bool
50     """Should we stop as soon as one error has occurred?"""
51
52     halt_event: threading.Event
53     """An event that, when set, indicates to stop ASAP."""
54
55
56 @dataclass
57 class TestResults:
58     name: str
59     """The name of this test / set of tests."""
60
61     tests_executed: List[str]
62     """Tests that were executed."""
63
64     tests_succeeded: List[str]
65     """Tests that succeeded."""
66
67     tests_failed: List[str]
68     """Tests that failed."""
69
70     tests_timed_out: List[str]
71     """Tests that timed out."""
72
73     def __add__(self, other):
74         self.tests_executed.extend(other.tests_executed)
75         self.tests_succeeded.extend(other.tests_succeeded)
76         self.tests_failed.extend(other.tests_failed)
77         self.tests_timed_out.extend(other.tests_timed_out)
78         return self
79
80     __radd__ = __add__
81
82     def __repr__(self) -> str:
83         out = f'{self.name}: '
84         out += f'{ansi.fg("green")}'
85         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
86         out += f'{ansi.reset()}.\n'
87
88         if len(self.tests_failed) > 0:
89             out += f'  ..{ansi.fg("red")}'
90             out += f'{len(self.tests_failed)} tests failed'
91             out += f'{ansi.reset()}:\n'
92             for test in self.tests_failed:
93                 out += f'    {test}\n'
94             out += '\n'
95
96         if len(self.tests_timed_out) > 0:
97             out += f'  ..{ansi.fg("yellow")}'
98             out += f'{len(self.tests_timed_out)} tests timed out'
99             out += f'{ansi.reset()}:\n'
100             for test in self.tests_failed:
101                 out += f'    {test}\n'
102             out += '\n'
103         return out
104
105
106 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
107     """A Base class for something that runs a test."""
108
109     def __init__(self, params: TestingParameters):
110         """Create a TestRunner.
111
112         Args:
113             params: Test running paramters.
114
115         """
116         super().__init__(self, target=self.begin, args=[params])
117         self.params = params
118         self.test_results = TestResults(
119             name=self.get_name(),
120             tests_executed=[],
121             tests_succeeded=[],
122             tests_failed=[],
123             tests_timed_out=[],
124         )
125         self.tests_started = 0
126
127     @abstractmethod
128     def get_name(self) -> str:
129         """The name of this test collection."""
130         pass
131
132     def get_status(self) -> Tuple[int, TestResults]:
133         """Ask the TestRunner for its status."""
134         return (self.tests_started, self.test_results)
135
136     @abstractmethod
137     def begin(self, params: TestingParameters) -> TestResults:
138         """Start execution."""
139         pass
140
141
142 class TemplatedTestRunner(TestRunner, ABC):
143     """A TestRunner that has a recipe for executing the tests."""
144
145     @abstractmethod
146     def identify_tests(self) -> List[str]:
147         """Return a list of tests that should be executed."""
148         pass
149
150     @abstractmethod
151     def run_test(self, test: Any) -> TestResults:
152         """Run a single test and return its TestResults."""
153         pass
154
155     def check_for_abort(self):
156         """Periodically caled to check to see if we need to stop."""
157
158         if self.params.halt_event.is_set():
159             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
160             raise Exception("Kill myself!")
161         if self.params.halt_on_error:
162             if len(self.test_results.tests_failed) > 0:
163                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
164                 raise Exception("Kill myself!")
165
166     def persist_output(self, test_name: str, message: str, output: str) -> None:
167         """Called to save the output of a test run."""
168
169         basename = file_utils.without_path(test_name)
170         dest = f'{basename}-output.txt'
171         with open(f'./test_output/{dest}', 'w') as wf:
172             print(message, file=wf)
173             print('-' * len(message), file=wf)
174             wf.write(output)
175
176     def execute_commandline(
177         self,
178         test_name: str,
179         cmdline: str,
180         *,
181         timeout: float = 120.0,
182     ) -> TestResults:
183         """Execute a particular commandline to run a test."""
184
185         try:
186             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
187             output = exec_utils.cmd(
188                 cmdline,
189                 timeout_seconds=timeout,
190             )
191             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
192             logger.debug('%s (%s) succeeded', test_name, cmdline)
193             return TestResults(test_name, [test_name], [test_name], [], [])
194         except subprocess.TimeoutExpired as e:
195             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
196             logger.error(msg)
197             logger.debug(
198                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
199             )
200             self.persist_output(test_name, msg, e.output.decode('utf-8'))
201             return TestResults(
202                 test_name,
203                 [test_name],
204                 [],
205                 [],
206                 [test_name],
207             )
208         except subprocess.CalledProcessError as e:
209             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
210             logger.error(msg)
211             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
212             self.persist_output(test_name, msg, e.output.decode('utf-8'))
213             return TestResults(
214                 test_name,
215                 [test_name],
216                 [],
217                 [test_name],
218                 [],
219             )
220
221     @overrides
222     def begin(self, params: TestingParameters) -> TestResults:
223         logger.debug('Thread %s started.', self.get_name())
224         interesting_tests = self.identify_tests()
225
226         running: List[Any] = []
227         for test in interesting_tests:
228             running.append(self.run_test(test))
229         self.tests_started = len(running)
230
231         for future in smart_future.wait_any(running):
232             self.check_for_abort()
233             result = future._resolve()
234             logger.debug('Test %s finished.', result.name)
235             self.test_results += result
236
237         logger.debug('Thread %s finished.', self.get_name())
238         return self.test_results
239
240
241 class UnittestTestRunner(TemplatedTestRunner):
242     """Run all known Unittests."""
243
244     @overrides
245     def get_name(self) -> str:
246         return "Unittests"
247
248     @overrides
249     def identify_tests(self) -> List[str]:
250         return list(file_utils.expand_globs('*_test.py'))
251
252     @par.parallelize
253     def run_test(self, test: Any) -> TestResults:
254         if config.config['coverage']:
255             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
256         else:
257             cmdline = test
258         return self.execute_commandline(test, cmdline)
259
260
261 class DoctestTestRunner(TemplatedTestRunner):
262     """Run all known Doctests."""
263
264     @overrides
265     def get_name(self) -> str:
266         return "Doctests"
267
268     @overrides
269     def identify_tests(self) -> List[str]:
270         ret = []
271         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
272         for line in out.split('\n'):
273             if re.match(r'.*\.py$', line):
274                 if 'run_tests.py' not in line:
275                     ret.append(line)
276         return ret
277
278     @par.parallelize
279     def run_test(self, test: Any) -> TestResults:
280         if config.config['coverage']:
281             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
282         else:
283             cmdline = f'python3 {test}'
284         return self.execute_commandline(test, cmdline)
285
286
287 class IntegrationTestRunner(TemplatedTestRunner):
288     """Run all know Integration tests."""
289
290     @overrides
291     def get_name(self) -> str:
292         return "Integration Tests"
293
294     @overrides
295     def identify_tests(self) -> List[str]:
296         return list(file_utils.expand_globs('*_itest.py'))
297
298     @par.parallelize
299     def run_test(self, test: Any) -> TestResults:
300         if config.config['coverage']:
301             cmdline = f'coverage run --source {HOME}/lib {test}'
302         else:
303             cmdline = test
304         return self.execute_commandline(test, cmdline)
305
306
307 def test_results_report(results: Dict[str, TestResults]) -> int:
308     """Give a final report about the tests that were run."""
309     total_problems = 0
310     for result in results.values():
311         print(result, end='')
312         total_problems += len(result.tests_failed)
313         total_problems += len(result.tests_timed_out)
314
315     if total_problems > 0:
316         print('Reminder: look in ./test_output to view test output logs')
317     return total_problems
318
319
320 def code_coverage_report():
321     """Give a final code coverage report."""
322     text_utils.header('Code Coverage')
323     exec_utils.cmd('coverage combine .coverage*')
324     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
325     print(out)
326     print(
327         """
328 To recall this report w/o re-running the tests:
329
330     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
331
332 ...from the 'tests' directory.  Note that subsequent calls to
333 run_tests.py with --coverage will klobber previous results.  See:
334
335     https://coverage.readthedocs.io/en/6.2/
336 """
337     )
338
339
340 @bootstrap.initialize
341 def main() -> Optional[int]:
342     saw_flag = False
343     halt_event = threading.Event()
344     threads: List[TestRunner] = []
345
346     halt_event.clear()
347     params = TestingParameters(
348         halt_on_error=True,
349         halt_event=halt_event,
350     )
351
352     if config.config['coverage']:
353         logger.debug('Clearing existing coverage data via "coverage erase".')
354         exec_utils.cmd('coverage erase')
355
356     if config.config['unittests'] or config.config['all']:
357         saw_flag = True
358         threads.append(UnittestTestRunner(params))
359     if config.config['doctests'] or config.config['all']:
360         saw_flag = True
361         threads.append(DoctestTestRunner(params))
362     if config.config['integration'] or config.config['all']:
363         saw_flag = True
364         threads.append(IntegrationTestRunner(params))
365
366     if not saw_flag:
367         config.print_usage()
368         print('ERROR: one of --unittests, --doctests or --integration is required.')
369         return 1
370
371     for thread in threads:
372         thread.start()
373
374     results: Dict[str, TestResults] = {}
375     while len(results) != len(threads):
376         started = 0
377         done = 0
378         failed = 0
379
380         for thread in threads:
381             (s, tr) = thread.get_status()
382             started += s
383             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
384             done += failed + len(tr.tests_succeeded)
385             if not thread.is_alive():
386                 tid = thread.name
387                 if tid not in results:
388                     result = thread.join()
389                     if result:
390                         results[tid] = result
391                         if len(result.tests_failed) > 0:
392                             logger.error(
393                                 'Thread %s returned abnormal results; killing the others.', tid
394                             )
395                             halt_event.set()
396
397         if started > 0:
398             percent_done = done / started
399         else:
400             percent_done = 0.0
401
402         if failed == 0:
403             color = ansi.fg('green')
404         else:
405             color = ansi.fg('red')
406
407         if percent_done < 100.0:
408             print(
409                 text_utils.bar_graph_string(
410                     done,
411                     started,
412                     text=text_utils.BarGraphText.FRACTION,
413                     width=80,
414                     fgcolor=color,
415                 ),
416                 end='\r',
417                 flush=True,
418             )
419         time.sleep(0.5)
420
421     print(f'{ansi.clear_line()}Final Report:')
422     if config.config['coverage']:
423         code_coverage_report()
424     total_problems = test_results_report(results)
425     return total_problems
426
427
428 if __name__ == '__main__':
429     main()