Make run_tests.py tell you what's still running.
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
27 )
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
30 args.add_argument(
31     '--integration', '-i', action='store_true', help='Run integration tests.'
32 )
33 args.add_argument(
34     '--all',
35     '-a',
36     action='store_true',
37     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
38 )
39 args.add_argument(
40     '--coverage',
41     '-c',
42     action='store_true',
43     help='Run tests and capture code coverage data',
44 )
45
46 HOME = os.environ['HOME']
47
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled.  This is because
50 # they pay attention to code performance which is adversely affected
51 # by coverage.
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
54
55 ROOT = ".."
56
57
58 @dataclass
59 class TestingParameters:
60     halt_on_error: bool
61     """Should we stop as soon as one error has occurred?"""
62
63     halt_event: threading.Event
64     """An event that, when set, indicates to stop ASAP."""
65
66
67 @dataclass
68 class TestToRun:
69     name: str
70     """The name of the test"""
71
72     kind: str
73     """The kind of the test"""
74
75     cmdline: str
76     """The command line to execute"""
77
78
79 @dataclass
80 class TestResults:
81     name: str
82     """The name of this test / set of tests."""
83
84     tests_executed: List[str]
85     """Tests that were executed."""
86
87     tests_succeeded: List[str]
88     """Tests that succeeded."""
89
90     tests_failed: List[str]
91     """Tests that failed."""
92
93     tests_timed_out: List[str]
94     """Tests that timed out."""
95
96     def __add__(self, other):
97         self.tests_executed.extend(other.tests_executed)
98         self.tests_succeeded.extend(other.tests_succeeded)
99         self.tests_failed.extend(other.tests_failed)
100         self.tests_timed_out.extend(other.tests_timed_out)
101         return self
102
103     __radd__ = __add__
104
105     def __repr__(self) -> str:
106         out = f'{self.name}: '
107         out += f'{ansi.fg("green")}'
108         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
109         out += f'{ansi.reset()}.\n'
110
111         if len(self.tests_failed) > 0:
112             out += f'  ..{ansi.fg("red")}'
113             out += f'{len(self.tests_failed)} tests failed'
114             out += f'{ansi.reset()}:\n'
115             for test in self.tests_failed:
116                 out += f'    {test}\n'
117             out += '\n'
118
119         if len(self.tests_timed_out) > 0:
120             out += f'  ..{ansi.fg("yellow")}'
121             out += f'{len(self.tests_timed_out)} tests timed out'
122             out += f'{ansi.reset()}:\n'
123             for test in self.tests_failed:
124                 out += f'    {test}\n'
125             out += '\n'
126         return out
127
128
129 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
130     """A Base class for something that runs a test."""
131
132     def __init__(self, params: TestingParameters):
133         """Create a TestRunner.
134
135         Args:
136             params: Test running paramters.
137
138         """
139         super().__init__(self, target=self.begin, args=[params])
140         self.params = params
141         self.test_results = TestResults(
142             name=self.get_name(),
143             tests_executed=[],
144             tests_succeeded=[],
145             tests_failed=[],
146             tests_timed_out=[],
147         )
148         self.lock = threading.Lock()
149
150     @abstractmethod
151     def get_name(self) -> str:
152         """The name of this test collection."""
153         pass
154
155     def get_status(self) -> TestResults:
156         """Ask the TestRunner for its status."""
157         with self.lock:
158             return self.test_results
159
160     @abstractmethod
161     def begin(self, params: TestingParameters) -> TestResults:
162         """Start execution."""
163         pass
164
165
166 class TemplatedTestRunner(TestRunner, ABC):
167     """A TestRunner that has a recipe for executing the tests."""
168
169     @abstractmethod
170     def identify_tests(self) -> List[TestToRun]:
171         """Return a list of tuples (test, cmdline) that should be executed."""
172         pass
173
174     @abstractmethod
175     def run_test(self, test: TestToRun) -> TestResults:
176         """Run a single test and return its TestResults."""
177         pass
178
179     def check_for_abort(self) -> bool:
180         """Periodically caled to check to see if we need to stop."""
181
182         if self.params.halt_event.is_set():
183             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
184             return True
185
186         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
187             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
188             return True
189         return False
190
191     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
192         """Called to save the output of a test run."""
193
194         dest = f'{test.name}-output.txt'
195         with open(f'./test_output/{dest}', 'w') as wf:
196             print(message, file=wf)
197             print('-' * len(message), file=wf)
198             wf.write(output)
199
200     def execute_commandline(
201         self,
202         test: TestToRun,
203         *,
204         timeout: float = 120.0,
205     ) -> TestResults:
206         """Execute a particular commandline to run a test."""
207
208         try:
209             output = exec_utils.cmd(
210                 test.cmdline,
211                 timeout_seconds=timeout,
212             )
213             if "***Test Failed***" in output:
214                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
215                 logger.error(msg)
216                 self.persist_output(test, msg, output)
217                 return TestResults(
218                     test.name,
219                     [],
220                     [],
221                     [test.name],
222                     [],
223                 )
224
225             self.persist_output(
226                 test, f'{test.name} ({test.cmdline}) succeeded.', output
227             )
228             logger.debug(
229                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
230             )
231             return TestResults(test.name, [], [test.name], [], [])
232         except subprocess.TimeoutExpired as e:
233             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
234             logger.error(msg)
235             logger.debug(
236                 '%s: %s output when it timed out: %s',
237                 self.get_name(),
238                 test.name,
239                 e.output,
240             )
241             self.persist_output(test, msg, e.output.decode('utf-8'))
242             return TestResults(
243                 test.name,
244                 [],
245                 [],
246                 [],
247                 [test.name],
248             )
249         except subprocess.CalledProcessError as e:
250             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
251             logger.error(msg)
252             logger.debug(
253                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
254             )
255             self.persist_output(test, msg, e.output.decode('utf-8'))
256             return TestResults(
257                 test.name,
258                 [],
259                 [],
260                 [test.name],
261                 [],
262             )
263
264     @overrides
265     def begin(self, params: TestingParameters) -> TestResults:
266         logger.debug('Thread %s started.', self.get_name())
267         interesting_tests = self.identify_tests()
268         logger.debug(
269             '%s: Identified %d tests to be run.',
270             self.get_name(),
271             len(interesting_tests),
272         )
273
274         # Note: because of @parallelize on run_tests it actually
275         # returns a SmartFuture with a TestResult inside of it.
276         # That's the reason for this Any business.
277         running: List[Any] = []
278         for test_to_run in interesting_tests:
279             running.append(self.run_test(test_to_run))
280             logger.debug(
281                 '%s: Test %s started in the background.',
282                 self.get_name(),
283                 test_to_run.name,
284             )
285             self.test_results.tests_executed.append(test_to_run.name)
286
287         for future in smart_future.wait_any(running, log_exceptions=False):
288             result = future._resolve()
289             logger.debug('Test %s finished.', result.name)
290             self.test_results += result
291             if self.check_for_abort():
292                 logger.debug(
293                     '%s: check_for_abort told us to exit early.', self.get_name()
294                 )
295                 return self.test_results
296
297         logger.debug('Thread %s finished running all tests', self.get_name())
298         return self.test_results
299
300
301 class UnittestTestRunner(TemplatedTestRunner):
302     """Run all known Unittests."""
303
304     @overrides
305     def get_name(self) -> str:
306         return "Unittests"
307
308     @overrides
309     def identify_tests(self) -> List[TestToRun]:
310         ret = []
311         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
312             basename = file_utils.without_path(test)
313             if basename in TESTS_TO_SKIP:
314                 continue
315             if config.config['coverage']:
316                 ret.append(
317                     TestToRun(
318                         name=basename,
319                         kind='unittest capturing coverage',
320                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
321                     )
322                 )
323                 if basename in PERF_SENSATIVE_TESTS:
324                     ret.append(
325                         TestToRun(
326                             name=basename,
327                             kind='unittest w/o coverage to record perf',
328                             cmdline=f'{test} 2>&1',
329                         )
330                     )
331             else:
332                 ret.append(
333                     TestToRun(
334                         name=basename,
335                         kind='unittest',
336                         cmdline=f'{test} 2>&1',
337                     )
338                 )
339         return ret
340
341     @par.parallelize
342     def run_test(self, test: TestToRun) -> TestResults:
343         return self.execute_commandline(test)
344
345
346 class DoctestTestRunner(TemplatedTestRunner):
347     """Run all known Doctests."""
348
349     @overrides
350     def get_name(self) -> str:
351         return "Doctests"
352
353     @overrides
354     def identify_tests(self) -> List[TestToRun]:
355         ret = []
356         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
357         for test in out.split('\n'):
358             if re.match(r'.*\.py$', test):
359                 basename = file_utils.without_path(test)
360                 if basename in TESTS_TO_SKIP:
361                     continue
362                 if config.config['coverage']:
363                     ret.append(
364                         TestToRun(
365                             name=basename,
366                             kind='doctest capturing coverage',
367                             cmdline=f'coverage run --source ../src {test} 2>&1',
368                         )
369                     )
370                     if basename in PERF_SENSATIVE_TESTS:
371                         ret.append(
372                             TestToRun(
373                                 name=basename,
374                                 kind='doctest w/o coverage to record perf',
375                                 cmdline=f'python3 {test} 2>&1',
376                             )
377                         )
378                 else:
379                     ret.append(
380                         TestToRun(
381                             name=basename,
382                             kind='doctest',
383                             cmdline=f'python3 {test} 2>&1',
384                         )
385                     )
386         return ret
387
388     @par.parallelize
389     def run_test(self, test: TestToRun) -> TestResults:
390         return self.execute_commandline(test)
391
392
393 class IntegrationTestRunner(TemplatedTestRunner):
394     """Run all know Integration tests."""
395
396     @overrides
397     def get_name(self) -> str:
398         return "Integration Tests"
399
400     @overrides
401     def identify_tests(self) -> List[TestToRun]:
402         ret = []
403         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
404             basename = file_utils.without_path(test)
405             if basename in TESTS_TO_SKIP:
406                 continue
407             if config.config['coverage']:
408                 ret.append(
409                     TestToRun(
410                         name=basename,
411                         kind='integration test capturing coverage',
412                         cmdline=f'coverage run --source ../src {test} 2>&1',
413                     )
414                 )
415                 if basename in PERF_SENSATIVE_TESTS:
416                     ret.append(
417                         TestToRun(
418                             name=basename,
419                             kind='integration test w/o coverage to capture perf',
420                             cmdline=f'{test} 2>&1',
421                         )
422                     )
423             else:
424                 ret.append(
425                     TestToRun(
426                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
427                     )
428                 )
429         return ret
430
431     @par.parallelize
432     def run_test(self, test: TestToRun) -> TestResults:
433         return self.execute_commandline(test)
434
435
436 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
437     """Give a final report about the tests that were run."""
438     total_problems = 0
439     for result in results.values():
440         if result is None:
441             print('Unexpected unhandled exception in test runner!!!')
442             total_problems += 1
443         else:
444             print(result, end='')
445             total_problems += len(result.tests_failed)
446             total_problems += len(result.tests_timed_out)
447
448     if total_problems > 0:
449         print('Reminder: look in ./test_output to view test output logs')
450     return total_problems
451
452
453 def code_coverage_report():
454     """Give a final code coverage report."""
455     text_utils.header('Code Coverage')
456     exec_utils.cmd('coverage combine .coverage*')
457     out = exec_utils.cmd(
458         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
459     )
460     print(out)
461     print(
462         """To recall this report w/o re-running the tests:
463
464     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
465
466 ...from the 'tests' directory.  Note that subsequent calls to
467 run_tests.py with --coverage will klobber previous results.  See:
468
469     https://coverage.readthedocs.io/en/6.2/
470 """
471     )
472
473
474 @bootstrap.initialize
475 def main() -> Optional[int]:
476     saw_flag = False
477     halt_event = threading.Event()
478     threads: List[TestRunner] = []
479
480     halt_event.clear()
481     params = TestingParameters(
482         halt_on_error=True,
483         halt_event=halt_event,
484     )
485
486     if config.config['coverage']:
487         logger.debug('Clearing existing coverage data via "coverage erase".')
488         exec_utils.cmd('coverage erase')
489
490     if config.config['unittests'] or config.config['all']:
491         saw_flag = True
492         threads.append(UnittestTestRunner(params))
493     if config.config['doctests'] or config.config['all']:
494         saw_flag = True
495         threads.append(DoctestTestRunner(params))
496     if config.config['integration'] or config.config['all']:
497         saw_flag = True
498         threads.append(IntegrationTestRunner(params))
499
500     if not saw_flag:
501         config.print_usage()
502         print('ERROR: one of --unittests, --doctests or --integration is required.')
503         return 1
504
505     for thread in threads:
506         thread.start()
507
508     results: Dict[str, Optional[TestResults]] = {}
509     start_time = time.time()
510     last_update = start_time
511     still_running = {}
512
513     while len(results) != len(threads):
514         started = 0
515         done = 0
516         failed = 0
517
518         for thread in threads:
519             tid = thread.name
520             tr = thread.get_status()
521             started += len(tr.tests_executed)
522             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
523             done += failed + len(tr.tests_succeeded)
524             running = set(tr.tests_executed)
525             running -= set(tr.tests_failed)
526             running -= set(tr.tests_succeeded)
527             running -= set(tr.tests_timed_out)
528             still_running[tid] = running
529
530             if time.time() - start_time > 5.0:
531                 if time.time() - last_update > 3.0:
532                     last_update = time.time()
533                     update = []
534                     for _, running_set in still_running.items():
535                         for test_name in running_set:
536                             update.append(test_name)
537                     print(f'\r{ansi.clear_line()}')
538                     if len(update) < 5:
539                         print(f'Still running: {" ".join(update)}')
540                     else:
541                         print(f'Still running: {len(update)} tests.')
542
543             if not thread.is_alive():
544                 if tid not in results:
545                     result = thread.join()
546                     if result:
547                         results[tid] = result
548                         if len(result.tests_failed) > 0:
549                             logger.error(
550                                 'Thread %s returned abnormal results; killing the others.',
551                                 tid,
552                             )
553                             halt_event.set()
554                     else:
555                         logger.error(
556                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
557                             tid,
558                         )
559                         halt_event.set()
560                         results[tid] = None
561
562         if started > 0:
563             percent_done = done / started
564         else:
565             percent_done = 0.0
566
567         if failed == 0:
568             color = ansi.fg('green')
569         else:
570             color = ansi.fg('red')
571
572         if percent_done < 100.0:
573             print(
574                 text_utils.bar_graph_string(
575                     done,
576                     started,
577                     text=text_utils.BarGraphText.FRACTION,
578                     width=80,
579                     fgcolor=color,
580                 ),
581                 end='\r',
582                 flush=True,
583             )
584         time.sleep(0.5)
585
586     print(f'{ansi.clear_line()}Final Report:')
587     if config.config['coverage']:
588         code_coverage_report()
589     total_problems = test_results_report(results)
590     return total_problems
591
592
593 if __name__ == '__main__':
594     main()