Stop calling internal method _resolve in run_tests.py.
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import deferred_operand
22 from pyutils.parallelize import parallelize as par
23 from pyutils.parallelize import smart_future, thread_utils
24
25 logger = logging.getLogger(__name__)
26 args = config.add_commandline_args(
27     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
28 )
29 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
30 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
31 args.add_argument(
32     '--integration', '-i', action='store_true', help='Run integration tests.'
33 )
34 args.add_argument(
35     '--all',
36     '-a',
37     action='store_true',
38     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
39 )
40 args.add_argument(
41     '--coverage',
42     '-c',
43     action='store_true',
44     help='Run tests and capture code coverage data',
45 )
46
47 HOME = os.environ['HOME']
48
49 # These tests will be run twice in --coverage mode: once to get code
50 # coverage and then again with not coverage enabeled.  This is because
51 # they pay attention to code performance which is adversely affected
52 # by coverage.
53 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
54 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
55
56 ROOT = ".."
57
58
59 @dataclass
60 class TestingParameters:
61     halt_on_error: bool
62     """Should we stop as soon as one error has occurred?"""
63
64     halt_event: threading.Event
65     """An event that, when set, indicates to stop ASAP."""
66
67
68 @dataclass
69 class TestToRun:
70     name: str
71     """The name of the test"""
72
73     kind: str
74     """The kind of the test"""
75
76     cmdline: str
77     """The command line to execute"""
78
79
80 @dataclass
81 class TestResults:
82     name: str
83     """The name of this test / set of tests."""
84
85     tests_executed: Dict[str, float]
86     """Tests that were executed."""
87
88     tests_succeeded: List[str]
89     """Tests that succeeded."""
90
91     tests_failed: List[str]
92     """Tests that failed."""
93
94     tests_timed_out: List[str]
95     """Tests that timed out."""
96
97     def __add__(self, other):
98         merged = dict_utils.coalesce(
99             [self.tests_executed, other.tests_executed],
100             aggregation_function=dict_utils.raise_on_duplicated_keys,
101         )
102         self.tests_executed = merged
103         self.tests_succeeded.extend(other.tests_succeeded)
104         self.tests_failed.extend(other.tests_failed)
105         self.tests_timed_out.extend(other.tests_timed_out)
106         return self
107
108     __radd__ = __add__
109
110     def __repr__(self) -> str:
111         out = f'{self.name}: '
112         out += f'{ansi.fg("green")}'
113         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
114         out += f'{ansi.reset()}.\n'
115
116         if len(self.tests_failed) > 0:
117             out += f'  ..{ansi.fg("red")}'
118             out += f'{len(self.tests_failed)} tests failed'
119             out += f'{ansi.reset()}:\n'
120             for test in self.tests_failed:
121                 out += f'    {test}\n'
122             out += '\n'
123
124         if len(self.tests_timed_out) > 0:
125             out += f'  ..{ansi.fg("yellow")}'
126             out += f'{len(self.tests_timed_out)} tests timed out'
127             out += f'{ansi.reset()}:\n'
128             for test in self.tests_failed:
129                 out += f'    {test}\n'
130             out += '\n'
131         return out
132
133
134 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
135     """A Base class for something that runs a test."""
136
137     def __init__(self, params: TestingParameters):
138         """Create a TestRunner.
139
140         Args:
141             params: Test running paramters.
142
143         """
144         super().__init__(self, target=self.begin, args=[params])
145         self.params = params
146         self.test_results = TestResults(
147             name=self.get_name(),
148             tests_executed={},
149             tests_succeeded=[],
150             tests_failed=[],
151             tests_timed_out=[],
152         )
153         self.lock = threading.Lock()
154
155     @abstractmethod
156     def get_name(self) -> str:
157         """The name of this test collection."""
158         pass
159
160     def get_status(self) -> TestResults:
161         """Ask the TestRunner for its status."""
162         with self.lock:
163             return self.test_results
164
165     @abstractmethod
166     def begin(self, params: TestingParameters) -> TestResults:
167         """Start execution."""
168         pass
169
170
171 class TemplatedTestRunner(TestRunner, ABC):
172     """A TestRunner that has a recipe for executing the tests."""
173
174     @abstractmethod
175     def identify_tests(self) -> List[TestToRun]:
176         """Return a list of tuples (test, cmdline) that should be executed."""
177         pass
178
179     @abstractmethod
180     def run_test(self, test: TestToRun) -> TestResults:
181         """Run a single test and return its TestResults."""
182         pass
183
184     def check_for_abort(self) -> bool:
185         """Periodically caled to check to see if we need to stop."""
186
187         if self.params.halt_event.is_set():
188             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
189             return True
190
191         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
192             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
193             return True
194         return False
195
196     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
197         """Called to save the output of a test run."""
198
199         dest = f'{test.name}-output.txt'
200         with open(f'./test_output/{dest}', 'w') as wf:
201             print(message, file=wf)
202             print('-' * len(message), file=wf)
203             wf.write(output)
204
205     def execute_commandline(
206         self,
207         test: TestToRun,
208         *,
209         timeout: float = 120.0,
210     ) -> TestResults:
211         """Execute a particular commandline to run a test."""
212
213         try:
214             output = exec_utils.cmd(
215                 test.cmdline,
216                 timeout_seconds=timeout,
217             )
218             if "***Test Failed***" in output:
219                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
220                 logger.error(msg)
221                 self.persist_output(test, msg, output)
222                 return TestResults(
223                     test.name,
224                     {},
225                     [],
226                     [test.name],
227                     [],
228                 )
229
230             self.persist_output(
231                 test, f'{test.name} ({test.cmdline}) succeeded.', output
232             )
233             logger.debug(
234                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
235             )
236             return TestResults(test.name, {}, [test.name], [], [])
237         except subprocess.TimeoutExpired as e:
238             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
239             logger.error(msg)
240             logger.debug(
241                 '%s: %s output when it timed out: %s',
242                 self.get_name(),
243                 test.name,
244                 e.output,
245             )
246             self.persist_output(test, msg, e.output.decode('utf-8'))
247             return TestResults(
248                 test.name,
249                 {},
250                 [],
251                 [],
252                 [test.name],
253             )
254         except subprocess.CalledProcessError as e:
255             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
256             logger.error(msg)
257             logger.debug(
258                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
259             )
260             self.persist_output(test, msg, e.output.decode('utf-8'))
261             return TestResults(
262                 test.name,
263                 {},
264                 [],
265                 [test.name],
266                 [],
267             )
268
269     @overrides
270     def begin(self, params: TestingParameters) -> TestResults:
271         logger.debug('Thread %s started.', self.get_name())
272         interesting_tests = self.identify_tests()
273         logger.debug(
274             '%s: Identified %d tests to be run.',
275             self.get_name(),
276             len(interesting_tests),
277         )
278
279         # Note: because of @parallelize on run_tests it actually
280         # returns a SmartFuture with a TestResult inside of it.
281         # That's the reason for this Any business.
282         running: List[Any] = []
283         for test_to_run in interesting_tests:
284             running.append(self.run_test(test_to_run))
285             logger.debug(
286                 '%s: Test %s started in the background.',
287                 self.get_name(),
288                 test_to_run.name,
289             )
290             self.test_results.tests_executed[test_to_run.name] = time.time()
291
292         for future in smart_future.wait_any(running, log_exceptions=False):
293             result = deferred_operand.DeferredOperand.resolve(future)
294             logger.debug('Test %s finished.', result.name)
295
296             # We sometimes run the same test more than once.  Do not allow
297             # one run's results to klobber the other's.
298             self.test_results += result
299             if self.check_for_abort():
300                 logger.debug(
301                     '%s: check_for_abort told us to exit early.', self.get_name()
302                 )
303                 return self.test_results
304
305         logger.debug('Thread %s finished running all tests', self.get_name())
306         return self.test_results
307
308
309 class UnittestTestRunner(TemplatedTestRunner):
310     """Run all known Unittests."""
311
312     @overrides
313     def get_name(self) -> str:
314         return "Unittests"
315
316     @overrides
317     def identify_tests(self) -> List[TestToRun]:
318         ret = []
319         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
320             basename = file_utils.without_path(test)
321             if basename in TESTS_TO_SKIP:
322                 continue
323             if config.config['coverage']:
324                 ret.append(
325                     TestToRun(
326                         name=basename,
327                         kind='unittest capturing coverage',
328                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
329                     )
330                 )
331                 if basename in PERF_SENSATIVE_TESTS:
332                     ret.append(
333                         TestToRun(
334                             name=f'{basename}_no_coverage',
335                             kind='unittest w/o coverage to record perf',
336                             cmdline=f'{test} 2>&1',
337                         )
338                     )
339             else:
340                 ret.append(
341                     TestToRun(
342                         name=basename,
343                         kind='unittest',
344                         cmdline=f'{test} 2>&1',
345                     )
346                 )
347         return ret
348
349     @par.parallelize
350     def run_test(self, test: TestToRun) -> TestResults:
351         return self.execute_commandline(test)
352
353
354 class DoctestTestRunner(TemplatedTestRunner):
355     """Run all known Doctests."""
356
357     @overrides
358     def get_name(self) -> str:
359         return "Doctests"
360
361     @overrides
362     def identify_tests(self) -> List[TestToRun]:
363         ret = []
364         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
365         for test in out.split('\n'):
366             if re.match(r'.*\.py$', test):
367                 basename = file_utils.without_path(test)
368                 if basename in TESTS_TO_SKIP:
369                     continue
370                 if config.config['coverage']:
371                     ret.append(
372                         TestToRun(
373                             name=basename,
374                             kind='doctest capturing coverage',
375                             cmdline=f'coverage run --source ../src {test} 2>&1',
376                         )
377                     )
378                     if basename in PERF_SENSATIVE_TESTS:
379                         ret.append(
380                             TestToRun(
381                                 name=f'{basename}_no_coverage',
382                                 kind='doctest w/o coverage to record perf',
383                                 cmdline=f'python3 {test} 2>&1',
384                             )
385                         )
386                 else:
387                     ret.append(
388                         TestToRun(
389                             name=basename,
390                             kind='doctest',
391                             cmdline=f'python3 {test} 2>&1',
392                         )
393                     )
394         return ret
395
396     @par.parallelize
397     def run_test(self, test: TestToRun) -> TestResults:
398         return self.execute_commandline(test)
399
400
401 class IntegrationTestRunner(TemplatedTestRunner):
402     """Run all know Integration tests."""
403
404     @overrides
405     def get_name(self) -> str:
406         return "Integration Tests"
407
408     @overrides
409     def identify_tests(self) -> List[TestToRun]:
410         ret = []
411         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
412             basename = file_utils.without_path(test)
413             if basename in TESTS_TO_SKIP:
414                 continue
415             if config.config['coverage']:
416                 ret.append(
417                     TestToRun(
418                         name=basename,
419                         kind='integration test capturing coverage',
420                         cmdline=f'coverage run --source ../src {test} 2>&1',
421                     )
422                 )
423                 if basename in PERF_SENSATIVE_TESTS:
424                     ret.append(
425                         TestToRun(
426                             name=f'{basename}_no_coverage',
427                             kind='integration test w/o coverage to capture perf',
428                             cmdline=f'{test} 2>&1',
429                         )
430                     )
431             else:
432                 ret.append(
433                     TestToRun(
434                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
435                     )
436                 )
437         return ret
438
439     @par.parallelize
440     def run_test(self, test: TestToRun) -> TestResults:
441         return self.execute_commandline(test)
442
443
444 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
445     """Give a final report about the tests that were run."""
446     total_problems = 0
447     for result in results.values():
448         if result is None:
449             print('Unexpected unhandled exception in test runner!!!')
450             total_problems += 1
451         else:
452             print(result, end='')
453             total_problems += len(result.tests_failed)
454             total_problems += len(result.tests_timed_out)
455
456     if total_problems > 0:
457         print('Reminder: look in ./test_output to view test output logs')
458     return total_problems
459
460
461 def code_coverage_report():
462     """Give a final code coverage report."""
463     text_utils.header('Code Coverage')
464     exec_utils.cmd('coverage combine .coverage*')
465     out = exec_utils.cmd(
466         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
467     )
468     print(out)
469     print(
470         """To recall this report w/o re-running the tests:
471
472     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
473
474 ...from the 'tests' directory.  Note that subsequent calls to
475 run_tests.py with --coverage will klobber previous results.  See:
476
477     https://coverage.readthedocs.io/en/6.2/
478 """
479     )
480
481
482 @bootstrap.initialize
483 def main() -> Optional[int]:
484     saw_flag = False
485     halt_event = threading.Event()
486     threads: List[TestRunner] = []
487
488     halt_event.clear()
489     params = TestingParameters(
490         halt_on_error=True,
491         halt_event=halt_event,
492     )
493
494     if config.config['coverage']:
495         logger.debug('Clearing existing coverage data via "coverage erase".')
496         exec_utils.cmd('coverage erase')
497
498     if config.config['unittests'] or config.config['all']:
499         saw_flag = True
500         threads.append(UnittestTestRunner(params))
501     if config.config['doctests'] or config.config['all']:
502         saw_flag = True
503         threads.append(DoctestTestRunner(params))
504     if config.config['integration'] or config.config['all']:
505         saw_flag = True
506         threads.append(IntegrationTestRunner(params))
507
508     if not saw_flag:
509         config.print_usage()
510         print('ERROR: one of --unittests, --doctests or --integration is required.')
511         return 1
512
513     for thread in threads:
514         thread.start()
515
516     results: Dict[str, Optional[TestResults]] = {}
517     start_time = time.time()
518     last_update = start_time
519     still_running = {}
520
521     while len(results) != len(threads):
522         started = 0
523         done = 0
524         failed = 0
525
526         for thread in threads:
527             tid = thread.name
528             tr = thread.get_status()
529             started += len(tr.tests_executed)
530             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
531             done += failed + len(tr.tests_succeeded)
532             running = set(tr.tests_executed.keys())
533             running -= set(tr.tests_failed)
534             running -= set(tr.tests_succeeded)
535             running -= set(tr.tests_timed_out)
536             running_with_start_time = {
537                 test: tr.tests_executed[test] for test in running
538             }
539             still_running[tid] = running_with_start_time
540
541             now = time.time()
542             if now - start_time > 5.0:
543                 if now - last_update > 3.0:
544                     last_update = now
545                     update = []
546                     for _, running_dict in still_running.items():
547                         for test_name, start_time in running_dict.items():
548                             if now - start_time > 10.0:
549                                 update.append(f'{test_name}@{now-start_time:.1f}s')
550                             else:
551                                 update.append(test_name)
552                     print(f'\r{ansi.clear_line()}')
553                     if len(update) < 4:
554                         print(f'Still running: {",".join(update)}')
555                     else:
556                         print(f'Still running: {len(update)} tests.')
557
558             if not thread.is_alive():
559                 if tid not in results:
560                     result = thread.join()
561                     if result:
562                         results[tid] = result
563                         if len(result.tests_failed) > 0:
564                             logger.error(
565                                 'Thread %s returned abnormal results; killing the others.',
566                                 tid,
567                             )
568                             halt_event.set()
569                     else:
570                         logger.error(
571                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
572                             tid,
573                         )
574                         halt_event.set()
575                         results[tid] = None
576
577         if failed == 0:
578             color = ansi.fg('green')
579         else:
580             color = ansi.fg('red')
581
582         if started > 0:
583             percent_done = done / started * 100.0
584         else:
585             percent_done = 0.0
586
587         if percent_done < 100.0:
588             print(
589                 text_utils.bar_graph_string(
590                     done,
591                     started,
592                     text=text_utils.BarGraphText.FRACTION,
593                     width=80,
594                     fgcolor=color,
595                 ),
596                 end='\r',
597                 flush=True,
598             )
599         time.sleep(0.5)
600
601     print(f'{ansi.clear_line()}Final Report:')
602     if config.config['coverage']:
603         code_coverage_report()
604     total_problems = test_results_report(results)
605     if total_problems > 0:
606         logging.error(
607             'Exiting with non-zero return code %d due to problems.', total_problems
608         )
609     return total_problems
610
611
612 if __name__ == '__main__':
613     main()