Fix wrong TimeoutError in catch.
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 from __future__ import annotations
8
9 import logging
10 import os
11 import re
12 import subprocess
13 import threading
14 import time
15 from abc import ABC, abstractmethod
16 from dataclasses import dataclass
17 from typing import Any, Dict, List, Optional
18
19 from overrides import overrides
20
21 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
22 from pyutils.files import file_utils
23 from pyutils.parallelize import parallelize as par
24 from pyutils.parallelize import smart_future, thread_utils
25
26 logger = logging.getLogger(__name__)
27 args = config.add_commandline_args(
28     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
29 )
30 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
31 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
32 args.add_argument(
33     '--integration', '-i', action='store_true', help='Run integration tests.'
34 )
35 args.add_argument(
36     '--all',
37     '-a',
38     action='store_true',
39     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
40 )
41 args.add_argument(
42     '--coverage',
43     '-c',
44     action='store_true',
45     help='Run tests and capture code coverage data',
46 )
47
48 HOME = os.environ['HOME']
49
50 # These tests will be run twice in --coverage mode: once to get code
51 # coverage and then again with not coverage enabeled.  This is because
52 # they pay attention to code performance which is adversely affected
53 # by coverage.
54 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
55 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
56
57 ROOT = ".."
58
59
60 @dataclass
61 class TestingParameters:
62     halt_on_error: bool
63     """Should we stop as soon as one error has occurred?"""
64
65     halt_event: threading.Event
66     """An event that, when set, indicates to stop ASAP."""
67
68
69 @dataclass
70 class TestToRun:
71     name: str
72     """The name of the test"""
73
74     kind: str
75     """The kind of the test"""
76
77     cmdline: str
78     """The command line to execute"""
79
80
81 @dataclass
82 class TestResults:
83     name: str
84     """The name of this test / set of tests."""
85
86     tests_executed: Dict[str, float]
87     """Tests that were executed."""
88
89     tests_succeeded: List[str]
90     """Tests that succeeded."""
91
92     tests_failed: List[str]
93     """Tests that failed."""
94
95     tests_timed_out: List[str]
96     """Tests that timed out."""
97
98     def __add__(self, other):
99         merged = dict_utils.coalesce(
100             [self.tests_executed, other.tests_executed],
101             aggregation_function=dict_utils.raise_on_duplicated_keys,
102         )
103         self.tests_executed = merged
104         self.tests_succeeded.extend(other.tests_succeeded)
105         self.tests_failed.extend(other.tests_failed)
106         self.tests_timed_out.extend(other.tests_timed_out)
107         return self
108
109     __radd__ = __add__
110
111     @staticmethod
112     def empty_test_results(suite_name: str) -> TestResults:
113         return TestResults(
114             name=suite_name,
115             tests_executed={},
116             tests_succeeded=[],
117             tests_failed=[],
118             tests_timed_out=[],
119         )
120
121     @staticmethod
122     def single_test_succeeded(name: str) -> TestResults:
123         return TestResults(name, {}, [name], [], [])
124
125     @staticmethod
126     def single_test_failed(name: str) -> TestResults:
127         return TestResults(
128             name,
129             {},
130             [],
131             [name],
132             [],
133         )
134
135     @staticmethod
136     def single_test_timed_out(name: str) -> TestResults:
137         return TestResults(
138             name,
139             {},
140             [],
141             [],
142             [name],
143         )
144
145     def __repr__(self) -> str:
146         out = f'{self.name}: '
147         out += f'{ansi.fg("green")}'
148         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
149         out += f'{ansi.reset()}.\n'
150
151         if len(self.tests_failed) > 0:
152             out += f'  ..{ansi.fg("red")}'
153             out += f'{len(self.tests_failed)} tests failed'
154             out += f'{ansi.reset()}:\n'
155             for test in self.tests_failed:
156                 out += f'    {test}\n'
157             out += '\n'
158
159         if len(self.tests_timed_out) > 0:
160             out += f'  ..{ansi.fg("lightning yellow")}'
161             out += f'{len(self.tests_timed_out)} tests timed out'
162             out += f'{ansi.reset()}:\n'
163             for test in self.tests_failed:
164                 out += f'    {test}\n'
165             out += '\n'
166         return out
167
168
169 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
170     """A Base class for something that runs a test."""
171
172     def __init__(self, params: TestingParameters):
173         """Create a TestRunner.
174
175         Args:
176             params: Test running paramters.
177
178         """
179         super().__init__(self, target=self.begin, args=[params])
180         self.params = params
181         self.test_results = TestResults.empty_test_results(self.get_name())
182         self.lock = threading.Lock()
183
184     @abstractmethod
185     def get_name(self) -> str:
186         """The name of this test collection."""
187         pass
188
189     def get_status(self) -> TestResults:
190         """Ask the TestRunner for its status."""
191         with self.lock:
192             return self.test_results
193
194     @abstractmethod
195     def begin(self, params: TestingParameters) -> TestResults:
196         """Start execution."""
197         pass
198
199
200 class TemplatedTestRunner(TestRunner, ABC):
201     """A TestRunner that has a recipe for executing the tests."""
202
203     @abstractmethod
204     def identify_tests(self) -> List[TestToRun]:
205         """Return a list of tuples (test, cmdline) that should be executed."""
206         pass
207
208     @abstractmethod
209     def run_test(self, test: TestToRun) -> TestResults:
210         """Run a single test and return its TestResults."""
211         pass
212
213     def check_for_abort(self) -> bool:
214         """Periodically called to check to see if we need to stop."""
215
216         if self.params.halt_event.is_set():
217             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
218             return True
219
220         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
221             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
222             return True
223         return False
224
225     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
226         """Called to save the output of a test run."""
227
228         dest = f'{test.name}-output.txt'
229         with open(f'./test_output/{dest}', 'w') as wf:
230             print(message, file=wf)
231             print('-' * len(message), file=wf)
232             wf.write(output)
233
234     def execute_commandline(
235         self,
236         test: TestToRun,
237         *,
238         timeout: float = 120.0,
239     ) -> TestResults:
240         """Execute a particular commandline to run a test."""
241
242         msg = f'{self.get_name()}: {test.name} ({test.cmdline}) '
243         try:
244             output = exec_utils.cmd(
245                 test.cmdline,
246                 timeout_seconds=timeout,
247             )
248             if "***Test Failed***" in output:
249                 msg += 'failed; doctest failure message detected.'
250                 logger.error(msg)
251                 self.persist_output(test, msg, output)
252                 return TestResults.single_test_failed(test.name)
253
254             msg += 'succeeded.'
255             self.persist_output(test, msg, output)
256             logger.debug(msg)
257             return TestResults.single_test_succeeded(test.name)
258
259         except subprocess.TimeoutExpired as e:
260             msg += f'timed out after {e.timeout:.1f} seconds.'
261             logger.error(msg)
262             logger.debug(
263                 '%s: %s output when it timed out: %s',
264                 self.get_name(),
265                 test.name,
266                 e.output,
267             )
268             self.persist_output(test, msg, e.output.decode('utf-8'))
269             return TestResults.single_test_timed_out(test.name)
270
271         except subprocess.CalledProcessError as e:
272             msg += f'failed with exit code {e.returncode}.'
273             logger.error(msg)
274             logger.debug(
275                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
276             )
277             self.persist_output(test, msg, e.output.decode('utf-8'))
278             return TestResults.single_test_failed(test.name)
279
280     @overrides
281     def begin(self, params: TestingParameters) -> TestResults:
282         logger.debug('Thread %s started.', self.get_name())
283         interesting_tests = self.identify_tests()
284         logger.debug(
285             '%s: Identified %d tests to be run.',
286             self.get_name(),
287             len(interesting_tests),
288         )
289
290         # Note: because of @parallelize on run_tests it actually
291         # returns a SmartFuture with a TestResult inside of it.
292         # That's the reason for this Any business.
293         running: List[Any] = []
294         for test_to_run in interesting_tests:
295             running.append(self.run_test(test_to_run))
296             logger.debug(
297                 '%s: Test %s started in the background.',
298                 self.get_name(),
299                 test_to_run.name,
300             )
301             self.test_results.tests_executed[test_to_run.name] = time.time()
302
303         for result in smart_future.wait_any(running, log_exceptions=False):
304             logger.debug('Test %s finished.', result.name)
305             self.test_results += result
306
307             if self.check_for_abort():
308                 logger.debug(
309                     '%s: check_for_abort told us to exit early.', self.get_name()
310                 )
311                 return self.test_results
312
313         logger.debug('Thread %s finished running all tests', self.get_name())
314         return self.test_results
315
316
317 class UnittestTestRunner(TemplatedTestRunner):
318     """Run all known Unittests."""
319
320     @overrides
321     def get_name(self) -> str:
322         return "Unittests"
323
324     @overrides
325     def identify_tests(self) -> List[TestToRun]:
326         ret = []
327         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
328             basename = file_utils.without_path(test)
329             if basename in TESTS_TO_SKIP:
330                 continue
331             if config.config['coverage']:
332                 ret.append(
333                     TestToRun(
334                         name=basename,
335                         kind='unittest capturing coverage',
336                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
337                     )
338                 )
339                 if basename in PERF_SENSATIVE_TESTS:
340                     ret.append(
341                         TestToRun(
342                             name=f'{basename}_no_coverage',
343                             kind='unittest w/o coverage to record perf',
344                             cmdline=f'{test} 2>&1',
345                         )
346                     )
347             else:
348                 ret.append(
349                     TestToRun(
350                         name=basename,
351                         kind='unittest',
352                         cmdline=f'{test} 2>&1',
353                     )
354                 )
355         return ret
356
357     @par.parallelize
358     def run_test(self, test: TestToRun) -> TestResults:
359         return self.execute_commandline(test)
360
361
362 class DoctestTestRunner(TemplatedTestRunner):
363     """Run all known Doctests."""
364
365     @overrides
366     def get_name(self) -> str:
367         return "Doctests"
368
369     @overrides
370     def identify_tests(self) -> List[TestToRun]:
371         ret = []
372         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
373         for test in out.split('\n'):
374             if re.match(r'.*\.py$', test):
375                 basename = file_utils.without_path(test)
376                 if basename in TESTS_TO_SKIP:
377                     continue
378                 if config.config['coverage']:
379                     ret.append(
380                         TestToRun(
381                             name=basename,
382                             kind='doctest capturing coverage',
383                             cmdline=f'coverage run --source ../src {test} 2>&1',
384                         )
385                     )
386                     if basename in PERF_SENSATIVE_TESTS:
387                         ret.append(
388                             TestToRun(
389                                 name=f'{basename}_no_coverage',
390                                 kind='doctest w/o coverage to record perf',
391                                 cmdline=f'python3 {test} 2>&1',
392                             )
393                         )
394                 else:
395                     ret.append(
396                         TestToRun(
397                             name=basename,
398                             kind='doctest',
399                             cmdline=f'python3 {test} 2>&1',
400                         )
401                     )
402         return ret
403
404     @par.parallelize
405     def run_test(self, test: TestToRun) -> TestResults:
406         return self.execute_commandline(test)
407
408
409 class IntegrationTestRunner(TemplatedTestRunner):
410     """Run all know Integration tests."""
411
412     @overrides
413     def get_name(self) -> str:
414         return "Integration Tests"
415
416     @overrides
417     def identify_tests(self) -> List[TestToRun]:
418         ret = []
419         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
420             basename = file_utils.without_path(test)
421             if basename in TESTS_TO_SKIP:
422                 continue
423             if config.config['coverage']:
424                 ret.append(
425                     TestToRun(
426                         name=basename,
427                         kind='integration test capturing coverage',
428                         cmdline=f'coverage run --source ../src {test} 2>&1',
429                     )
430                 )
431                 if basename in PERF_SENSATIVE_TESTS:
432                     ret.append(
433                         TestToRun(
434                             name=f'{basename}_no_coverage',
435                             kind='integration test w/o coverage to capture perf',
436                             cmdline=f'{test} 2>&1',
437                         )
438                     )
439             else:
440                 ret.append(
441                     TestToRun(
442                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
443                     )
444                 )
445         return ret
446
447     @par.parallelize
448     def run_test(self, test: TestToRun) -> TestResults:
449         return self.execute_commandline(test)
450
451
452 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
453     """Give a final report about the tests that were run."""
454     total_problems = 0
455     for result in results.values():
456         if result is None:
457             print('Unexpected unhandled exception in test runner!!!')
458             total_problems += 1
459         else:
460             print(result, end='')
461             total_problems += len(result.tests_failed)
462             total_problems += len(result.tests_timed_out)
463
464     if total_problems > 0:
465         print('Reminder: look in ./test_output to view test output logs')
466     return total_problems
467
468
469 def code_coverage_report():
470     """Give a final code coverage report."""
471     text_utils.header('Code Coverage')
472     exec_utils.cmd('coverage combine .coverage*')
473     out = exec_utils.cmd(
474         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
475     )
476     print(out)
477     print(
478         f"""To recall this report w/o re-running the tests:
479
480     $ {ansi.bold()}coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover{ansi.reset()}
481
482 ...from the 'tests' directory.  Note that subsequent calls to
483 run_tests.py with --coverage will klobber previous results.  See:
484
485     https://coverage.readthedocs.io/en/6.2/
486 """
487     )
488
489
490 @bootstrap.initialize
491 def main() -> Optional[int]:
492     saw_flag = False
493     threads: List[TestRunner] = []
494
495     halt_event = threading.Event()
496     halt_event.clear()
497     params = TestingParameters(
498         halt_on_error=True,
499         halt_event=halt_event,
500     )
501
502     if config.config['coverage']:
503         logger.debug('Clearing existing coverage data via "coverage erase".')
504         exec_utils.cmd('coverage erase')
505     if config.config['unittests'] or config.config['all']:
506         saw_flag = True
507         threads.append(UnittestTestRunner(params))
508     if config.config['doctests'] or config.config['all']:
509         saw_flag = True
510         threads.append(DoctestTestRunner(params))
511     if config.config['integration'] or config.config['all']:
512         saw_flag = True
513         threads.append(IntegrationTestRunner(params))
514
515     if not saw_flag:
516         config.print_usage()
517         config.error('One of --unittests, --doctests or --integration is required.', 1)
518
519     for thread in threads:
520         thread.start()
521
522     start_time = time.time()
523     last_update = start_time
524     results: Dict[str, Optional[TestResults]] = {}
525     still_running = {}
526
527     while len(results) != len(threads):
528         started = 0
529         done = 0
530         failed = 0
531
532         for thread in threads:
533             tid = thread.name
534             tr = thread.get_status()
535             started += len(tr.tests_executed)
536             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
537             done += failed + len(tr.tests_succeeded)
538             running = set(tr.tests_executed.keys())
539             running -= set(tr.tests_failed)
540             running -= set(tr.tests_succeeded)
541             running -= set(tr.tests_timed_out)
542             running_with_start_time = {
543                 test: tr.tests_executed[test] for test in running
544             }
545             still_running[tid] = running_with_start_time
546
547             # Maybe print tests that are still running.
548             now = time.time()
549             if now - start_time > 5.0:
550                 if now - last_update > 3.0:
551                     last_update = now
552                     update = []
553                     for _, running_dict in still_running.items():
554                         for test_name, start_time in running_dict.items():
555                             if now - start_time > 10.0:
556                                 update.append(f'{test_name}@{now-start_time:.1f}s')
557                             else:
558                                 update.append(test_name)
559                     print(f'\r{ansi.clear_line()}')
560                     if len(update) < 4:
561                         print(f'Still running: {",".join(update)}')
562                     else:
563                         print(f'Still running: {len(update)} tests.')
564
565             # Maybe signal the other threads to stop too.
566             if not thread.is_alive():
567                 if tid not in results:
568                     result = thread.join()
569                     if result:
570                         results[tid] = result
571                         if (len(result.tests_failed) + len(result.tests_timed_out)) > 0:
572                             logger.error(
573                                 'Thread %s returned abnormal results; killing the others.',
574                                 tid,
575                             )
576                             halt_event.set()
577                     else:
578                         logger.error(
579                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
580                             tid,
581                         )
582                         halt_event.set()
583                         results[tid] = None
584
585         color = ansi.fg('green')
586         if failed > 0:
587             color = ansi.fg('red')
588
589         if started > 0:
590             percent_done = done / started * 100.0
591         else:
592             percent_done = 0.0
593
594         if percent_done < 100.0:
595             print(
596                 text_utils.bar_graph_string(
597                     done,
598                     started,
599                     text=text_utils.BarGraphText.FRACTION,
600                     width=72,
601                     fgcolor=color,
602                 ),
603                 end='',
604                 flush=True,
605             )
606             print(f'  {color}{now - start_time:.1f}s{ansi.reset()}', end='\r')
607         time.sleep(0.1)
608
609     print(f'{ansi.clear_line()}\n{ansi.underline()}Final Report:{ansi.reset()}')
610     if config.config['coverage']:
611         code_coverage_report()
612     print(f'Test suite runtime: {time.time() - start_time:.1f}s')
613     total_problems = test_results_report(results)
614     if total_problems > 0:
615         logging.error(
616             'Exiting with non-zero return code %d due to problems.', total_problems
617         )
618     return total_problems
619
620
621 if __name__ == '__main__':
622     main()