c96f882bb3d76306a2932aeb01d047b177fea456
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.  Used in a git pre-commit hook.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional, Tuple
16
17 from overrides import overrides
18
19 import ansi
20 import bootstrap
21 import config
22 import exec_utils
23 import file_utils
24 import parallelize as par
25 import smart_future
26 import text_utils
27 import thread_utils
28
29 logger = logging.getLogger(__name__)
30 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
31 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
32 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
33 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
34 args.add_argument(
35     '--all',
36     '-a',
37     action='store_true',
38     help='Run unittests, doctests and integration tests.  Equivalient to -u -d -i',
39 )
40 args.add_argument(
41     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
42 )
43
44 HOME = os.environ['HOME']
45
46 # These tests will be run twice in --coverage mode: once to get code
47 # coverage and then again with not coverage enabeled.  This is because
48 # they pay attention to code performance which is adversely affected
49 # by coverage.
50 PERF_SENSATIVE_TESTS = set(['/home/scott/lib/python_modules/tests/string_utils_test.py'])
51
52
53 @dataclass
54 class TestingParameters:
55     halt_on_error: bool
56     """Should we stop as soon as one error has occurred?"""
57
58     halt_event: threading.Event
59     """An event that, when set, indicates to stop ASAP."""
60
61
62 @dataclass
63 class TestToRun:
64     name: str
65     """The name of the test"""
66
67     kind: str
68     """The kind of the test"""
69
70     cmdline: str
71     """The command line to execute"""
72
73
74 @dataclass
75 class TestResults:
76     name: str
77     """The name of this test / set of tests."""
78
79     tests_executed: List[str]
80     """Tests that were executed."""
81
82     tests_succeeded: List[str]
83     """Tests that succeeded."""
84
85     tests_failed: List[str]
86     """Tests that failed."""
87
88     tests_timed_out: List[str]
89     """Tests that timed out."""
90
91     def __add__(self, other):
92         self.tests_executed.extend(other.tests_executed)
93         self.tests_succeeded.extend(other.tests_succeeded)
94         self.tests_failed.extend(other.tests_failed)
95         self.tests_timed_out.extend(other.tests_timed_out)
96         return self
97
98     __radd__ = __add__
99
100     def __repr__(self) -> str:
101         out = f'{self.name}: '
102         out += f'{ansi.fg("green")}'
103         out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed'
104         out += f'{ansi.reset()}.\n'
105
106         if len(self.tests_failed) > 0:
107             out += f'  ..{ansi.fg("red")}'
108             out += f'{len(self.tests_failed)} tests failed'
109             out += f'{ansi.reset()}:\n'
110             for test in self.tests_failed:
111                 out += f'    {test}\n'
112             out += '\n'
113
114         if len(self.tests_timed_out) > 0:
115             out += f'  ..{ansi.fg("yellow")}'
116             out += f'{len(self.tests_timed_out)} tests timed out'
117             out += f'{ansi.reset()}:\n'
118             for test in self.tests_failed:
119                 out += f'    {test}\n'
120             out += '\n'
121         return out
122
123
124 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
125     """A Base class for something that runs a test."""
126
127     def __init__(self, params: TestingParameters):
128         """Create a TestRunner.
129
130         Args:
131             params: Test running paramters.
132
133         """
134         super().__init__(self, target=self.begin, args=[params])
135         self.params = params
136         self.test_results = TestResults(
137             name=self.get_name(),
138             tests_executed=[],
139             tests_succeeded=[],
140             tests_failed=[],
141             tests_timed_out=[],
142         )
143         self.tests_started = 0
144
145     @abstractmethod
146     def get_name(self) -> str:
147         """The name of this test collection."""
148         pass
149
150     def get_status(self) -> Tuple[int, TestResults]:
151         """Ask the TestRunner for its status."""
152         return (self.tests_started, self.test_results)
153
154     @abstractmethod
155     def begin(self, params: TestingParameters) -> TestResults:
156         """Start execution."""
157         pass
158
159
160 class TemplatedTestRunner(TestRunner, ABC):
161     """A TestRunner that has a recipe for executing the tests."""
162
163     @abstractmethod
164     def identify_tests(self) -> List[TestToRun]:
165         """Return a list of tuples (test, cmdline) that should be executed."""
166         pass
167
168     @abstractmethod
169     def run_test(self, test: TestToRun) -> TestResults:
170         """Run a single test and return its TestResults."""
171         pass
172
173     def check_for_abort(self):
174         """Periodically caled to check to see if we need to stop."""
175
176         if self.params.halt_event.is_set():
177             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
178             raise Exception("Kill myself!")
179         if self.params.halt_on_error:
180             if len(self.test_results.tests_failed) > 0:
181                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
182                 raise Exception("Kill myself!")
183
184     def persist_output(self, test: TestToRun, message: str, output: str) -> None:
185         """Called to save the output of a test run."""
186
187         dest = f'{test.name}-output.txt'
188         with open(f'./test_output/{dest}', 'w') as wf:
189             print(message, file=wf)
190             print('-' * len(message), file=wf)
191             wf.write(output)
192
193     def execute_commandline(
194         self,
195         test: TestToRun,
196         *,
197         timeout: float = 120.0,
198     ) -> TestResults:
199         """Execute a particular commandline to run a test."""
200
201         try:
202             output = exec_utils.cmd(
203                 test.cmdline,
204                 timeout_seconds=timeout,
205             )
206             self.persist_output(test, f'{test.name} ({test.cmdline}) succeeded.', output)
207             logger.debug('%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline)
208             return TestResults(test.name, [test.name], [test.name], [], [])
209         except subprocess.TimeoutExpired as e:
210             msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.'
211             logger.error(msg)
212             logger.debug(
213                 '%s: %s output when it timed out: %s', self.get_name(), test.name, e.output
214             )
215             self.persist_output(test, msg, e.output.decode('utf-8'))
216             return TestResults(
217                 test.name,
218                 [test.name],
219                 [],
220                 [],
221                 [test.name],
222             )
223         except subprocess.CalledProcessError as e:
224             msg = (
225                 f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}'
226             )
227             logger.error(msg)
228             logger.debug('%s: %s output when it failed: %s', self.get_name(), test.name, e.output)
229             self.persist_output(test, msg, e.output.decode('utf-8'))
230             return TestResults(
231                 test.name,
232                 [test.name],
233                 [],
234                 [test.name],
235                 [],
236             )
237
238     @overrides
239     def begin(self, params: TestingParameters) -> TestResults:
240         logger.debug('Thread %s started.', self.get_name())
241         interesting_tests = self.identify_tests()
242         logger.debug('%s: Identified %d tests to be run.', self.get_name(), len(interesting_tests))
243
244         # Note: because of @parallelize on run_tests it actually
245         # returns a SmartFuture with a TestResult inside of it.
246         # That's the reason for this Any business.
247         running: List[Any] = []
248         for test_to_run in interesting_tests:
249             running.append(self.run_test(test_to_run))
250             logger.debug(
251                 '%s: Test %s started in the background.', self.get_name(), test_to_run.name
252             )
253             self.tests_started += 1
254
255         for future in smart_future.wait_any(running):
256             self.check_for_abort()
257             result = future._resolve()
258             logger.debug('Test %s finished.', result.name)
259             self.test_results += result
260
261         logger.debug('Thread %s finished.', self.get_name())
262         return self.test_results
263
264
265 class UnittestTestRunner(TemplatedTestRunner):
266     """Run all known Unittests."""
267
268     @overrides
269     def get_name(self) -> str:
270         return "Unittests"
271
272     @overrides
273     def identify_tests(self) -> List[TestToRun]:
274         ret = []
275         for test in file_utils.expand_globs('*_test.py'):
276             basename = file_utils.without_path(test)
277             if config.config['coverage']:
278                 ret.append(
279                     TestToRun(
280                         name=basename,
281                         kind='unittest capturing coverage',
282                         cmdline=f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf 2>&1',
283                     )
284                 )
285                 if test in PERF_SENSATIVE_TESTS:
286                     ret.append(
287                         TestToRun(
288                             name=basename,
289                             kind='unittest w/o coverage to record perf',
290                             cmdline=f'{test} 2>&1',
291                         )
292                     )
293             else:
294                 ret.append(
295                     TestToRun(
296                         name=basename,
297                         kind='unittest',
298                         cmdline=f'{test} 2>&1',
299                     )
300                 )
301         return ret
302
303     @par.parallelize
304     def run_test(self, test: TestToRun) -> TestResults:
305         return self.execute_commandline(test)
306
307
308 class DoctestTestRunner(TemplatedTestRunner):
309     """Run all known Doctests."""
310
311     @overrides
312     def get_name(self) -> str:
313         return "Doctests"
314
315     @overrides
316     def identify_tests(self) -> List[TestToRun]:
317         ret = []
318         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
319         for test in out.split('\n'):
320             if re.match(r'.*\.py$', test):
321                 if 'run_tests.py' not in test:
322                     basename = file_utils.without_path(test)
323                     if config.config['coverage']:
324                         ret.append(
325                             TestToRun(
326                                 name=basename,
327                                 kind='doctest capturing coverage',
328                                 cmdline=f'coverage run --source {HOME}/lib {test} 2>&1',
329                             )
330                         )
331                         if test in PERF_SENSATIVE_TESTS:
332                             ret.append(
333                                 TestToRun(
334                                     name=basename,
335                                     kind='doctest w/o coverage to record perf',
336                                     cmdline=f'python3 {test} 2>&1',
337                                 )
338                             )
339                     else:
340                         ret.append(
341                             TestToRun(name=basename, kind='doctest', cmdline=f'python3 {test} 2>&1')
342                         )
343         return ret
344
345     @par.parallelize
346     def run_test(self, test: TestToRun) -> TestResults:
347         return self.execute_commandline(test)
348
349
350 class IntegrationTestRunner(TemplatedTestRunner):
351     """Run all know Integration tests."""
352
353     @overrides
354     def get_name(self) -> str:
355         return "Integration Tests"
356
357     @overrides
358     def identify_tests(self) -> List[TestToRun]:
359         ret = []
360         for test in file_utils.expand_globs('*_itest.py'):
361             basename = file_utils.without_path(test)
362             if config.config['coverage']:
363                 ret.append(
364                     TestToRun(
365                         name=basename,
366                         kind='integration test capturing coverage',
367                         cmdline=f'coverage run --source {HOME}/lib {test} 2>&1',
368                     )
369                 )
370                 if test in PERF_SENSATIVE_TESTS:
371                     ret.append(
372                         TestToRun(
373                             name=basename,
374                             kind='integration test w/o coverage to capture perf',
375                             cmdline=f'{test} 2>&1',
376                         )
377                     )
378             else:
379                 ret.append(
380                     TestToRun(name=basename, kind='integration test', cmdline=f'{test} 2>&1')
381                 )
382         return ret
383
384     @par.parallelize
385     def run_test(self, test: TestToRun) -> TestResults:
386         return self.execute_commandline(test)
387
388
389 def test_results_report(results: Dict[str, TestResults]) -> int:
390     """Give a final report about the tests that were run."""
391     total_problems = 0
392     for result in results.values():
393         print(result, end='')
394         total_problems += len(result.tests_failed)
395         total_problems += len(result.tests_timed_out)
396
397     if total_problems > 0:
398         print('Reminder: look in ./test_output to view test output logs')
399     return total_problems
400
401
402 def code_coverage_report():
403     """Give a final code coverage report."""
404     text_utils.header('Code Coverage')
405     exec_utils.cmd('coverage combine .coverage*')
406     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
407     print(out)
408     print(
409         """
410 To recall this report w/o re-running the tests:
411
412     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
413
414 ...from the 'tests' directory.  Note that subsequent calls to
415 run_tests.py with --coverage will klobber previous results.  See:
416
417     https://coverage.readthedocs.io/en/6.2/
418 """
419     )
420
421
422 @bootstrap.initialize
423 def main() -> Optional[int]:
424     saw_flag = False
425     halt_event = threading.Event()
426     threads: List[TestRunner] = []
427
428     halt_event.clear()
429     params = TestingParameters(
430         halt_on_error=True,
431         halt_event=halt_event,
432     )
433
434     if config.config['coverage']:
435         logger.debug('Clearing existing coverage data via "coverage erase".')
436         exec_utils.cmd('coverage erase')
437
438     if config.config['unittests'] or config.config['all']:
439         saw_flag = True
440         threads.append(UnittestTestRunner(params))
441     if config.config['doctests'] or config.config['all']:
442         saw_flag = True
443         threads.append(DoctestTestRunner(params))
444     if config.config['integration'] or config.config['all']:
445         saw_flag = True
446         threads.append(IntegrationTestRunner(params))
447
448     if not saw_flag:
449         config.print_usage()
450         print('ERROR: one of --unittests, --doctests or --integration is required.')
451         return 1
452
453     for thread in threads:
454         thread.start()
455
456     results: Dict[str, TestResults] = {}
457     while len(results) != len(threads):
458         started = 0
459         done = 0
460         failed = 0
461
462         for thread in threads:
463             (s, tr) = thread.get_status()
464             started += s
465             failed += len(tr.tests_failed) + len(tr.tests_timed_out)
466             done += failed + len(tr.tests_succeeded)
467             if not thread.is_alive():
468                 tid = thread.name
469                 if tid not in results:
470                     result = thread.join()
471                     if result:
472                         results[tid] = result
473                         if len(result.tests_failed) > 0:
474                             logger.error(
475                                 'Thread %s returned abnormal results; killing the others.', tid
476                             )
477                             halt_event.set()
478
479         if started > 0:
480             percent_done = done / started
481         else:
482             percent_done = 0.0
483
484         if failed == 0:
485             color = ansi.fg('green')
486         else:
487             color = ansi.fg('red')
488
489         if percent_done < 100.0:
490             print(
491                 text_utils.bar_graph_string(
492                     done,
493                     started,
494                     text=text_utils.BarGraphText.FRACTION,
495                     width=80,
496                     fgcolor=color,
497                 ),
498                 end='\r',
499                 flush=True,
500             )
501         time.sleep(0.5)
502
503     print(f'{ansi.clear_line()}Final Report:')
504     if config.config['coverage']:
505         code_coverage_report()
506     total_problems = test_results_report(results)
507     return total_problems
508
509
510 if __name__ == '__main__':
511     main()