Migration from old pyutilz package name (which, in turn, came from
[pyutils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 from pyutils import ansi, bootstrap, config, exec_utils, text_utils
20 from pyutils.files import file_utils
21 from pyutils.parallelize import parallelize as par
22 from pyutils.parallelize import smart_future, thread_utils
23
24 logger = logging.getLogger(__name__)
25 args = config.add_commandline_args(f'({__file__})', f'Args related to {__file__}')
26 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
27 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
28 args.add_argument(
29     '--integration', '-i', action='store_true', help='Run integration tests.'
30 )
31 args.add_argument(
32     '--all',
33     '-a',
34     action='store_true',
35     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
36 )
37 args.add_argument(
38     '--coverage',
39     '-c',
40     action='store_true',
41     help='Run tests and capture code coverage data',
42 )
43
44 HOME = os.environ['HOME']
45
46 # These tests will be run twice in --coverage mode: once to get code
47 # coverage and then again with not coverage enabeled.  This is because
48 # they pay attention to code performance which is adversely affected
49 # by coverage.
50 PERF_SENSATIVE_TESTS = set(['string_utils_test.py'])
51 TESTS_TO_SKIP = set(['zookeeper_test.py', 'run_tests.py'])
52
53 ROOT = ".."
54
55
56 @dataclass
57 class TestingParameters:
58     halt_on_error: bool
59     """Should we stop as soon as one error has occurred?"""
60
61     halt_event: threading.Event
62     """An event that, when set, indicates to stop ASAP."""
63
64
65 @dataclass
66 class TestToRun:
67     name: str
68     """The name of the test"""
69
70     kind: str
71     """The kind of the test"""
72
73     cmdline: str
74     """The command line to execute"""
75
76
77 @dataclass
78 class TestResults:
79     name: str
80     """The name of this test / set of tests."""
81
82     tests_executed: List[str]
83     """Tests that were executed."""
84
85     tests_succeeded: List[str]
86     """Tests that succeeded."""
87
88     tests_failed: List[str]
89     """Tests that failed."""
90
91     tests_timed_out: List[str]
92     """Tests that timed out."""
93
94     def __add__(self, other):
95         self.tests_executed.extend(other.tests_executed)
96         self.tests_succeeded.extend(other.tests_succeeded)
97         self.tests_failed.extend(other.tests_failed)
98         self.tests_timed_out.extend(other.tests_timed_out)
99         return self
100
101     __radd__ = __add__
102
103     def __repr__(self) -> str:
104         out = f'{self.name}: '
105         out += f'{ansi.fg("green")}'
106         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
107         out += f'{ansi.reset()}.\n'
108
109         if len(self.tests_failed) > 0:
110             out += f'  ..{ansi.fg("red")}'
111             out += f'{len(self.tests_failed)} tests failed'
112             out += f'{ansi.reset()}:\n'
113             for test in self.tests_failed:
114                 out += f'    {test}\n'
115             out += '\n'
116
117         if len(self.tests_timed_out) > 0:
118             out += f'  ..{ansi.fg("yellow")}'
119             out += f'{len(self.tests_timed_out)} tests timed out'
120             out += f'{ansi.reset()}:\n'
121             for test in self.tests_failed:
122                 out += f'    {test}\n'
123             out += '\n'
124         return out
125
126
127 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
128     """A Base class for something that runs a test."""
129
130     def __init__(self, params: TestingParameters):
131         """Create a TestRunner.
132
133         Args:
134             params: Test running paramters.
135
136         """
137         super().__init__(self, target=self.begin, args=[params])
138         self.params = params
139         self.test_results = TestResults(
140             name=self.get_name(),
141             tests_executed=[],
142             tests_succeeded=[],
143             tests_failed=[],
144             tests_timed_out=[],
145         )
146         self.tests_started = 0
147         self.lock = threading.Lock()
148
149     @abstractmethod
150     def get_name(self) -> str:
151         """The name of this test collection."""
152         pass
153
154     def get_status(self) -> Tuple[int, TestResults]:
155         """Ask the TestRunner for its status."""
156         with self.lock:
157             return (self.tests_started, self.test_results)
158
159     @abstractmethod
160     def begin(self, params: TestingParameters) -> TestResults:
161         """Start execution."""
162         pass
163
164
165 class TemplatedTestRunner(TestRunner, ABC):
166     """A TestRunner that has a recipe for executing the tests."""
167
168     @abstractmethod
169     def identify_tests(self) -> List[TestToRun]:
170         """Return a list of tuples (test, cmdline) that should be executed."""
171         pass
172
173     @abstractmethod
174     def run_test(self, test: TestToRun) -> TestResults:
175         """Run a single test and return its TestResults."""
176         pass
177
178     def check_for_abort(self):
179         """Periodically caled to check to see if we need to stop."""
180
181         if self.params.halt_event.is_set():
182             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
183             raise Exception("Kill myself!")
184         if self.params.halt_on_error:
185             if len(self.test_results.tests_failed) > 0:
186                 logger.error(
187                     'Thread %s saw abnormal results; exiting.', self.get_name()
188                 )
189                 raise Exception("Kill myself!")
190
191     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
192         """Called to save the output of a test run."""
193
194         dest = f'{test.name}-output.txt'
195         with open(f'./test_output/{dest}', 'w') as wf:
196             print(message, file=wf)
197             print('-' * len(message), file=wf)
198             wf.write(output)
199
200     def execute_commandline(
201         self,
202         test: TestToRun,
203         *,
204         timeout: float = 120.0,
205     ) -> TestResults:
206         """Execute a particular commandline to run a test."""
207
208         try:
209             output = exec_utils.cmd(
210                 test.cmdline,
211                 timeout_seconds=timeout,
212             )
213             self.persist_output(
214                 test, f'{test.name} ({test.cmdline}) succeeded.', output
215             )
216             logger.debug(
217                 '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline
218             )
219             return TestResults(test.name, [test.name], [test.name], [], [])
220         except subprocess.TimeoutExpired as e:
221             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
222             logger.error(msg)
223             logger.debug(
224                 '%s: %s output when it timed out: %s',
225                 self.get_name(),
226                 test.name,
227                 e.output,
228             )
229             self.persist_output(test, msg, e.output.decode('utf-8'))
230             return TestResults(
231                 test.name,
232                 [test.name],
233                 [],
234                 [],
235                 [test.name],
236             )
237         except subprocess.CalledProcessError as e:
238             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
239             logger.error(msg)
240             logger.debug(
241                 '%s: %s output when it failed: %s', self.get_name(), test.name, e.output
242             )
243             self.persist_output(test, msg, e.output.decode('utf-8'))
244             return TestResults(
245                 test.name,
246                 [test.name],
247                 [],
248                 [test.name],
249                 [],
250             )
251
252     @overrides
253     def begin(self, params: TestingParameters) -> TestResults:
254         logger.debug('Thread %s started.', self.get_name())
255         interesting_tests = self.identify_tests()
256         logger.debug(
257             '%s: Identified %d tests to be run.',
258             self.get_name(),
259             len(interesting_tests),
260         )
261
262         # Note: because of @parallelize on run_tests it actually
263         # returns a SmartFuture with a TestResult inside of it.
264         # That's the reason for this Any business.
265         running: List[Any] = []
266         for test_to_run in interesting_tests:
267             running.append(self.run_test(test_to_run))
268             logger.debug(
269                 '%s: Test %s started in the background.',
270                 self.get_name(),
271                 test_to_run.name,
272             )
273             self.tests_started += 1
274
275         for future in smart_future.wait_any(running):
276             self.check_for_abort()
277             result = future._resolve()
278             logger.debug('Test %s finished.', result.name)
279             self.test_results += result
280
281         logger.debug('Thread %s finished.', self.get_name())
282         return self.test_results
283
284
285 class UnittestTestRunner(TemplatedTestRunner):
286     """Run all known Unittests."""
287
288     @overrides
289     def get_name(self) -> str:
290         return "Unittests"
291
292     @overrides
293     def identify_tests(self) -> List[TestToRun]:
294         ret = []
295         for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'):
296             basename = file_utils.without_path(test)
297             if basename in TESTS_TO_SKIP:
298                 continue
299             if config.config['coverage']:
300                 ret.append(
301                     TestToRun(
302                         name=basename,
303                         kind='unittest capturing coverage',
304                         cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1',
305                     )
306                 )
307                 if basename in PERF_SENSATIVE_TESTS:
308                     ret.append(
309                         TestToRun(
310                             name=basename,
311                             kind='unittest w/o coverage to record perf',
312                             cmdline=f'{test} 2>&1',
313                         )
314                     )
315             else:
316                 ret.append(
317                     TestToRun(
318                         name=basename,
319                         kind='unittest',
320                         cmdline=f'{test} 2>&1',
321                     )
322                 )
323         return ret
324
325     @par.parallelize
326     def run_test(self, test: TestToRun) -> TestResults:
327         return self.execute_commandline(test)
328
329
330 class DoctestTestRunner(TemplatedTestRunner):
331     """Run all known Doctests."""
332
333     @overrides
334     def get_name(self) -> str:
335         return "Doctests"
336
337     @overrides
338     def identify_tests(self) -> List[TestToRun]:
339         ret = []
340         out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*')
341         for test in out.split('\n'):
342             if re.match(r'.*\.py$', test):
343                 basename = file_utils.without_path(test)
344                 if basename in TESTS_TO_SKIP:
345                     continue
346                 if config.config['coverage']:
347                     ret.append(
348                         TestToRun(
349                             name=basename,
350                             kind='doctest capturing coverage',
351                             cmdline=f'coverage run --source ../src {test} 2>&1',
352                         )
353                     )
354                     if basename in PERF_SENSATIVE_TESTS:
355                         ret.append(
356                             TestToRun(
357                                 name=basename,
358                                 kind='doctest w/o coverage to record perf',
359                                 cmdline=f'python3 {test} 2>&1',
360                             )
361                         )
362                 else:
363                     ret.append(
364                         TestToRun(
365                             name=basename,
366                             kind='doctest',
367                             cmdline=f'python3 {test} 2>&1',
368                         )
369                     )
370         return ret
371
372     @par.parallelize
373     def run_test(self, test: TestToRun) -> TestResults:
374         return self.execute_commandline(test)
375
376
377 class IntegrationTestRunner(TemplatedTestRunner):
378     """Run all know Integration tests."""
379
380     @overrides
381     def get_name(self) -> str:
382         return "Integration Tests"
383
384     @overrides
385     def identify_tests(self) -> List[TestToRun]:
386         ret = []
387         for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'):
388             basename = file_utils.without_path(test)
389             if basename in TESTS_TO_SKIP:
390                 continue
391             if config.config['coverage']:
392                 ret.append(
393                     TestToRun(
394                         name=basename,
395                         kind='integration test capturing coverage',
396                         cmdline=f'coverage run --source ../src {test} 2>&1',
397                     )
398                 )
399                 if basename in PERF_SENSATIVE_TESTS:
400                     ret.append(
401                         TestToRun(
402                             name=basename,
403                             kind='integration test w/o coverage to capture perf',
404                             cmdline=f'{test} 2>&1',
405                         )
406                     )
407             else:
408                 ret.append(
409                     TestToRun(
410                         name=basename, kind='integration test', cmdline=f'{test} 2>&1'
411                     )
412                 )
413         return ret
414
415     @par.parallelize
416     def run_test(self, test: TestToRun) -> TestResults:
417         return self.execute_commandline(test)
418
419
420 def test_results_report(results: Dict[str, TestResults]) -> int:
421     """Give a final report about the tests that were run."""
422     total_problems = 0
423     for result in results.values():
424         print(result, end='')
425         total_problems += len(result.tests_failed)
426         total_problems += len(result.tests_timed_out)
427
428     if total_problems > 0:
429         print('Reminder: look in ./test_output to view test output logs')
430     return total_problems
431
432
433 def code_coverage_report():
434     """Give a final code coverage report."""
435     text_utils.header('Code Coverage')
436     exec_utils.cmd('coverage combine .coverage*')
437     out = exec_utils.cmd(
438         'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover'
439     )
440     print(out)
441     print(
442         """To recall this report w/o re-running the tests:
443
444     $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover
445
446 ...from the 'tests' directory.  Note that subsequent calls to
447 run_tests.py with --coverage will klobber previous results.  See:
448
449     https://coverage.readthedocs.io/en/6.2/
450 """
451     )
452
453
454 @bootstrap.initialize
455 def main() -> Optional[int]:
456     saw_flag = False
457     halt_event = threading.Event()
458     threads: List[TestRunner] = []
459
460     halt_event.clear()
461     params = TestingParameters(
462         halt_on_error=True,
463         halt_event=halt_event,
464     )
465
466     if config.config['coverage']:
467         logger.debug('Clearing existing coverage data via "coverage erase".')
468         exec_utils.cmd('coverage erase')
469
470     if config.config['unittests'] or config.config['all']:
471         saw_flag = True
472         threads.append(UnittestTestRunner(params))
473     if config.config['doctests'] or config.config['all']:
474         saw_flag = True
475         threads.append(DoctestTestRunner(params))
476     if config.config['integration'] or config.config['all']:
477         saw_flag = True
478         threads.append(IntegrationTestRunner(params))
479
480     if not saw_flag:
481         config.print_usage()
482         print('ERROR: one of --unittests, --doctests or --integration is required.')
483         return 1
484
485     for thread in threads:
486         thread.start()
487
488     results: Dict[str, TestResults] = {}
489     while len(results) != len(threads):
490         started = 0
491         done = 0
492         failed = 0
493
494         for thread in threads:
495             (s, tr) = thread.get_status()
496             started += s
497             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
498             done += failed + len(tr.tests_succeeded)
499             if not thread.is_alive():
500                 tid = thread.name
501                 if tid not in results:
502                     result = thread.join()
503                     if result:
504                         results[tid] = result
505                         if len(result.tests_failed) > 0:
506                             logger.error(
507                                 'Thread %s returned abnormal results; killing the others.',
508                                 tid,
509                             )
510                             halt_event.set()
511
512         if started > 0:
513             percent_done = done / started
514         else:
515             percent_done = 0.0
516
517         if failed == 0:
518             color = ansi.fg('green')
519         else:
520             color = ansi.fg('red')
521
522         if percent_done < 100.0:
523             print(
524                 text_utils.bar_graph_string(
525                     done,
526                     started,
527                     text=text_utils.BarGraphText.FRACTION,
528                     width=80,
529                     fgcolor=color,
530                 ),
531                 end='\r',
532                 flush=True,
533             )
534         time.sleep(0.5)
535
536     print(f'{ansi.clear_line()}Final Report:')
537     if config.config['coverage']:
538         code_coverage_report()
539     total_problems = test_results_report(results)
540     return total_problems
541
542
543 if __name__ == '__main__':
544     main()