Spacing tweak.
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import smart_future
26 import text_utils
27 import thread_utils
28
29 logger = logging.getLogger(__name__)
30 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
31 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
32 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
33 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
34 args.add_argument(
35     '--all',
36     '-a',
37     action='store_true',
38     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
39 )
40 args.add_argument(
41     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
42 )
43
44 HOME = os.environ['HOME']
45
46 # These tests will be run twice in --coverage mode: once to get code
47 # coverage and then again with not coverage enabeled.  This is because
48 # they pay attention to code performance which is adversely affected
49 # by coverage.
50 PERF_SENSATIVE_TESTS = set(['/home/scott/lib/python_modules/tests/string_utils_test.py'])
51
52
53 @dataclass
54 class TestingParameters:
55     halt_on_error: bool
56     """Should we stop as soon as one error has occurred?"""
57
58     halt_event: threading.Event
59     """An event that, when set, indicates to stop ASAP."""
60
61
62 @dataclass
63 class TestToRun:
64     name: str
65     """The name of the test"""
66
67     kind: str
68     """The kind of the test"""
69
70     cmdline: str
71     """The command line to execute"""
72
73
74 @dataclass
75 class TestResults:
76     name: str
77     """The name of this test / set of tests."""
78
79     tests_executed: List[str]
80     """Tests that were executed."""
81
82     tests_succeeded: List[str]
83     """Tests that succeeded."""
84
85     tests_failed: List[str]
86     """Tests that failed."""
87
88     tests_timed_out: List[str]
89     """Tests that timed out."""
90
91     def __add__(self, other):
92         self.tests_executed.extend(other.tests_executed)
93         self.tests_succeeded.extend(other.tests_succeeded)
94         self.tests_failed.extend(other.tests_failed)
95         self.tests_timed_out.extend(other.tests_timed_out)
96         return self
97
98     __radd__ = __add__
99
100     def __repr__(self) -> str:
101         out = f'{self.name}: '
102         out += f'{ansi.fg("green")}'
103         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
104         out += f'{ansi.reset()}.\n'
105
106         if len(self.tests_failed) > 0:
107             out += f'  ..{ansi.fg("red")}'
108             out += f'{len(self.tests_failed)} tests failed'
109             out += f'{ansi.reset()}:\n'
110             for test in self.tests_failed:
111                 out += f'    {test}\n'
112             out += '\n'
113
114         if len(self.tests_timed_out) > 0:
115             out += f'  ..{ansi.fg("yellow")}'
116             out += f'{len(self.tests_timed_out)} tests timed out'
117             out += f'{ansi.reset()}:\n'
118             for test in self.tests_failed:
119                 out += f'    {test}\n'
120             out += '\n'
121         return out
122
123
124 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
125     """A Base class for something that runs a test."""
126
127     def __init__(self, params: TestingParameters):
128         """Create a TestRunner.
129
130         Args:
131             params: Test running paramters.
132
133         """
134         super().__init__(self, target=self.begin, args=[params])
135         self.params = params
136         self.test_results = TestResults(
137             name=self.get_name(),
138             tests_executed=[],
139             tests_succeeded=[],
140             tests_failed=[],
141             tests_timed_out=[],
142         )
143         self.tests_started = 0
144         self.lock = threading.Lock()
145
146     @abstractmethod
147     def get_name(self) -> str:
148         """The name of this test collection."""
149         pass
150
151     def get_status(self) -> Tuple[int, TestResults]:
152         """Ask the TestRunner for its status."""
153         with self.lock:
154             return (self.tests_started, self.test_results)
155
156     @abstractmethod
157     def begin(self, params: TestingParameters) -> TestResults:
158         """Start execution."""
159         pass
160
161
162 class TemplatedTestRunner(TestRunner, ABC):
163     """A TestRunner that has a recipe for executing the tests."""
164
165     @abstractmethod
166     def identify_tests(self) -> List[TestToRun]:
167         """Return a list of tuples (test, cmdline) that should be executed."""
168         pass
169
170     @abstractmethod
171     def run_test(self, test: TestToRun) -> TestResults:
172         """Run a single test and return its TestResults."""
173         pass
174
175     def check_for_abort(self):
176         """Periodically caled to check to see if we need to stop."""
177
178         if self.params.halt_event.is_set():
179             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
180             raise Exception("Kill myself!")
181         if self.params.halt_on_error:
182             if len(self.test_results.tests_failed) > 0:
183                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
184                 raise Exception("Kill myself!")
185
186     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
187         """Called to save the output of a test run."""
188
189         dest = f'{test.name}-output.txt'
190         with open(f'./test_output/{dest}', 'w') as wf:
191             print(message, file=wf)
192             print('-' * len(message), file=wf)
193             wf.write(output)
194
195     def execute_commandline(
196         self,
197         test: TestToRun,
198         *,
199         timeout: float = 120.0,
200     ) -> TestResults:
201         """Execute a particular commandline to run a test."""
202
203         try:
204             output = exec_utils.cmd(
205                 test.cmdline,
206                 timeout_seconds=timeout,
207             )
208             self.persist_output(test, f'{test.name} ({test.cmdline}) succeeded.', output)
209             logger.debug('%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline)
210             return TestResults(test.name, [test.name], [test.name], [], [])
211         except subprocess.TimeoutExpired as e:
212             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
213             logger.error(msg)
214             logger.debug(
215                 '%s: %s output when it timed out: %s', self.get_name(), test.name, e.output
216             )
217             self.persist_output(test, msg, e.output.decode('utf-8'))
218             return TestResults(
219                 test.name,
220                 [test.name],
221                 [],
222                 [],
223                 [test.name],
224             )
225         except subprocess.CalledProcessError as e:
226             msg = (
227                 f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
228             )
229             logger.error(msg)
230             logger.debug('%s: %s output when it failed: %s', self.get_name(), test.name, e.output)
231             self.persist_output(test, msg, e.output.decode('utf-8'))
232             return TestResults(
233                 test.name,
234                 [test.name],
235                 [],
236                 [test.name],
237                 [],
238             )
239
240     @overrides
241     def begin(self, params: TestingParameters) -> TestResults:
242         logger.debug('Thread %s started.', self.get_name())
243         interesting_tests = self.identify_tests()
244         logger.debug('%s: Identified %d tests to be run.', self.get_name(), len(interesting_tests))
245
246         # Note: because of @parallelize on run_tests it actually
247         # returns a SmartFuture with a TestResult inside of it.
248         # That's the reason for this Any business.
249         running: List[Any] = []
250         for test_to_run in interesting_tests:
251             running.append(self.run_test(test_to_run))
252             logger.debug(
253                 '%s: Test %s started in the background.', self.get_name(), test_to_run.name
254             )
255             self.tests_started += 1
256
257         for future in smart_future.wait_any(running):
258             self.check_for_abort()
259             result = future._resolve()
260             logger.debug('Test %s finished.', result.name)
261             self.test_results += result
262
263         logger.debug('Thread %s finished.', self.get_name())
264         return self.test_results
265
266
267 class UnittestTestRunner(TemplatedTestRunner):
268     """Run all known Unittests."""
269
270     @overrides
271     def get_name(self) -> str:
272         return "Unittests"
273
274     @overrides
275     def identify_tests(self) -> List[TestToRun]:
276         ret = []
277         for test in file_utils.expand_globs('*_test.py'):
278             basename = file_utils.without_path(test)
279             if config.config['coverage']:
280                 ret.append(
281                     TestToRun(
282                         name=basename,
283                         kind='unittest capturing coverage',
284                         cmdline=f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf 2>&1',
285                     )
286                 )
287                 if test in PERF_SENSATIVE_TESTS:
288                     ret.append(
289                         TestToRun(
290                             name=basename,
291                             kind='unittest w/o coverage to record perf',
292                             cmdline=f'{test} 2>&1',
293                         )
294                     )
295             else:
296                 ret.append(
297                     TestToRun(
298                         name=basename,
299                         kind='unittest',
300                         cmdline=f'{test} 2>&1',
301                     )
302                 )
303         return ret
304
305     @par.parallelize
306     def run_test(self, test: TestToRun) -> TestResults:
307         return self.execute_commandline(test)
308
309
310 class DoctestTestRunner(TemplatedTestRunner):
311     """Run all known Doctests."""
312
313     @overrides
314     def get_name(self) -> str:
315         return "Doctests"
316
317     @overrides
318     def identify_tests(self) -> List[TestToRun]:
319         ret = []
320         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
321         for test in out.split('\n'):
322             if re.match(r'.*\.py$', test):
323                 if 'run_tests.py' not in test:
324                     basename = file_utils.without_path(test)
325                     if config.config['coverage']:
326                         ret.append(
327                             TestToRun(
328                                 name=basename,
329                                 kind='doctest capturing coverage',
330                                 cmdline=f'coverage run --source {HOME}/lib {test} 2>&1',
331                             )
332                         )
333                         if test in PERF_SENSATIVE_TESTS:
334                             ret.append(
335                                 TestToRun(
336                                     name=basename,
337                                     kind='doctest w/o coverage to record perf',
338                                     cmdline=f'python3 {test} 2>&1',
339                                 )
340                             )
341                     else:
342                         ret.append(
343                             TestToRun(name=basename, kind='doctest', cmdline=f'python3 {test} 2>&1')
344                         )
345         return ret
346
347     @par.parallelize
348     def run_test(self, test: TestToRun) -> TestResults:
349         return self.execute_commandline(test)
350
351
352 class IntegrationTestRunner(TemplatedTestRunner):
353     """Run all know Integration tests."""
354
355     @overrides
356     def get_name(self) -> str:
357         return "Integration Tests"
358
359     @overrides
360     def identify_tests(self) -> List[TestToRun]:
361         ret = []
362         for test in file_utils.expand_globs('*_itest.py'):
363             basename = file_utils.without_path(test)
364             if config.config['coverage']:
365                 ret.append(
366                     TestToRun(
367                         name=basename,
368                         kind='integration test capturing coverage',
369                         cmdline=f'coverage run --source {HOME}/lib {test} 2>&1',
370                     )
371                 )
372                 if test in PERF_SENSATIVE_TESTS:
373                     ret.append(
374                         TestToRun(
375                             name=basename,
376                             kind='integration test w/o coverage to capture perf',
377                             cmdline=f'{test} 2>&1',
378                         )
379                     )
380             else:
381                 ret.append(
382                     TestToRun(name=basename, kind='integration test', cmdline=f'{test} 2>&1')
383                 )
384         return ret
385
386     @par.parallelize
387     def run_test(self, test: TestToRun) -> TestResults:
388         return self.execute_commandline(test)
389
390
391 def test_results_report(results: Dict[str, TestResults]) -> int:
392     """Give a final report about the tests that were run."""
393     total_problems = 0
394     for result in results.values():
395         print(result, end='')
396         total_problems += len(result.tests_failed)
397         total_problems += len(result.tests_timed_out)
398
399     if total_problems > 0:
400         print('Reminder: look in ./test_output to view test output logs')
401     return total_problems
402
403
404 def code_coverage_report():
405     """Give a final code coverage report."""
406     text_utils.header('Code Coverage')
407     exec_utils.cmd('coverage combine .coverage*')
408     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
409     print(out)
410     print(
411         """To recall this report w/o re-running the tests:
412
413     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
414
415 ...from the 'tests' directory.  Note that subsequent calls to
416 run_tests.py with --coverage will klobber previous results.  See:
417
418     https://coverage.readthedocs.io/en/6.2/
419 """
420     )
421
422
423 @bootstrap.initialize
424 def main() -> Optional[int]:
425     saw_flag = False
426     halt_event = threading.Event()
427     threads: List[TestRunner] = []
428
429     halt_event.clear()
430     params = TestingParameters(
431         halt_on_error=True,
432         halt_event=halt_event,
433     )
434
435     if config.config['coverage']:
436         logger.debug('Clearing existing coverage data via "coverage erase".')
437         exec_utils.cmd('coverage erase')
438
439     if config.config['unittests'] or config.config['all']:
440         saw_flag = True
441         threads.append(UnittestTestRunner(params))
442     if config.config['doctests'] or config.config['all']:
443         saw_flag = True
444         threads.append(DoctestTestRunner(params))
445     if config.config['integration'] or config.config['all']:
446         saw_flag = True
447         threads.append(IntegrationTestRunner(params))
448
449     if not saw_flag:
450         config.print_usage()
451         print('ERROR: one of --unittests, --doctests or --integration is required.')
452         return 1
453
454     for thread in threads:
455         thread.start()
456
457     results: Dict[str, TestResults] = {}
458     while len(results) != len(threads):
459         started = 0
460         done = 0
461         failed = 0
462
463         for thread in threads:
464             (s, tr) = thread.get_status()
465             started += s
466             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
467             done += failed + len(tr.tests_succeeded)
468             if not thread.is_alive():
469                 tid = thread.name
470                 if tid not in results:
471                     result = thread.join()
472                     if result:
473                         results[tid] = result
474                         if len(result.tests_failed) > 0:
475                             logger.error(
476                                 'Thread %s returned abnormal results; killing the others.', tid
477                             )
478                             halt_event.set()
479
480         if started > 0:
481             percent_done = done / started
482         else:
483             percent_done = 0.0
484
485         if failed == 0:
486             color = ansi.fg('green')
487         else:
488             color = ansi.fg('red')
489
490         if percent_done < 100.0:
491             print(
492                 text_utils.bar_graph_string(
493                     done,
494                     started,
495                     text=text_utils.BarGraphText.FRACTION,
496                     width=80,
497                     fgcolor=color,
498                 ),
499                 end='\r',
500                 flush=True,
501             )
502         time.sleep(0.5)
503
504     print(f'{ansi.clear_line()}Final Report:')
505     if config.config['coverage']:
506         code_coverage_report()
507     total_problems = test_results_report(results)
508     return total_problems
509
510
511 if __name__ == '__main__':
512     main()