Fix run_tests.py to detect doctest failures and tear itself down more
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(f'({__file__})', f'Args related to {__file__}')
26 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
27 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
28 args.add_argument(
29     '--integration', '-i', action='store_true', help='Run integration tests.'
30 )
31 args.add_argument(
32     '--all',
33     '-a',
34     action='store_true',
35     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
36 )
37 args.add_argument(
38     '--coverage',
39     '-c',
40     action='store_true',
41     help='Run tests and capture code coverage data',
42 )
43
44 HOME = os.environ['HOME']
45
46 # These tests will be run twice in --coverage mode: once to get code
47 # coverage and then again with not coverage enabeled.  This is because
48 # they pay attention to code performance which is adversely affected
49 # by coverage.
50 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
51 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
52
53 ROOT = ".."
54
55
56 @dataclass
57 class TestingParameters:
58     halt_on_error: bool
59     """Should we stop as soon as one error has occurred?"""
60
61     halt_event: threading.Event
62     """An event that, when set, indicates to stop ASAP."""
63
64
65 @dataclass
66 class TestToRun:
67     name: str
68     """The name of the test"""
69
70     kind: str
71     """The kind of the test"""
72
73     cmdline: str
74     """The command line to execute"""
75
76
77 @dataclass
78 class TestResults:
79     name: str
80     """The name of this test / set of tests."""
81
82     tests_executed: List[str]
83     """Tests that were executed."""
84
85     tests_succeeded: List[str]
86     """Tests that succeeded."""
87
88     tests_failed: List[str]
89     """Tests that failed."""
90
91     tests_timed_out: List[str]
92     """Tests that timed out."""
93
94     def __add__(self, other):
95         self.tests_executed.extend(other.tests_executed)
96         self.tests_succeeded.extend(other.tests_succeeded)
97         self.tests_failed.extend(other.tests_failed)
98         self.tests_timed_out.extend(other.tests_timed_out)
99         return self
100
101     __radd__ = __add__
102
103     def __repr__(self) -> str:
104         out = f'{self.name}: '
105         out += f'{ansi.fg("green")}'
106         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
107         out += f'{ansi.reset()}.\n'
108
109         if len(self.tests_failed) > 0:
110             out += f'  ..{ansi.fg("red")}'
111             out += f'{len(self.tests_failed)} tests failed'
112             out += f'{ansi.reset()}:\n'
113             for test in self.tests_failed:
114                 out += f'    {test}\n'
115             out += '\n'
116
117         if len(self.tests_timed_out) > 0:
118             out += f'  ..{ansi.fg("yellow")}'
119             out += f'{len(self.tests_timed_out)} tests timed out'
120             out += f'{ansi.reset()}:\n'
121             for test in self.tests_failed:
122                 out += f'    {test}\n'
123             out += '\n'
124         return out
125
126
127 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
128     """A Base class for something that runs a test."""
129
130     def __init__(self, params: TestingParameters):
131         """Create a TestRunner.
132
133         Args:
134             params: Test running paramters.
135
136         """
137         super().__init__(self, target=self.begin, args=[params])
138         self.params = params
139         self.test_results = TestResults(
140             name=self.get_name(),
141             tests_executed=[],
142             tests_succeeded=[],
143             tests_failed=[],
144             tests_timed_out=[],
145         )
146         self.tests_started = 0
147         self.lock = threading.Lock()
148
149     @abstractmethod
150     def get_name(self) -> str:
151         """The name of this test collection."""
152         pass
153
154     def get_status(self) -> Tuple[int, TestResults]:
155         """Ask the TestRunner for its status."""
156         with self.lock:
157             return (self.tests_started, self.test_results)
158
159     @abstractmethod
160     def begin(self, params: TestingParameters) -> TestResults:
161         """Start execution."""
162         pass
163
164
165 class TemplatedTestRunner(TestRunner, ABC):
166     """A TestRunner that has a recipe for executing the tests."""
167
168     @abstractmethod
169     def identify_tests(self) -> List[TestToRun]:
170         """Return a list of tuples (test, cmdline) that should be executed."""
171         pass
172
173     @abstractmethod
174     def run_test(self, test: TestToRun) -> TestResults:
175         """Run a single test and return its TestResults."""
176         pass
177
178     def check_for_abort(self) -> bool:
179         """Periodically caled to check to see if we need to stop."""
180
181         if self.params.halt_event.is_set():
182             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
183             return True
184
185         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
186             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
187             return True
188         return False
189
190     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
191         """Called to save the output of a test run."""
192
193         dest = f'{test.name}-output.txt'
194         with open(f'./test_output/{dest}', 'w') as wf:
195             print(message, file=wf)
196             print('-' * len(message), file=wf)
197             wf.write(output)
198
199     def execute_commandline(
200         self,
201         test: TestToRun,
202         *,
203         timeout: float = 120.0,
204     ) -> TestResults:
205         """Execute a particular commandline to run a test."""
206
207         try:
208             output = exec_utils.cmd(
209                 test.cmdline,
210                 timeout_seconds=timeout,
211             )
212             if "***Test Failed***" in output:
213                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
214                 logger.error(msg)
215                 self.persist_output(test, msg, output)
216                 return TestResults(
217                     test.name,
218                     [test.name],
219                     [],
220                     [test.name],
221                     [],
222                 )
223
224             self.persist_output(
225                 test, f'{test.name} ({test.cmdline}) succeeded.', output
226             )
227             logger.debug(
228                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
229             )
230             return TestResults(test.name, [test.name], [test.name], [], [])
231         except subprocess.TimeoutExpired as e:
232             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
233             logger.error(msg)
234             logger.debug(
235                 '%s: %s output when it timed out: %s',
236                 self.get_name(),
237                 test.name,
238                 e.output,
239             )
240             self.persist_output(test, msg, e.output.decode('utf-8'))
241             return TestResults(
242                 test.name,
243                 [test.name],
244                 [],
245                 [],
246                 [test.name],
247             )
248         except subprocess.CalledProcessError as e:
249             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
250             logger.error(msg)
251             logger.debug(
252                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
253             )
254             self.persist_output(test, msg, e.output.decode('utf-8'))
255             return TestResults(
256                 test.name,
257                 [test.name],
258                 [],
259                 [test.name],
260                 [],
261             )
262
263     @overrides
264     def begin(self, params: TestingParameters) -> TestResults:
265         logger.debug('Thread %s started.', self.get_name())
266         interesting_tests = self.identify_tests()
267         logger.debug(
268             '%s: Identified %d tests to be run.',
269             self.get_name(),
270             len(interesting_tests),
271         )
272
273         # Note: because of @parallelize on run_tests it actually
274         # returns a SmartFuture with a TestResult inside of it.
275         # That's the reason for this Any business.
276         running: List[Any] = []
277         for test_to_run in interesting_tests:
278             running.append(self.run_test(test_to_run))
279             logger.debug(
280                 '%s: Test %s started in the background.',
281                 self.get_name(),
282                 test_to_run.name,
283             )
284             self.tests_started += 1
285
286         for future in smart_future.wait_any(running, log_exceptions=False):
287             result = future._resolve()
288             logger.debug('Test %s finished.', result.name)
289             self.test_results += result
290             if self.check_for_abort():
291                 logger.debug(
292                     '%s: check_for_abort told us to exit early.', self.get_name()
293                 )
294                 return self.test_results
295
296         logger.debug('Thread %s finished running all tests', self.get_name())
297         return self.test_results
298
299
300 class UnittestTestRunner(TemplatedTestRunner):
301     """Run all known Unittests."""
302
303     @overrides
304     def get_name(self) -> str:
305         return "Unittests"
306
307     @overrides
308     def identify_tests(self) -> List[TestToRun]:
309         ret = []
310         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
311             basename = file_utils.without_path(test)
312             if basename in TESTS_TO_SKIP:
313                 continue
314             if config.config['coverage']:
315                 ret.append(
316                     TestToRun(
317                         name=basename,
318                         kind='unittest capturing coverage',
319                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
320                     )
321                 )
322                 if basename in PERF_SENSATIVE_TESTS:
323                     ret.append(
324                         TestToRun(
325                             name=basename,
326                             kind='unittest w/o coverage to record perf',
327                             cmdline=f'{test} 2>&1',
328                         )
329                     )
330             else:
331                 ret.append(
332                     TestToRun(
333                         name=basename,
334                         kind='unittest',
335                         cmdline=f'{test} 2>&1',
336                     )
337                 )
338         return ret
339
340     @par.parallelize
341     def run_test(self, test: TestToRun) -> TestResults:
342         return self.execute_commandline(test)
343
344
345 class DoctestTestRunner(TemplatedTestRunner):
346     """Run all known Doctests."""
347
348     @overrides
349     def get_name(self) -> str:
350         return "Doctests"
351
352     @overrides
353     def identify_tests(self) -> List[TestToRun]:
354         ret = []
355         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
356         for test in out.split('\n'):
357             if re.match(r'.*\.py$', test):
358                 basename = file_utils.without_path(test)
359                 if basename in TESTS_TO_SKIP:
360                     continue
361                 if config.config['coverage']:
362                     ret.append(
363                         TestToRun(
364                             name=basename,
365                             kind='doctest capturing coverage',
366                             cmdline=f'coverage run --source ../src {test} 2>&1',
367                         )
368                     )
369                     if basename in PERF_SENSATIVE_TESTS:
370                         ret.append(
371                             TestToRun(
372                                 name=basename,
373                                 kind='doctest w/o coverage to record perf',
374                                 cmdline=f'python3 {test} 2>&1',
375                             )
376                         )
377                 else:
378                     ret.append(
379                         TestToRun(
380                             name=basename,
381                             kind='doctest',
382                             cmdline=f'python3 {test} 2>&1',
383                         )
384                     )
385         return ret
386
387     @par.parallelize
388     def run_test(self, test: TestToRun) -> TestResults:
389         return self.execute_commandline(test)
390
391
392 class IntegrationTestRunner(TemplatedTestRunner):
393     """Run all know Integration tests."""
394
395     @overrides
396     def get_name(self) -> str:
397         return "Integration Tests"
398
399     @overrides
400     def identify_tests(self) -> List[TestToRun]:
401         ret = []
402         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
403             basename = file_utils.without_path(test)
404             if basename in TESTS_TO_SKIP:
405                 continue
406             if config.config['coverage']:
407                 ret.append(
408                     TestToRun(
409                         name=basename,
410                         kind='integration test capturing coverage',
411                         cmdline=f'coverage run --source ../src {test} 2>&1',
412                     )
413                 )
414                 if basename in PERF_SENSATIVE_TESTS:
415                     ret.append(
416                         TestToRun(
417                             name=basename,
418                             kind='integration test w/o coverage to capture perf',
419                             cmdline=f'{test} 2>&1',
420                         )
421                     )
422             else:
423                 ret.append(
424                     TestToRun(
425                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
426                     )
427                 )
428         return ret
429
430     @par.parallelize
431     def run_test(self, test: TestToRun) -> TestResults:
432         return self.execute_commandline(test)
433
434
435 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
436     """Give a final report about the tests that were run."""
437     total_problems = 0
438     for result in results.values():
439         if result is None:
440             print('Unexpected unhandled exception in test runner!!!')
441             total_problems += 1
442         else:
443             print(result, end='')
444             total_problems += len(result.tests_failed)
445             total_problems += len(result.tests_timed_out)
446
447     if total_problems > 0:
448         print('Reminder: look in ./test_output to view test output logs')
449     return total_problems
450
451
452 def code_coverage_report():
453     """Give a final code coverage report."""
454     text_utils.header('Code Coverage')
455     exec_utils.cmd('coverage combine .coverage*')
456     out = exec_utils.cmd(
457         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
458     )
459     print(out)
460     print(
461         """To recall this report w/o re-running the tests:
462
463     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
464
465 ...from the 'tests' directory.  Note that subsequent calls to
466 run_tests.py with --coverage will klobber previous results.  See:
467
468     https://coverage.readthedocs.io/en/6.2/
469 """
470     )
471
472
473 @bootstrap.initialize
474 def main() -> Optional[int]:
475     saw_flag = False
476     halt_event = threading.Event()
477     threads: List[TestRunner] = []
478
479     halt_event.clear()
480     params = TestingParameters(
481         halt_on_error=True,
482         halt_event=halt_event,
483     )
484
485     if config.config['coverage']:
486         logger.debug('Clearing existing coverage data via "coverage erase".')
487         exec_utils.cmd('coverage erase')
488
489     if config.config['unittests'] or config.config['all']:
490         saw_flag = True
491         threads.append(UnittestTestRunner(params))
492     if config.config['doctests'] or config.config['all']:
493         saw_flag = True
494         threads.append(DoctestTestRunner(params))
495     if config.config['integration'] or config.config['all']:
496         saw_flag = True
497         threads.append(IntegrationTestRunner(params))
498
499     if not saw_flag:
500         config.print_usage()
501         print('ERROR: one of --unittests, --doctests or --integration is required.')
502         return 1
503
504     for thread in threads:
505         thread.start()
506
507     results: Dict[str, Optional[TestResults]] = {}
508     while len(results) != len(threads):
509         started = 0
510         done = 0
511         failed = 0
512
513         for thread in threads:
514             (s, tr) = thread.get_status()
515             started += s
516             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
517             done += failed + len(tr.tests_succeeded)
518             if not thread.is_alive():
519                 tid = thread.name
520                 if tid not in results:
521                     result = thread.join()
522                     if result:
523                         results[tid] = result
524                         if len(result.tests_failed) > 0:
525                             logger.error(
526                                 'Thread %s returned abnormal results; killing the others.',
527                                 tid,
528                             )
529                             halt_event.set()
530                     else:
531                         logger.error(
532                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
533                             tid,
534                         )
535                         halt_event.set()
536                         results[tid] = None
537
538         if started > 0:
539             percent_done = done / started
540         else:
541             percent_done = 0.0
542
543         if failed == 0:
544             color = ansi.fg('green')
545         else:
546             color = ansi.fg('red')
547
548         if percent_done < 100.0:
549             print(
550                 text_utils.bar_graph_string(
551                     done,
552                     started,
553                     text=text_utils.BarGraphText.FRACTION,
554                     width=80,
555                     fgcolor=color,
556                 ),
557                 end='\r',
558                 flush=True,
559             )
560         time.sleep(0.5)
561
562     print(f'{ansi.clear_line()}Final Report:')
563     if config.config['coverage']:
564         code_coverage_report()
565     total_problems = test_results_report(results)
566     return total_problems
567
568
569 if __name__ == '__main__':
570     main()