464358401ff4966be1071f68bbd25aedf72fd23e
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 from __future__ import annotations
8
9 import logging
10 import os
11 import re
12 import subprocess
13 import threading
14 import time
15 from abc import ABC, abstractmethod
16 from dataclasses import dataclass
17 from typing import Any, Dict, List, Optional, Tuple
18
19 from overrides import overrides
20
21 from pyutils import ansi, bootstrap, config, dict_utils, exec_utils, text_utils
22 from pyutils.files import file_utils
23 from pyutils.parallelize import parallelize as par
24 from pyutils.parallelize import smart_future, thread_utils
25
26 logger = logging.getLogger(__name__)
27 args = config.add_commandline_args(
28     f"Run Tests Driver ({__file__})", f"Args related to {__file__}"
29 )
30 args.add_argument("--unittests", "-u", action="store_true", help="Run unittests.")
31 args.add_argument("--doctests", "-d", action="store_true", help="Run doctests.")
32 args.add_argument(
33     "--integration", "-i", action="store_true", help="Run integration tests."
34 )
35 args.add_argument(
36     "--all",
37     "-a",
38     action="store_true",
39     help="Run unittests, doctests and integration tests.  Equivalient to -u -d -i",
40 )
41 args.add_argument(
42     "--coverage",
43     "-c",
44     action="store_true",
45     help="Run tests and capture code coverage data",
46 )
47
48 HOME = os.environ["HOME"]
49
50 # These tests will be run twice in --coverage mode: once to get code
51 # coverage and then again with not coverage enabeled.  This is because
52 # they pay attention to code performance which is adversely affected
53 # by coverage.
54 PERF_SENSATIVE_TESTS = set(["string_utils_test.py"])
55 TESTS_TO_SKIP = set(["zookeeper_test.py", "zookeeper.py", "run_tests.py"])
56
57 ROOT = ".."
58
59
60 @dataclass
61 class TestingParameters:
62     halt_on_error: bool
63     """Should we stop as soon as one error has occurred?"""
64
65     halt_event: threading.Event
66     """An event that, when set, indicates to stop ASAP."""
67
68
69 @dataclass
70 class TestToRun:
71     name: str
72     """The name of the test"""
73
74     kind: str
75     """The kind of the test"""
76
77     cmdline: str
78     """The command line to execute"""
79
80
81 @dataclass
82 class TestResults:
83     name: str
84     """The name of this test / set of tests."""
85
86     tests_executed: Dict[str, float]
87     """Tests that were executed."""
88
89     tests_succeeded: List[str]
90     """Tests that succeeded."""
91
92     tests_failed: List[str]
93     """Tests that failed."""
94
95     tests_timed_out: List[str]
96     """Tests that timed out."""
97
98     def __add__(self, other):
99         merged = dict_utils.coalesce(
100             [self.tests_executed, other.tests_executed],
101             aggregation_function=dict_utils.raise_on_duplicated_keys,
102         )
103         self.tests_executed = merged
104         self.tests_succeeded.extend(other.tests_succeeded)
105         self.tests_failed.extend(other.tests_failed)
106         self.tests_timed_out.extend(other.tests_timed_out)
107         return self
108
109     __radd__ = __add__
110
111     @staticmethod
112     def empty_test_results(suite_name: str) -> TestResults:
113         return TestResults(
114             name=suite_name,
115             tests_executed={},
116             tests_succeeded=[],
117             tests_failed=[],
118             tests_timed_out=[],
119         )
120
121     @staticmethod
122     def single_test_succeeded(name: str) -> TestResults:
123         return TestResults(name, {}, [name], [], [])
124
125     @staticmethod
126     def single_test_failed(name: str) -> TestResults:
127         return TestResults(
128             name,
129             {},
130             [],
131             [name],
132             [],
133         )
134
135     @staticmethod
136     def single_test_timed_out(name: str) -> TestResults:
137         return TestResults(
138             name,
139             {},
140             [],
141             [],
142             [name],
143         )
144
145     def __repr__(self) -> str:
146         out = f"{self.name}: "
147         out += f'{ansi.fg("green")}'
148         out += f"{len(self.tests_succeeded)}/{len(self.tests_executed)} passed"
149         out += f"{ansi.reset()}.\n"
150         tests_with_known_status = len(self.tests_succeeded)
151
152         if len(self.tests_failed) > 0:
153             out += f'  ..{ansi.fg("red")}'
154             out += f"{len(self.tests_failed)} tests failed"
155             out += f"{ansi.reset()}:\n"
156             for test in self.tests_failed:
157                 out += f"    {test}\n"
158             tests_with_known_status += len(self.tests_failed)
159
160         if len(self.tests_timed_out) > 0:
161             out += f'  ..{ansi.fg("lightning yellow")}'
162             out += f"{len(self.tests_timed_out)} tests timed out"
163             out += f"{ansi.reset()}:\n"
164             for test in self.tests_failed:
165                 out += f"    {test}\n"
166             tests_with_known_status += len(self.tests_timed_out)
167
168         missing = len(self.tests_executed) - tests_with_known_status
169         if missing:
170             out += f'  ..{ansi.fg("lightning yellow")}'
171             out += f"{missing} tests aborted early"
172             out += f"{ansi.reset()}\n"
173         return out
174
175     def _key(self) -> Tuple[str, Tuple, Tuple, Tuple]:
176         return (
177             self.name,
178             tuple(self.tests_succeeded),
179             tuple(self.tests_failed),
180             tuple(self.tests_timed_out),
181         )
182
183     def __hash__(self) -> int:
184         return hash(self._key())
185
186
187 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
188     """A Base class for something that runs a test."""
189
190     def __init__(self, params: TestingParameters):
191         """Create a TestRunner.
192
193         Args:
194             params: Test running paramters.
195
196         """
197         super().__init__(self, target=self.begin, args=[params])
198         self.params = params
199         self.test_results = TestResults.empty_test_results(self.get_name())
200         self.lock = threading.Lock()
201
202     @abstractmethod
203     def get_name(self) -> str:
204         """The name of this test collection."""
205         pass
206
207     def get_status(self) -> TestResults:
208         """Ask the TestRunner for its status."""
209         with self.lock:
210             return self.test_results
211
212     @abstractmethod
213     def begin(self, params: TestingParameters) -> TestResults:
214         """Start execution."""
215         pass
216
217
218 class TemplatedTestRunner(TestRunner, ABC):
219     """A TestRunner that has a recipe for executing the tests."""
220
221     def __init__(self, params: TestingParameters):
222         super().__init__(params)
223
224         # Note: because of @parallelize on run_tests it actually
225         # returns a SmartFuture with a TestResult inside of it.
226         # That's the reason for this Any business.
227         self.running: List[Any] = []
228         self.already_cancelled = False
229
230     @abstractmethod
231     def identify_tests(self) -> List[TestToRun]:
232         """Return a list of tuples (test, cmdline) that should be executed."""
233         pass
234
235     @abstractmethod
236     def run_test(self, test: TestToRun) -> TestResults:
237         """Run a single test and return its TestResults."""
238         pass
239
240     def check_for_abort(self) -> bool:
241         """Periodically called to check to see if we need to stop."""
242
243         if self.params.halt_event.is_set():
244             logger.debug("Thread %s saw halt event; exiting.", self.get_name())
245             return True
246
247         if self.params.halt_on_error and len(self.test_results.tests_failed) > 0:
248             logger.debug("Thread %s saw abnormal results; exiting.", self.get_name())
249             return True
250         return False
251
252     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
253         """Called to save the output of a test run."""
254
255         dest = f"{test.name}-output.txt"
256         with open(f"./test_output/{dest}", "w") as wf:
257             print(message, file=wf)
258             print("-" * len(message), file=wf)
259             wf.write(output)
260
261     def execute_commandline(
262         self,
263         test: TestToRun,
264         *,
265         timeout: float = 120.0,
266     ) -> TestResults:
267         """Execute a particular commandline to run a test."""
268
269         msg = f"{self.get_name()}: {test.name} ({test.cmdline}) "
270         try:
271             output = exec_utils.cmd(
272                 test.cmdline,
273                 timeout_seconds=timeout,
274             )
275             if "***Test Failed***" in output:
276                 msg += "failed; doctest failure message detected."
277                 logger.error(msg)
278                 self.persist_output(test, msg, output)
279                 return TestResults.single_test_failed(test.name)
280
281             msg += "succeeded."
282             self.persist_output(test, msg, output)
283             logger.debug(msg)
284             return TestResults.single_test_succeeded(test.name)
285
286         except subprocess.TimeoutExpired as e:
287             msg += f"timed out after {e.timeout:.1f} seconds."
288             logger.error(msg)
289             logger.debug(
290                 "%s: %s output when it timed out: %s",
291                 self.get_name(),
292                 test.name,
293                 e.output,
294             )
295             self.persist_output(test, msg, e.output.decode("utf-8"))
296             return TestResults.single_test_timed_out(test.name)
297
298         except subprocess.CalledProcessError as e:
299             msg += f"failed with exit code {e.returncode}."
300             logger.error(msg)
301             logger.debug(
302                 "%s: %s output when it failed: %s", self.get_name(), test.name, e.output
303             )
304             self.persist_output(test, msg, e.output.decode("utf-8"))
305             return TestResults.single_test_failed(test.name)
306
307     def callback(self):
308         if not self.already_cancelled and self.check_for_abort():
309             logger.debug(
310                 "%s: aborting %d running futures to exit early.",
311                 self.get_name(),
312                 len(self.running),
313             )
314             for x in self.running:
315                 x.wrapped_future.cancel()
316
317     @overrides
318     def begin(self, params: TestingParameters) -> TestResults:
319         logger.debug("Thread %s started.", self.get_name())
320         interesting_tests = self.identify_tests()
321         logger.debug(
322             "%s: Identified %d tests to be run.",
323             self.get_name(),
324             len(interesting_tests),
325         )
326
327         for test_to_run in interesting_tests:
328             self.running.append(self.run_test(test_to_run))
329             logger.debug(
330                 "%s: Test %s started in the background.",
331                 self.get_name(),
332                 test_to_run.name,
333             )
334             self.test_results.tests_executed[test_to_run.name] = time.time()
335
336         already_seen = set()
337         for result in smart_future.wait_any(
338             self.running, timeout=1.0, callback=self.callback, log_exceptions=False
339         ):
340             if result and result not in already_seen:
341                 logger.debug("Test %s finished.", result.name)
342                 self.test_results += result
343                 already_seen.add(result)
344
345             if self.check_for_abort():
346                 logger.error("%s: exiting early.", self.get_name())
347                 return self.test_results
348
349         logger.debug("%s: executed all tests and returning normally", self.get_name())
350         return self.test_results
351
352
353 class UnittestTestRunner(TemplatedTestRunner):
354     """Run all known Unittests."""
355
356     @overrides
357     def get_name(self) -> str:
358         return "Unittests"
359
360     @overrides
361     def identify_tests(self) -> List[TestToRun]:
362         ret = []
363         for test in file_utils.get_matching_files_recursive(ROOT, "*_test.py"):
364             basename = file_utils.without_path(test)
365             if basename in TESTS_TO_SKIP:
366                 continue
367             if config.config["coverage"]:
368                 ret.append(
369                     TestToRun(
370                         name=basename,
371                         kind="unittest capturing coverage",
372                         cmdline=f"coverage run --source ../src {test} --unittests_ignore_perf 2>&1",
373                     )
374                 )
375                 if basename in PERF_SENSATIVE_TESTS:
376                     ret.append(
377                         TestToRun(
378                             name=f"{basename}_no_coverage",
379                             kind="unittest w/o coverage to record perf",
380                             cmdline=f"{test} 2>&1",
381                         )
382                     )
383             else:
384                 ret.append(
385                     TestToRun(
386                         name=basename,
387                         kind="unittest",
388                         cmdline=f"{test} 2>&1",
389                     )
390                 )
391         return ret
392
393     @par.parallelize
394     def run_test(self, test: TestToRun) -> TestResults:
395         return self.execute_commandline(test)
396
397
398 class DoctestTestRunner(TemplatedTestRunner):
399     """Run all known Doctests."""
400
401     @overrides
402     def get_name(self) -> str:
403         return "Doctests"
404
405     @overrides
406     def identify_tests(self) -> List[TestToRun]:
407         ret = []
408         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
409         for test in out.split("\n"):
410             if re.match(r".*\.py$", test):
411                 basename = file_utils.without_path(test)
412                 if basename in TESTS_TO_SKIP:
413                     continue
414                 if config.config["coverage"]:
415                     ret.append(
416                         TestToRun(
417                             name=basename,
418                             kind="doctest capturing coverage",
419                             cmdline=f"coverage run --source ../src {test} 2>&1",
420                         )
421                     )
422                     if basename in PERF_SENSATIVE_TESTS:
423                         ret.append(
424                             TestToRun(
425                                 name=f"{basename}_no_coverage",
426                                 kind="doctest w/o coverage to record perf",
427                                 cmdline=f"python3 {test} 2>&1",
428                             )
429                         )
430                 else:
431                     ret.append(
432                         TestToRun(
433                             name=basename,
434                             kind="doctest",
435                             cmdline=f"python3 {test} 2>&1",
436                         )
437                     )
438         return ret
439
440     @par.parallelize
441     def run_test(self, test: TestToRun) -> TestResults:
442         return self.execute_commandline(test)
443
444
445 class IntegrationTestRunner(TemplatedTestRunner):
446     """Run all know Integration tests."""
447
448     @overrides
449     def get_name(self) -> str:
450         return "Integration Tests"
451
452     @overrides
453     def identify_tests(self) -> List[TestToRun]:
454         ret = []
455         for test in file_utils.get_matching_files_recursive(ROOT, "*_itest.py"):
456             basename = file_utils.without_path(test)
457             if basename in TESTS_TO_SKIP:
458                 continue
459             if config.config["coverage"]:
460                 ret.append(
461                     TestToRun(
462                         name=basename,
463                         kind="integration test capturing coverage",
464                         cmdline=f"coverage run --source ../src {test} 2>&1",
465                     )
466                 )
467                 if basename in PERF_SENSATIVE_TESTS:
468                     ret.append(
469                         TestToRun(
470                             name=f"{basename}_no_coverage",
471                             kind="integration test w/o coverage to capture perf",
472                             cmdline=f"{test} 2>&1",
473                         )
474                     )
475             else:
476                 ret.append(
477                     TestToRun(
478                         name=basename, kind="integration test", cmdline=f"{test} 2>&1"
479                     )
480                 )
481         return ret
482
483     @par.parallelize
484     def run_test(self, test: TestToRun) -> TestResults:
485         return self.execute_commandline(test)
486
487
488 def test_results_report(results: Dict[str, Optional[TestResults]]) -> int:
489     """Give a final report about the tests that were run."""
490     total_problems = 0
491     for result in results.values():
492         if result is None:
493             print("Unexpected unhandled exception in test runner!!!")
494             total_problems += 1
495         else:
496             print(result, end="")
497             total_problems += len(result.tests_failed)
498             total_problems += len(result.tests_timed_out)
499
500     if total_problems > 0:
501         print(
502             f"{ansi.bold()}Test output / logging can be found under ./test_output{ansi.reset()}"
503         )
504     return total_problems
505
506
507 def code_coverage_report():
508     """Give a final code coverage report."""
509     text_utils.header("Code Coverage")
510     exec_utils.cmd("coverage combine .coverage*")
511     out = exec_utils.cmd(
512         "coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover"
513     )
514     print(out)
515     print(
516         f"""To recall this report w/o re-running the tests:
517
518     $ {ansi.bold()}coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover{ansi.reset()}
519
520 ...from the 'tests' directory.  Note that subsequent calls to
521 run_tests.py with --coverage will klobber previous results.  See:
522
523     https://coverage.readthedocs.io/en/6.2/
524 """
525     )
526
527
528 @bootstrap.initialize
529 def main() -> Optional[int]:
530     saw_flag = False
531     threads: List[TestRunner] = []
532
533     halt_event = threading.Event()
534     halt_event.clear()
535     params = TestingParameters(
536         halt_on_error=True,
537         halt_event=halt_event,
538     )
539
540     if config.config["coverage"]:
541         logger.debug('Clearing existing coverage data via "coverage erase".')
542         exec_utils.cmd("coverage erase")
543     if config.config["unittests"] or config.config["all"]:
544         saw_flag = True
545         threads.append(UnittestTestRunner(params))
546     if config.config["doctests"] or config.config["all"]:
547         saw_flag = True
548         threads.append(DoctestTestRunner(params))
549     if config.config["integration"] or config.config["all"]:
550         saw_flag = True
551         threads.append(IntegrationTestRunner(params))
552
553     if not saw_flag:
554         config.print_usage()
555         config.error("One of --unittests, --doctests or --integration is required.", 1)
556
557     for thread in threads:
558         thread.start()
559
560     start_time = time.time()
561     last_update = start_time
562     results: Dict[str, Optional[TestResults]] = {}
563     still_running = {}
564
565     while len(results) != len(threads):
566         started = 0
567         done = 0
568         failed = 0
569
570         for thread in threads:
571             tid = thread.name
572             tr = thread.get_status()
573             started += len(tr.tests_executed)
574             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
575             done += failed + len(tr.tests_succeeded)
576             running = set(tr.tests_executed.keys())
577             running -= set(tr.tests_failed)
578             running -= set(tr.tests_succeeded)
579             running -= set(tr.tests_timed_out)
580             running_with_start_time = {
581                 test: tr.tests_executed[test] for test in running
582             }
583             still_running[tid] = running_with_start_time
584
585             # Maybe print tests that are still running.
586             now = time.time()
587             if now - start_time > 5.0:
588                 if now - last_update > 3.0:
589                     last_update = now
590                     update = []
591                     for _, running_dict in still_running.items():
592                         for test_name, start_time in running_dict.items():
593                             if now - start_time > 10.0:
594                                 update.append(f"{test_name}@{now-start_time:.1f}s")
595                             else:
596                                 update.append(test_name)
597                     print(f"\r{ansi.clear_line()}")
598                     if len(update) < 4:
599                         print(f'Still running: {",".join(update)}')
600                     else:
601                         print(f"Still running: {len(update)} tests.")
602
603             # Maybe signal the other threads to stop too.
604             if not thread.is_alive():
605                 if tid not in results:
606                     result = thread.join()
607                     if result:
608                         results[tid] = result
609                         if (len(result.tests_failed) + len(result.tests_timed_out)) > 0:
610                             logger.error(
611                                 "Thread %s returned abnormal results; killing the others.",
612                                 thread.get_name(),
613                             )
614                             halt_event.set()
615                     else:
616                         logger.error(
617                             "Thread %s took an unhandled exception... bug in run_tests.py?!  Aborting.",
618                             tid,
619                         )
620                         halt_event.set()
621                         results[tid] = None
622
623         color = ansi.fg("green")
624         if failed > 0:
625             color = ansi.fg("red")
626
627         if started > 0:
628             percent_done = done / started * 100.0
629         else:
630             percent_done = 0.0
631
632         if percent_done < 100.0:
633             print(
634                 text_utils.bar_graph_string(
635                     done,
636                     started,
637                     text=text_utils.BarGraphText.FRACTION,
638                     width=72,
639                     fgcolor=color,
640                 ),
641                 end="",
642                 flush=True,
643             )
644             print(f"  {color}{now - start_time:.1f}s{ansi.reset()}", end="\r")
645         time.sleep(0.1)
646
647     print(f"{ansi.clear_line()}\n{ansi.underline()}Final Report:{ansi.reset()}")
648     if config.config["coverage"]:
649         code_coverage_report()
650     print(f"Test suite runtime: {time.time() - start_time:.1f}s")
651     total_problems = test_results_report(results)
652     if total_problems > 0:
653         logging.error(
654             "Exiting with non-zero return code %d due to problems.", total_problems
655         )
656     return total_problems
657
658
659 if __name__ == "__main__":
660     main()