Do not let test names in results collide when we run one to capture
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
27 )
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
30 args.add_argument(
31     '--integration', '-i', action='store_true', help='Run integration tests.'
32 )
33 args.add_argument(
34     '--all',
35     '-a',
36     action='store_true',
37     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
38 )
39 args.add_argument(
40     '--coverage',
41     '-c',
42     action='store_true',
43     help='Run tests and capture code coverage data',
44 )
45
46 HOME = os.environ['HOME']
47
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled.  This is because
50 # they pay attention to code performance which is adversely affected
51 # by coverage.
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
54
55 ROOT = ".."
56
57
58 @dataclass
59 class TestingParameters:
60     halt_on_error: bool
61     """Should we stop as soon as one error has occurred?"""
62
63     halt_event: threading.Event
64     """An event that, when set, indicates to stop ASAP."""
65
66
67 @dataclass
68 class TestToRun:
69     name: str
70     """The name of the test"""
71
72     kind: str
73     """The kind of the test"""
74
75     cmdline: str
76     """The command line to execute"""
77
78
79 @dataclass
80 class TestResults:
81     name: str
82     """The name of this test / set of tests."""
83
84     tests_executed: Dict[str, float]
85     """Tests that were executed."""
86
87     tests_succeeded: List[str]
88     """Tests that succeeded."""
89
90     tests_failed: List[str]
91     """Tests that failed."""
92
93     tests_timed_out: List[str]
94     """Tests that timed out."""
95
96     def __add__(self, other):
97         merged = dict_utils.coalesce(
98             [self.tests_executed, other.tests_executed],
99             aggregation_function=dict_utils.raise_on_duplicated_keys,
100         )
101         self.tests_executed = merged
102         self.tests_succeeded.extend(other.tests_succeeded)
103         self.tests_failed.extend(other.tests_failed)
104         self.tests_timed_out.extend(other.tests_timed_out)
105         return self
106
107     __radd__ = __add__
108
109     def __repr__(self) -> str:
110         out = f'{self.name}: '
111         out += f'{ansi.fg("green")}'
112         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
113         out += f'{ansi.reset()}.\n'
114
115         if len(self.tests_failed) > 0:
116             out += f'  ..{ansi.fg("red")}'
117             out += f'{len(self.tests_failed)} tests failed'
118             out += f'{ansi.reset()}:\n'
119             for test in self.tests_failed:
120                 out += f'    {test}\n'
121             out += '\n'
122
123         if len(self.tests_timed_out) > 0:
124             out += f'  ..{ansi.fg("yellow")}'
125             out += f'{len(self.tests_timed_out)} tests timed out'
126             out += f'{ansi.reset()}:\n'
127             for test in self.tests_failed:
128                 out += f'    {test}\n'
129             out += '\n'
130         return out
131
132
133 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
134     """A Base class for something that runs a test."""
135
136     def __init__(self, params: TestingParameters):
137         """Create a TestRunner.
138
139         Args:
140             params: Test running paramters.
141
142         """
143         super().__init__(self, target=self.begin, args=[params])
144         self.params = params
145         self.test_results = TestResults(
146             name=self.get_name(),
147             tests_executed={},
148             tests_succeeded=[],
149             tests_failed=[],
150             tests_timed_out=[],
151         )
152         self.lock = threading.Lock()
153
154     @abstractmethod
155     def get_name(self) -> str:
156         """The name of this test collection."""
157         pass
158
159     def get_status(self) -> TestResults:
160         """Ask the TestRunner for its status."""
161         with self.lock:
162             return self.test_results
163
164     @abstractmethod
165     def begin(self, params: TestingParameters) -> TestResults:
166         """Start execution."""
167         pass
168
169
170 class TemplatedTestRunner(TestRunner, ABC):
171     """A TestRunner that has a recipe for executing the tests."""
172
173     @abstractmethod
174     def identify_tests(self) -> List[TestToRun]:
175         """Return a list of tuples (test, cmdline) that should be executed."""
176         pass
177
178     @abstractmethod
179     def run_test(self, test: TestToRun) -> TestResults:
180         """Run a single test and return its TestResults."""
181         pass
182
183     def check_for_abort(self) -> bool:
184         """Periodically caled to check to see if we need to stop."""
185
186         if self.params.halt_event.is_set():
187             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
188             return True
189
190         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
191             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
192             return True
193         return False
194
195     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
196         """Called to save the output of a test run."""
197
198         dest = f'{test.name}-output.txt'
199         with open(f'./test_output/{dest}', 'w') as wf:
200             print(message, file=wf)
201             print('-' * len(message), file=wf)
202             wf.write(output)
203
204     def execute_commandline(
205         self,
206         test: TestToRun,
207         *,
208         timeout: float = 120.0,
209     ) -> TestResults:
210         """Execute a particular commandline to run a test."""
211
212         try:
213             output = exec_utils.cmd(
214                 test.cmdline,
215                 timeout_seconds=timeout,
216             )
217             if "***Test Failed***" in output:
218                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
219                 logger.error(msg)
220                 self.persist_output(test, msg, output)
221                 return TestResults(
222                     test.name,
223                     {},
224                     [],
225                     [test.name],
226                     [],
227                 )
228
229             self.persist_output(
230                 test, f'{test.name} ({test.cmdline}) succeeded.', output
231             )
232             logger.debug(
233                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
234             )
235             return TestResults(test.name, {}, [test.name], [], [])
236         except subprocess.TimeoutExpired as e:
237             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
238             logger.error(msg)
239             logger.debug(
240                 '%s: %s output when it timed out: %s',
241                 self.get_name(),
242                 test.name,
243                 e.output,
244             )
245             self.persist_output(test, msg, e.output.decode('utf-8'))
246             return TestResults(
247                 test.name,
248                 {},
249                 [],
250                 [],
251                 [test.name],
252             )
253         except subprocess.CalledProcessError as e:
254             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
255             logger.error(msg)
256             logger.debug(
257                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
258             )
259             self.persist_output(test, msg, e.output.decode('utf-8'))
260             return TestResults(
261                 test.name,
262                 {},
263                 [],
264                 [test.name],
265                 [],
266             )
267
268     @overrides
269     def begin(self, params: TestingParameters) -> TestResults:
270         logger.debug('Thread %s started.', self.get_name())
271         interesting_tests = self.identify_tests()
272         logger.debug(
273             '%s: Identified %d tests to be run.',
274             self.get_name(),
275             len(interesting_tests),
276         )
277
278         # Note: because of @parallelize on run_tests it actually
279         # returns a SmartFuture with a TestResult inside of it.
280         # That's the reason for this Any business.
281         running: List[Any] = []
282         for test_to_run in interesting_tests:
283             running.append(self.run_test(test_to_run))
284             logger.debug(
285                 '%s: Test %s started in the background.',
286                 self.get_name(),
287                 test_to_run.name,
288             )
289             self.test_results.tests_executed[test_to_run.name] = time.time()
290
291         for future in smart_future.wait_any(running, log_exceptions=False):
292             result = future._resolve()
293             logger.debug('Test %s finished.', result.name)
294
295             # We sometimes run the same test more than once.  Do not allow
296             # one run's results to klobber the other's.
297             self.test_results += result
298             if self.check_for_abort():
299                 logger.debug(
300                     '%s: check_for_abort told us to exit early.', self.get_name()
301                 )
302                 return self.test_results
303
304         logger.debug('Thread %s finished running all tests', self.get_name())
305         return self.test_results
306
307
308 class UnittestTestRunner(TemplatedTestRunner):
309     """Run all known Unittests."""
310
311     @overrides
312     def get_name(self) -> str:
313         return "Unittests"
314
315     @overrides
316     def identify_tests(self) -> List[TestToRun]:
317         ret = []
318         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
319             basename = file_utils.without_path(test)
320             if basename in TESTS_TO_SKIP:
321                 continue
322             if config.config['coverage']:
323                 ret.append(
324                     TestToRun(
325                         name=basename,
326                         kind='unittest capturing coverage',
327                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
328                     )
329                 )
330                 if basename in PERF_SENSATIVE_TESTS:
331                     ret.append(
332                         TestToRun(
333                             name=f'{basename}_no_coverage',
334                             kind='unittest w/o coverage to record perf',
335                             cmdline=f'{test} 2>&1',
336                         )
337                     )
338             else:
339                 ret.append(
340                     TestToRun(
341                         name=basename,
342                         kind='unittest',
343                         cmdline=f'{test} 2>&1',
344                     )
345                 )
346         return ret
347
348     @par.parallelize
349     def run_test(self, test: TestToRun) -> TestResults:
350         return self.execute_commandline(test)
351
352
353 class DoctestTestRunner(TemplatedTestRunner):
354     """Run all known Doctests."""
355
356     @overrides
357     def get_name(self) -> str:
358         return "Doctests"
359
360     @overrides
361     def identify_tests(self) -> List[TestToRun]:
362         ret = []
363         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
364         for test in out.split('\n'):
365             if re.match(r'.*\.py$', test):
366                 basename = file_utils.without_path(test)
367                 if basename in TESTS_TO_SKIP:
368                     continue
369                 if config.config['coverage']:
370                     ret.append(
371                         TestToRun(
372                             name=basename,
373                             kind='doctest capturing coverage',
374                             cmdline=f'coverage run --source ../src {test} 2>&1',
375                         )
376                     )
377                     if basename in PERF_SENSATIVE_TESTS:
378                         ret.append(
379                             TestToRun(
380                                 name=f'{basename}_no_coverage',
381                                 kind='doctest w/o coverage to record perf',
382                                 cmdline=f'python3 {test} 2>&1',
383                             )
384                         )
385                 else:
386                     ret.append(
387                         TestToRun(
388                             name=basename,
389                             kind='doctest',
390                             cmdline=f'python3 {test} 2>&1',
391                         )
392                     )
393         return ret
394
395     @par.parallelize
396     def run_test(self, test: TestToRun) -> TestResults:
397         return self.execute_commandline(test)
398
399
400 class IntegrationTestRunner(TemplatedTestRunner):
401     """Run all know Integration tests."""
402
403     @overrides
404     def get_name(self) -> str:
405         return "Integration Tests"
406
407     @overrides
408     def identify_tests(self) -> List[TestToRun]:
409         ret = []
410         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
411             basename = file_utils.without_path(test)
412             if basename in TESTS_TO_SKIP:
413                 continue
414             if config.config['coverage']:
415                 ret.append(
416                     TestToRun(
417                         name=basename,
418                         kind='integration test capturing coverage',
419                         cmdline=f'coverage run --source ../src {test} 2>&1',
420                     )
421                 )
422                 if basename in PERF_SENSATIVE_TESTS:
423                     ret.append(
424                         TestToRun(
425                             name=f'{basename}_no_coverage',
426                             kind='integration test w/o coverage to capture perf',
427                             cmdline=f'{test} 2>&1',
428                         )
429                     )
430             else:
431                 ret.append(
432                     TestToRun(
433                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
434                     )
435                 )
436         return ret
437
438     @par.parallelize
439     def run_test(self, test: TestToRun) -> TestResults:
440         return self.execute_commandline(test)
441
442
443 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
444     """Give a final report about the tests that were run."""
445     total_problems = 0
446     for result in results.values():
447         if result is None:
448             print('Unexpected unhandled exception in test runner!!!')
449             total_problems += 1
450         else:
451             print(result, end='')
452             total_problems += len(result.tests_failed)
453             total_problems += len(result.tests_timed_out)
454
455     if total_problems > 0:
456         print('Reminder: look in ./test_output to view test output logs')
457     return total_problems
458
459
460 def code_coverage_report():
461     """Give a final code coverage report."""
462     text_utils.header('Code Coverage')
463     exec_utils.cmd('coverage combine .coverage*')
464     out = exec_utils.cmd(
465         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
466     )
467     print(out)
468     print(
469         """To recall this report w/o re-running the tests:
470
471     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
472
473 ...from the 'tests' directory.  Note that subsequent calls to
474 run_tests.py with --coverage will klobber previous results.  See:
475
476     https://coverage.readthedocs.io/en/6.2/
477 """
478     )
479
480
481 @bootstrap.initialize
482 def main() -> Optional[int]:
483     saw_flag = False
484     halt_event = threading.Event()
485     threads: List[TestRunner] = []
486
487     halt_event.clear()
488     params = TestingParameters(
489         halt_on_error=True,
490         halt_event=halt_event,
491     )
492
493     if config.config['coverage']:
494         logger.debug('Clearing existing coverage data via "coverage erase".')
495         exec_utils.cmd('coverage erase')
496
497     if config.config['unittests'] or config.config['all']:
498         saw_flag = True
499         threads.append(UnittestTestRunner(params))
500     if config.config['doctests'] or config.config['all']:
501         saw_flag = True
502         threads.append(DoctestTestRunner(params))
503     if config.config['integration'] or config.config['all']:
504         saw_flag = True
505         threads.append(IntegrationTestRunner(params))
506
507     if not saw_flag:
508         config.print_usage()
509         print('ERROR: one of --unittests, --doctests or --integration is required.')
510         return 1
511
512     for thread in threads:
513         thread.start()
514
515     results: Dict[str, Optional[TestResults]] = {}
516     start_time = time.time()
517     last_update = start_time
518     still_running = {}
519
520     while len(results) != len(threads):
521         started = 0
522         done = 0
523         failed = 0
524
525         for thread in threads:
526             tid = thread.name
527             tr = thread.get_status()
528             started += len(tr.tests_executed)
529             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
530             done += failed + len(tr.tests_succeeded)
531             running = set(tr.tests_executed.keys())
532             running -= set(tr.tests_failed)
533             running -= set(tr.tests_succeeded)
534             running -= set(tr.tests_timed_out)
535             running_with_start_time = {
536                 test: tr.tests_executed[test] for test in running
537             }
538             still_running[tid] = running_with_start_time
539
540             now = time.time()
541             if now - start_time > 5.0:
542                 if now - last_update > 3.0:
543                     last_update = now
544                     update = []
545                     for _, running_dict in still_running.items():
546                         for test_name, start_time in running_dict.items():
547                             if now - start_time > 10.0:
548                                 update.append(f'{test_name}@{now-start_time:.1f}s')
549                             else:
550                                 update.append(test_name)
551                     print(f'\r{ansi.clear_line()}')
552                     if len(update) < 4:
553                         print(f'Still running: {",".join(update)}')
554                     else:
555                         print(f'Still running: {len(update)} tests.')
556
557             if not thread.is_alive():
558                 if tid not in results:
559                     result = thread.join()
560                     if result:
561                         results[tid] = result
562                         if len(result.tests_failed) > 0:
563                             logger.error(
564                                 'Thread %s returned abnormal results; killing the others.',
565                                 tid,
566                             )
567                             halt_event.set()
568                     else:
569                         logger.error(
570                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
571                             tid,
572                         )
573                         halt_event.set()
574                         results[tid] = None
575
576         if failed == 0:
577             color = ansi.fg('green')
578         else:
579             color = ansi.fg('red')
580
581         if started > 0:
582             percent_done = done / started * 100.0
583         else:
584             percent_done = 0.0
585
586         if percent_done < 100.0:
587             print(
588                 text_utils.bar_graph_string(
589                     done,
590                     started,
591                     text=text_utils.BarGraphText.FRACTION,
592                     width=80,
593                     fgcolor=color,
594                 ),
595                 end='\r',
596                 flush=True,
597             )
598         time.sleep(0.5)
599
600     print(f'{ansi.clear_line()}Final Report:')
601     if config.config['coverage']:
602         code_coverage_report()
603     total_problems = test_results_report(results)
604     return total_problems
605
606
607 if __name__ == '__main__':
608     main()