b398e1826a79d259ad72ea0a3c693e0252187700
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(
26     f'Run Tests Driver ({__file__})', f'Args related to {__file__}'
27 )
28 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
29 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
30 args.add_argument(
31     '--integration', '-i', action='store_true', help='Run integration tests.'
32 )
33 args.add_argument(
34     '--all',
35     '-a',
36     action='store_true',
37     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
38 )
39 args.add_argument(
40     '--coverage',
41     '-c',
42     action='store_true',
43     help='Run tests and capture code coverage data',
44 )
45
46 HOME = os.environ['HOME']
47
48 # These tests will be run twice in --coverage mode: once to get code
49 # coverage and then again with not coverage enabeled.  This is because
50 # they pay attention to code performance which is adversely affected
51 # by coverage.
52 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
53 TESTS_TO_SKIP = set(['zookeeper_test.py', 'zookeeper.py', 'run_tests.py'])
54
55 ROOT = ".."
56
57
58 @dataclass
59 class TestingParameters:
60     halt_on_error: bool
61     """Should we stop as soon as one error has occurred?"""
62
63     halt_event: threading.Event
64     """An event that, when set, indicates to stop ASAP."""
65
66
67 @dataclass
68 class TestToRun:
69     name: str
70     """The name of the test"""
71
72     kind: str
73     """The kind of the test"""
74
75     cmdline: str
76     """The command line to execute"""
77
78
79 @dataclass
80 class TestResults:
81     name: str
82     """The name of this test / set of tests."""
83
84     tests_executed: Dict[str, float]
85     """Tests that were executed."""
86
87     tests_succeeded: List[str]
88     """Tests that succeeded."""
89
90     tests_failed: List[str]
91     """Tests that failed."""
92
93     tests_timed_out: List[str]
94     """Tests that timed out."""
95
96     def __add__(self, other):
97         merged = dict_utils.coalesce(
98             [self.tests_executed, other.tests_executed],
99             aggregation_function=dict_utils.raise_on_duplicated_keys,
100         )
101         self.tests_executed = merged
102         self.tests_succeeded.extend(other.tests_succeeded)
103         self.tests_failed.extend(other.tests_failed)
104         self.tests_timed_out.extend(other.tests_timed_out)
105         return self
106
107     __radd__ = __add__
108
109     def __repr__(self) -> str:
110         out = f'{self.name}: '
111         out += f'{ansi.fg("green")}'
112         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
113         out += f'{ansi.reset()}.\n'
114
115         if len(self.tests_failed) > 0:
116             out += f'  ..{ansi.fg("red")}'
117             out += f'{len(self.tests_failed)} tests failed'
118             out += f'{ansi.reset()}:\n'
119             for test in self.tests_failed:
120                 out += f'    {test}\n'
121             out += '\n'
122
123         if len(self.tests_timed_out) > 0:
124             out += f'  ..{ansi.fg("yellow")}'
125             out += f'{len(self.tests_timed_out)} tests timed out'
126             out += f'{ansi.reset()}:\n'
127             for test in self.tests_failed:
128                 out += f'    {test}\n'
129             out += '\n'
130         return out
131
132
133 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
134     """A Base class for something that runs a test."""
135
136     def __init__(self, params: TestingParameters):
137         """Create a TestRunner.
138
139         Args:
140             params: Test running paramters.
141
142         """
143         super().__init__(self, target=self.begin, args=[params])
144         self.params = params
145         self.test_results = TestResults(
146             name=self.get_name(),
147             tests_executed={},
148             tests_succeeded=[],
149             tests_failed=[],
150             tests_timed_out=[],
151         )
152         self.lock = threading.Lock()
153
154     @abstractmethod
155     def get_name(self) -> str:
156         """The name of this test collection."""
157         pass
158
159     def get_status(self) -> TestResults:
160         """Ask the TestRunner for its status."""
161         with self.lock:
162             return self.test_results
163
164     @abstractmethod
165     def begin(self, params: TestingParameters) -> TestResults:
166         """Start execution."""
167         pass
168
169
170 class TemplatedTestRunner(TestRunner, ABC):
171     """A TestRunner that has a recipe for executing the tests."""
172
173     @abstractmethod
174     def identify_tests(self) -> List[TestToRun]:
175         """Return a list of tuples (test, cmdline) that should be executed."""
176         pass
177
178     @abstractmethod
179     def run_test(self, test: TestToRun) -> TestResults:
180         """Run a single test and return its TestResults."""
181         pass
182
183     def check_for_abort(self) -> bool:
184         """Periodically caled to check to see if we need to stop."""
185
186         if self.params.halt_event.is_set():
187             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
188             return True
189
190         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
191             logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
192             return True
193         return False
194
195     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
196         """Called to save the output of a test run."""
197
198         dest = f'{test.name}-output.txt'
199         with open(f'./test_output/{dest}', 'w') as wf:
200             print(message, file=wf)
201             print('-' * len(message), file=wf)
202             wf.write(output)
203
204     def execute_commandline(
205         self,
206         test: TestToRun,
207         *,
208         timeout: float = 120.0,
209     ) -> TestResults:
210         """Execute a particular commandline to run a test."""
211
212         try:
213             output = exec_utils.cmd(
214                 test.cmdline,
215                 timeout_seconds=timeout,
216             )
217             if "***Test Failed***" in output:
218                 msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; doctest failure message detected'
219                 logger.error(msg)
220                 self.persist_output(test, msg, output)
221                 return TestResults(
222                     test.name,
223                     {},
224                     [],
225                     [test.name],
226                     [],
227                 )
228
229             self.persist_output(
230                 test, f'{test.name} ({test.cmdline}) succeeded.', output
231             )
232             logger.debug(
233                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
234             )
235             return TestResults(test.name, {}, [test.name], [], [])
236         except subprocess.TimeoutExpired as e:
237             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
238             logger.error(msg)
239             logger.debug(
240                 '%s: %s output when it timed out: %s',
241                 self.get_name(),
242                 test.name,
243                 e.output,
244             )
245             self.persist_output(test, msg, e.output.decode('utf-8'))
246             return TestResults(
247                 test.name,
248                 {},
249                 [],
250                 [],
251                 [test.name],
252             )
253         except subprocess.CalledProcessError as e:
254             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
255             logger.error(msg)
256             logger.debug(
257                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
258             )
259             self.persist_output(test, msg, e.output.decode('utf-8'))
260             return TestResults(
261                 test.name,
262                 {},
263                 [],
264                 [test.name],
265                 [],
266             )
267
268     @overrides
269     def begin(self, params: TestingParameters) -> TestResults:
270         logger.debug('Thread %s started.', self.get_name())
271         interesting_tests = self.identify_tests()
272         logger.debug(
273             '%s: Identified %d tests to be run.',
274             self.get_name(),
275             len(interesting_tests),
276         )
277
278         # Note: because of @parallelize on run_tests it actually
279         # returns a SmartFuture with a TestResult inside of it.
280         # That's the reason for this Any business.
281         running: List[Any] = []
282         for test_to_run in interesting_tests:
283             running.append(self.run_test(test_to_run))
284             logger.debug(
285                 '%s: Test %s started in the background.',
286                 self.get_name(),
287                 test_to_run.name,
288             )
289             self.test_results.tests_executed[test_to_run.name] = time.time()
290
291         for future in smart_future.wait_any(running, log_exceptions=False):
292             result = future._resolve()
293             logger.debug('Test %s finished.', result.name)
294             self.test_results += result
295             if self.check_for_abort():
296                 logger.debug(
297                     '%s: check_for_abort told us to exit early.', self.get_name()
298                 )
299                 return self.test_results
300
301         logger.debug('Thread %s finished running all tests', self.get_name())
302         return self.test_results
303
304
305 class UnittestTestRunner(TemplatedTestRunner):
306     """Run all known Unittests."""
307
308     @overrides
309     def get_name(self) -> str:
310         return "Unittests"
311
312     @overrides
313     def identify_tests(self) -> List[TestToRun]:
314         ret = []
315         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
316             basename = file_utils.without_path(test)
317             if basename in TESTS_TO_SKIP:
318                 continue
319             if config.config['coverage']:
320                 ret.append(
321                     TestToRun(
322                         name=basename,
323                         kind='unittest capturing coverage',
324                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
325                     )
326                 )
327                 if basename in PERF_SENSATIVE_TESTS:
328                     ret.append(
329                         TestToRun(
330                             name=basename,
331                             kind='unittest w/o coverage to record perf',
332                             cmdline=f'{test} 2>&1',
333                         )
334                     )
335             else:
336                 ret.append(
337                     TestToRun(
338                         name=basename,
339                         kind='unittest',
340                         cmdline=f'{test} 2>&1',
341                     )
342                 )
343         return ret
344
345     @par.parallelize
346     def run_test(self, test: TestToRun) -> TestResults:
347         return self.execute_commandline(test)
348
349
350 class DoctestTestRunner(TemplatedTestRunner):
351     """Run all known Doctests."""
352
353     @overrides
354     def get_name(self) -> str:
355         return "Doctests"
356
357     @overrides
358     def identify_tests(self) -> List[TestToRun]:
359         ret = []
360         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
361         for test in out.split('\n'):
362             if re.match(r'.*\.py$', test):
363                 basename = file_utils.without_path(test)
364                 if basename in TESTS_TO_SKIP:
365                     continue
366                 if config.config['coverage']:
367                     ret.append(
368                         TestToRun(
369                             name=basename,
370                             kind='doctest capturing coverage',
371                             cmdline=f'coverage run --source ../src {test} 2>&1',
372                         )
373                     )
374                     if basename in PERF_SENSATIVE_TESTS:
375                         ret.append(
376                             TestToRun(
377                                 name=basename,
378                                 kind='doctest w/o coverage to record perf',
379                                 cmdline=f'python3 {test} 2>&1',
380                             )
381                         )
382                 else:
383                     ret.append(
384                         TestToRun(
385                             name=basename,
386                             kind='doctest',
387                             cmdline=f'python3 {test} 2>&1',
388                         )
389                     )
390         return ret
391
392     @par.parallelize
393     def run_test(self, test: TestToRun) -> TestResults:
394         return self.execute_commandline(test)
395
396
397 class IntegrationTestRunner(TemplatedTestRunner):
398     """Run all know Integration tests."""
399
400     @overrides
401     def get_name(self) -> str:
402         return "Integration Tests"
403
404     @overrides
405     def identify_tests(self) -> List[TestToRun]:
406         ret = []
407         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
408             basename = file_utils.without_path(test)
409             if basename in TESTS_TO_SKIP:
410                 continue
411             if config.config['coverage']:
412                 ret.append(
413                     TestToRun(
414                         name=basename,
415                         kind='integration test capturing coverage',
416                         cmdline=f'coverage run --source ../src {test} 2>&1',
417                     )
418                 )
419                 if basename in PERF_SENSATIVE_TESTS:
420                     ret.append(
421                         TestToRun(
422                             name=basename,
423                             kind='integration test w/o coverage to capture perf',
424                             cmdline=f'{test} 2>&1',
425                         )
426                     )
427             else:
428                 ret.append(
429                     TestToRun(
430                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
431                     )
432                 )
433         return ret
434
435     @par.parallelize
436     def run_test(self, test: TestToRun) -> TestResults:
437         return self.execute_commandline(test)
438
439
440 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
441     """Give a final report about the tests that were run."""
442     total_problems = 0
443     for result in results.values():
444         if result is None:
445             print('Unexpected unhandled exception in test runner!!!')
446             total_problems += 1
447         else:
448             print(result, end='')
449             total_problems += len(result.tests_failed)
450             total_problems += len(result.tests_timed_out)
451
452     if total_problems > 0:
453         print('Reminder: look in ./test_output to view test output logs')
454     return total_problems
455
456
457 def code_coverage_report():
458     """Give a final code coverage report."""
459     text_utils.header('Code Coverage')
460     exec_utils.cmd('coverage combine .coverage*')
461     out = exec_utils.cmd(
462         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
463     )
464     print(out)
465     print(
466         """To recall this report w/o re-running the tests:
467
468     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
469
470 ...from the 'tests' directory.  Note that subsequent calls to
471 run_tests.py with --coverage will klobber previous results.  See:
472
473     https://coverage.readthedocs.io/en/6.2/
474 """
475     )
476
477
478 @bootstrap.initialize
479 def main() -> Optional[int]:
480     saw_flag = False
481     halt_event = threading.Event()
482     threads: List[TestRunner] = []
483
484     halt_event.clear()
485     params = TestingParameters(
486         halt_on_error=True,
487         halt_event=halt_event,
488     )
489
490     if config.config['coverage']:
491         logger.debug('Clearing existing coverage data via "coverage erase".')
492         exec_utils.cmd('coverage erase')
493
494     if config.config['unittests'] or config.config['all']:
495         saw_flag = True
496         threads.append(UnittestTestRunner(params))
497     if config.config['doctests'] or config.config['all']:
498         saw_flag = True
499         threads.append(DoctestTestRunner(params))
500     if config.config['integration'] or config.config['all']:
501         saw_flag = True
502         threads.append(IntegrationTestRunner(params))
503
504     if not saw_flag:
505         config.print_usage()
506         print('ERROR: one of --unittests, --doctests or --integration is required.')
507         return 1
508
509     for thread in threads:
510         thread.start()
511
512     results: Dict[str, Optional[TestResults]] = {}
513     start_time = time.time()
514     last_update = start_time
515     still_running = {}
516
517     while len(results) != len(threads):
518         started = 0
519         done = 0
520         failed = 0
521
522         for thread in threads:
523             tid = thread.name
524             tr = thread.get_status()
525             started += len(tr.tests_executed)
526             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
527             done += failed + len(tr.tests_succeeded)
528             running = set(tr.tests_executed.keys())
529             running -= set(tr.tests_failed)
530             running -= set(tr.tests_succeeded)
531             running -= set(tr.tests_timed_out)
532             running_with_start_time = {
533                 test: tr.tests_executed[test] for test in running
534             }
535             still_running[tid] = running_with_start_time
536
537             now = time.time()
538             if now - start_time > 5.0:
539                 if now - last_update > 3.0:
540                     last_update = now
541                     update = []
542                     for _, running_dict in still_running.items():
543                         for test_name, start_time in running_dict.items():
544                             if now - start_time > 10.0:
545                                 update.append(f'{test_name}@{now-start_time:.1f}s')
546                             else:
547                                 update.append(test_name)
548                     print(f'\r{ansi.clear_line()}')
549                     if len(update) < 4:
550                         print(f'Still running: {",".join(update)}')
551                     else:
552                         print(f'Still running: {len(update)} tests.')
553
554             if not thread.is_alive():
555                 if tid not in results:
556                     result = thread.join()
557                     if result:
558                         results[tid] = result
559                         if len(result.tests_failed) > 0:
560                             logger.error(
561                                 'Thread %s returned abnormal results; killing the others.',
562                                 tid,
563                             )
564                             halt_event.set()
565                     else:
566                         logger.error(
567                             'Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.',
568                             tid,
569                         )
570                         halt_event.set()
571                         results[tid] = None
572
573         if started > 0:
574             percent_done = done / started
575         else:
576             percent_done = 0.0
577
578         if failed == 0:
579             color = ansi.fg('green')
580         else:
581             color = ansi.fg('red')
582
583         if percent_done < 100.0:
584             print(
585                 text_utils.bar_graph_string(
586                     done,
587                     started,
588                     text=text_utils.BarGraphText.FRACTION,
589                     width=80,
590                     fgcolor=color,
591                 ),
592                 end='\r',
593                 flush=True,
594             )
595         time.sleep(0.5)
596
597     print(f'{ansi.clear_line()}Final Report:')
598     if config.config['coverage']:
599         code_coverage_report()
600     total_problems = test_results_report(results)
601     return total_problems
602
603
604 if __name__ == '__main__':
605     main()