5806c340f57cc4f6a1873d19546b90b201eceaf1
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
27 )
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
30 args.add_argument(
31     '--integration', '-i', action='store_true', help='Run integration tests.'
32 )
33 args.add_argument(
34     '--all',
35     '-a',
36     action='store_true',
37     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
38 )
39 args.add_argument(
40     '--coverage',
41     '-c',
42     action='store_true',
43     help='Run tests and capture code coverage data',
44 )
45
46 HOME = os.environ['HOME']
47
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled.  This is because
50 # they pay attention to code performance which is adversely affected
51 # by coverage.
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
54
55 ROOT = ".."
56
57
58 @dataclass
59 class TestingParameters:
60     halt_on_error: bool
61     """Should we stop as soon as one error has occurred?"""
62
63     halt_event: threading.Event
64     """An event that, when set, indicates to stop ASAP."""
65
66
67 @dataclass
68 class TestToRun:
69     name: str
70     """The name of the test"""
71
72     kind: str
73     """The kind of the test"""
74
75     cmdline: str
76     """The command line to execute"""
77
78
79 @dataclass
80 class TestResults:
81     name: str
82     """The name of this test / set of tests."""
83
84     tests_executed: Dict[str, float]
85     """Tests that were executed."""
86
87     tests_succeeded: List[str]
88     """Tests that succeeded."""
89
90     tests_failed: List[str]
91     """Tests that failed."""
92
93     tests_timed_out: List[str]
94     """Tests that timed out."""
95
96     def __add__(self, other):
97         merged = dict_utils.coalesce(
98             [self.tests_executed, other.tests_executed],
99             aggregation_function=dict_utils.raise_on_duplicated_keys,
100         )
101         self.tests_executed = merged
102         self.tests_succeeded.extend(other.tests_succeeded)
103         self.tests_failed.extend(other.tests_failed)
104         self.tests_timed_out.extend(other.tests_timed_out)
105         return self
106
107     __radd__ = __add__
108
109     def __repr__(self) -> str:
110         out = f'{self.name}: '
111         out += f'{ansi.fg("green")}'
112         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
113         out += f'{ansi.reset()}.\n'
114
115         if len(self.tests_failed) > 0:
116             out += f'  ..{ansi.fg("red")}'
117             out += f'{len(self.tests_failed)} tests failed'
118             out += f'{ansi.reset()}:\n'
119             for test in self.tests_failed:
120                 out += f'    {test}\n'
121             out += '\n'
122
123         if len(self.tests_timed_out) > 0:
124             out += f'  ..{ansi.fg("yellow")}'
125             out += f'{len(self.tests_timed_out)} tests timed out'
126             out += f'{ansi.reset()}:\n'
127             for test in self.tests_failed:
128                 out += f'    {test}\n'
129             out += '\n'
130         return out
131
132
133 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
134     """A Base class for something that runs a test."""
135
136     def __init__(self, params: TestingParameters):
137         """Create a TestRunner.
138
139         Args:
140             params: Test running paramters.
141
142         """
143         super().__init__(self, target=self.begin, args=[params])
144         self.params = params
145         self.test_results = TestResults(
146             name=self.get_name(),
147             tests_executed={},
148             tests_succeeded=[],
149             tests_failed=[],
150             tests_timed_out=[],
151         )
152         self.lock = threading.Lock()
153
154     @abstractmethod
155     def get_name(self) -> str:
156         """The name of this test collection."""
157         pass
158
159     def get_status(self) -> TestResults:
160         """Ask the TestRunner for its status."""
161         with self.lock:
162             return self.test_results
163
164     @abstractmethod
165     def begin(self, params: TestingParameters) -> TestResults:
166         """Start execution."""
167         pass
168
169
170 class TemplatedTestRunner(TestRunner, ABC):
171     """A TestRunner that has a recipe for executing the tests."""
172
173     @abstractmethod
174     def identify_tests(self) -> List[TestToRun]:
175         """Return a list of tuples (test, cmdline) that should be executed."""
176         pass
177
178     @abstractmethod
179     def run_test(self, test: TestToRun) -> TestResults:
180         """Run a single test and return its TestResults."""
181         pass
182
183     def check_for_abort(self) -> bool:
184         """Periodically caled to check to see if we need to stop."""
185
186         if self.params.halt_event.is_set():
187             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
188             return True
189
190         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
191             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
192             return True
193         return False
194
195     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
196         """Called to save the output of a test run."""
197
198         dest = f'{test.name}-output.txt'
199         with open(f'./test_output/{dest}', 'w') as wf:
200             print(message, file=wf)
201             print('-' * len(message), file=wf)
202             wf.write(output)
203
204     def execute_commandline(
205         self,
206         test: TestToRun,
207         *,
208         timeout: float = 120.0,
209     ) -> TestResults:
210         """Execute a particular commandline to run a test."""
211
212         try:
213             output = exec_utils.cmd(
214                 test.cmdline,
215                 timeout_seconds=timeout,
216             )
217             if "***Test Failed***" in output:
218                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
219                 logger.error(msg)
220                 self.persist_output(test, msg, output)
221                 return TestResults(
222                     test.name,
223                     {},
224                     [],
225                     [test.name],
226                     [],
227                 )
228
229             self.persist_output(
230                 test, f'{test.name} ({test.cmdline}) succeeded.', output
231             )
232             logger.debug(
233                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
234             )
235             return TestResults(test.name, {}, [test.name], [], [])
236         except subprocess.TimeoutExpired as e:
237             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
238             logger.error(msg)
239             logger.debug(
240                 '%s: %s output when it timed out: %s',
241                 self.get_name(),
242                 test.name,
243                 e.output,
244             )
245             self.persist_output(test, msg, e.output.decode('utf-8'))
246             return TestResults(
247                 test.name,
248                 {},
249                 [],
250                 [],
251                 [test.name],
252             )
253         except subprocess.CalledProcessError as e:
254             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
255             logger.error(msg)
256             logger.debug(
257                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
258             )
259             self.persist_output(test, msg, e.output.decode('utf-8'))
260             return TestResults(
261                 test.name,
262                 {},
263                 [],
264                 [test.name],
265                 [],
266             )
267
268     @overrides
269     def begin(self, params: TestingParameters) -> TestResults:
270         logger.debug('Thread %s started.', self.get_name())
271         interesting_tests = self.identify_tests()
272         logger.debug(
273             '%s: Identified %d tests to be run.',
274             self.get_name(),
275             len(interesting_tests),
276         )
277
278         # Note: because of @parallelize on run_tests it actually
279         # returns a SmartFuture with a TestResult inside of it.
280         # That's the reason for this Any business.
281         running: List[Any] = []
282         for test_to_run in interesting_tests:
283             running.append(self.run_test(test_to_run))
284             logger.debug(
285                 '%s: Test %s started in the background.',
286                 self.get_name(),
287                 test_to_run.name,
288             )
289             self.test_results.tests_executed[test_to_run.name] = time.time()
290
291         for result in smart_future.wait_any(running, log_exceptions=False):
292             logger.debug('Test %s finished.', result.name)
293
294             # We sometimes run the same test more than once.  Do not allow
295             # one run's results to klobber the other's.
296             self.test_results += result
297             if self.check_for_abort():
298                 logger.debug(
299                     '%s: check_for_abort told us to exit early.', self.get_name()
300                 )
301                 return self.test_results
302
303         logger.debug('Thread %s finished running all tests', self.get_name())
304         return self.test_results
305
306
307 class UnittestTestRunner(TemplatedTestRunner):
308     """Run all known Unittests."""
309
310     @overrides
311     def get_name(self) -> str:
312         return "Unittests"
313
314     @overrides
315     def identify_tests(self) -> List[TestToRun]:
316         ret = []
317         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
318             basename = file_utils.without_path(test)
319             if basename in TESTS_TO_SKIP:
320                 continue
321             if config.config['coverage']:
322                 ret.append(
323                     TestToRun(
324                         name=basename,
325                         kind='unittest capturing coverage',
326                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
327                     )
328                 )
329                 if basename in PERF_SENSATIVE_TESTS:
330                     ret.append(
331                         TestToRun(
332                             name=f'{basename}_no_coverage',
333                             kind='unittest w/o coverage to record perf',
334                             cmdline=f'{test} 2>&1',
335                         )
336                     )
337             else:
338                 ret.append(
339                     TestToRun(
340                         name=basename,
341                         kind='unittest',
342                         cmdline=f'{test} 2>&1',
343                     )
344                 )
345         return ret
346
347     @par.parallelize
348     def run_test(self, test: TestToRun) -> TestResults:
349         return self.execute_commandline(test)
350
351
352 class DoctestTestRunner(TemplatedTestRunner):
353     """Run all known Doctests."""
354
355     @overrides
356     def get_name(self) -> str:
357         return "Doctests"
358
359     @overrides
360     def identify_tests(self) -> List[TestToRun]:
361         ret = []
362         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
363         for test in out.split('\n'):
364             if re.match(r'.*\.py$', test):
365                 basename = file_utils.without_path(test)
366                 if basename in TESTS_TO_SKIP:
367                     continue
368                 if config.config['coverage']:
369                     ret.append(
370                         TestToRun(
371                             name=basename,
372                             kind='doctest capturing coverage',
373                             cmdline=f'coverage run --source ../src {test} 2>&1',
374                         )
375                     )
376                     if basename in PERF_SENSATIVE_TESTS:
377                         ret.append(
378                             TestToRun(
379                                 name=f'{basename}_no_coverage',
380                                 kind='doctest w/o coverage to record perf',
381                                 cmdline=f'python3 {test} 2>&1',
382                             )
383                         )
384                 else:
385                     ret.append(
386                         TestToRun(
387                             name=basename,
388                             kind='doctest',
389                             cmdline=f'python3 {test} 2>&1',
390                         )
391                     )
392         return ret
393
394     @par.parallelize
395     def run_test(self, test: TestToRun) -> TestResults:
396         return self.execute_commandline(test)
397
398
399 class IntegrationTestRunner(TemplatedTestRunner):
400     """Run all know Integration tests."""
401
402     @overrides
403     def get_name(self) -> str:
404         return "Integration Tests"
405
406     @overrides
407     def identify_tests(self) -> List[TestToRun]:
408         ret = []
409         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
410             basename = file_utils.without_path(test)
411             if basename in TESTS_TO_SKIP:
412                 continue
413             if config.config['coverage']:
414                 ret.append(
415                     TestToRun(
416                         name=basename,
417                         kind='integration test capturing coverage',
418                         cmdline=f'coverage run --source ../src {test} 2>&1',
419                     )
420                 )
421                 if basename in PERF_SENSATIVE_TESTS:
422                     ret.append(
423                         TestToRun(
424                             name=f'{basename}_no_coverage',
425                             kind='integration test w/o coverage to capture perf',
426                             cmdline=f'{test} 2>&1',
427                         )
428                     )
429             else:
430                 ret.append(
431                     TestToRun(
432                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
433                     )
434                 )
435         return ret
436
437     @par.parallelize
438     def run_test(self, test: TestToRun) -> TestResults:
439         return self.execute_commandline(test)
440
441
442 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
443     """Give a final report about the tests that were run."""
444     total_problems = 0
445     for result in results.values():
446         if result is None:
447             print('Unexpected unhandled exception in test runner!!!')
448             total_problems += 1
449         else:
450             print(result, end='')
451             total_problems += len(result.tests_failed)
452             total_problems += len(result.tests_timed_out)
453
454     if total_problems > 0:
455         print('Reminder: look in ./test_output to view test output logs')
456     return total_problems
457
458
459 def code_coverage_report():
460     """Give a final code coverage report."""
461     text_utils.header('Code Coverage')
462     exec_utils.cmd('coverage combine .coverage*')
463     out = exec_utils.cmd(
464         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
465     )
466     print(out)
467     print(
468         """To recall this report w/o re-running the tests:
469
470     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
471
472 ...from the 'tests' directory.  Note that subsequent calls to
473 run_tests.py with --coverage will klobber previous results.  See:
474
475     https://coverage.readthedocs.io/en/6.2/
476 """
477     )
478
479
480 @bootstrap.initialize
481 def main() -> Optional[int]:
482     saw_flag = False
483     halt_event = threading.Event()
484     threads: List[TestRunner] = []
485
486     halt_event.clear()
487     params = TestingParameters(
488         halt_on_error=True,
489         halt_event=halt_event,
490     )
491
492     if config.config['coverage']:
493         logger.debug('Clearing existing coverage data via "coverage erase".')
494         exec_utils.cmd('coverage erase')
495
496     if config.config['unittests'] or config.config['all']:
497         saw_flag = True
498         threads.append(UnittestTestRunner(params))
499     if config.config['doctests'] or config.config['all']:
500         saw_flag = True
501         threads.append(DoctestTestRunner(params))
502     if config.config['integration'] or config.config['all']:
503         saw_flag = True
504         threads.append(IntegrationTestRunner(params))
505
506     if not saw_flag:
507         config.print_usage()
508         print('ERROR: one of --unittests, --doctests or --integration is required.')
509         return 1
510
511     for thread in threads:
512         thread.start()
513
514     results: Dict[str, Optional[TestResults]] = {}
515     start_time = time.time()
516     last_update = start_time
517     still_running = {}
518
519     while len(results) != len(threads):
520         started = 0
521         done = 0
522         failed = 0
523
524         for thread in threads:
525             tid = thread.name
526             tr = thread.get_status()
527             started += len(tr.tests_executed)
528             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
529             done += failed + len(tr.tests_succeeded)
530             running = set(tr.tests_executed.keys())
531             running -= set(tr.tests_failed)
532             running -= set(tr.tests_succeeded)
533             running -= set(tr.tests_timed_out)
534             running_with_start_time = {
535                 test: tr.tests_executed[test] for test in running
536             }
537             still_running[tid] = running_with_start_time
538
539             now = time.time()
540             if now - start_time > 5.0:
541                 if now - last_update > 3.0:
542                     last_update = now
543                     update = []
544                     for _, running_dict in still_running.items():
545                         for test_name, start_time in running_dict.items():
546                             if now - start_time > 10.0:
547                                 update.append(f'{test_name}@{now-start_time:.1f}s')
548                             else:
549                                 update.append(test_name)
550                     print(f'\r{ansi.clear_line()}')
551                     if len(update) < 4:
552                         print(f'Still running: {",".join(update)}')
553                     else:
554                         print(f'Still running: {len(update)} tests.')
555
556             if not thread.is_alive():
557                 if tid not in results:
558                     result = thread.join()
559                     if result:
560                         results[tid] = result
561                         if len(result.tests_failed) > 0:
562                             logger.error(
563                                 'Thread %s returned abnormal results; killing the others.',
564                                 tid,
565                             )
566                             halt_event.set()
567                     else:
568                         logger.error(
569                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
570                             tid,
571                         )
572                         halt_event.set()
573                         results[tid] = None
574
575         if failed == 0:
576             color = ansi.fg('green')
577         else:
578             color = ansi.fg('red')
579
580         if started > 0:
581             percent_done = done / started * 100.0
582         else:
583             percent_done = 0.0
584
585         if percent_done < 100.0:
586             print(
587                 text_utils.bar_graph_string(
588                     done,
589                     started,
590                     text=text_utils.BarGraphText.FRACTION,
591                     width=80,
592                     fgcolor=color,
593                 ),
594                 end='\r',
595                 flush=True,
596             )
597         time.sleep(0.5)
598
599     print(f'{ansi.clear_line()}Final Report:')
600     if config.config['coverage']:
601         code_coverage_report()
602     total_problems = test_results_report(results)
603     if total_problems > 0:
604         logging.error(
605             'Exiting with non-zero return code %d due to problems.', total_problems
606         )
607     return total_problems
608
609
610 if __name__ == '__main__':
611     main()