7e7bad593da9f8b453626ef018393b9436548833
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 import bootstrap
20 import config
21 import exec_utils
22 import file_utils
23 import parallelize as par
24 import text_utils
25 import thread_utils
26
27 logger = logging.getLogger(__name__)
28 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
29 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
30 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
31 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
32 args.add_argument(
33     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
34 )
35
36 HOME = os.environ['HOME']
37
38
39 @dataclass
40 class TestingParameters:
41     halt_on_error: bool
42     halt_event: threading.Event
43
44
45 @dataclass
46 class TestResults:
47     name: str
48     tests_executed: List[str]
49     tests_succeeded: List[str]
50     tests_failed: List[str]
51     tests_timed_out: List[str]
52
53
54 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
55     def __init__(self, params: TestingParameters):
56         super().__init__(self, target=self.begin, args=[params])
57         self.params = params
58         self.test_results = TestResults(
59             name=f"All {self.get_name()} tests",
60             tests_executed=[],
61             tests_succeeded=[],
62             tests_failed=[],
63             tests_timed_out=[],
64         )
65
66     def aggregate_test_results(self, result: TestResults):
67         self.test_results.tests_executed.extend(result.tests_executed)
68         self.test_results.tests_succeeded.extend(result.tests_succeeded)
69         self.test_results.tests_failed.extend(result.tests_failed)
70         self.test_results.tests_timed_out.extend(result.tests_timed_out)
71
72     @abstractmethod
73     def get_name(self) -> str:
74         pass
75
76     @abstractmethod
77     def begin(self, params: TestingParameters) -> TestResults:
78         pass
79
80
81 class TemplatedTestRunner(TestRunner, ABC):
82     @abstractmethod
83     def identify_tests(self) -> List[Any]:
84         pass
85
86     @abstractmethod
87     def run_test(self, test: Any) -> TestResults:
88         pass
89
90     def check_for_abort(self):
91         if self.params.halt_event.is_set():
92             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
93             raise Exception("Kill myself!")
94         if self.params.halt_on_error:
95             if len(self.test_results.tests_failed) > 0:
96                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
97                 raise Exception("Kill myself!")
98
99     def status_report(self, running: List[Any], done: List[Any]):
100         total = len(running) + len(done)
101         logging.info(
102             '%s: %d/%d in flight; %d/%d completed.',
103             self.get_name(),
104             len(running),
105             total,
106             len(done),
107             total,
108         )
109
110     def persist_output(self, test_name: str, message: str, output: str) -> None:
111         basename = file_utils.without_path(test_name)
112         dest = f'{basename}-output.txt'
113         with open(f'./test_output/{dest}', 'w') as wf:
114             print(message, file=wf)
115             print('-' * len(message), file=wf)
116             wf.write(output)
117
118     def execute_commandline(
119         self,
120         test_name: str,
121         cmdline: str,
122         *,
123         timeout: float = 120.0,
124     ) -> TestResults:
125
126         try:
127             logger.debug('%s: Running %s (%s)', self.get_name(), test_name, cmdline)
128             output = exec_utils.cmd(
129                 cmdline,
130                 timeout_seconds=timeout,
131             )
132             self.persist_output(test_name, f'{test_name} ({cmdline}) succeeded.', output)
133             logger.debug('%s (%s) succeeded', test_name, cmdline)
134             return TestResults(test_name, [test_name], [test_name], [], [])
135         except subprocess.TimeoutExpired as e:
136             msg = f'{self.get_name()}: {test_name} ({cmdline}) timed out after {e.timeout:.1f} seconds.'
137             logger.error(msg)
138             logger.debug(
139                 '%s: %s output when it timed out: %s', self.get_name(), test_name, e.output
140             )
141             self.persist_output(test_name, msg, e.output)
142             return TestResults(
143                 test_name,
144                 [test_name],
145                 [],
146                 [],
147                 [test_name],
148             )
149         except subprocess.CalledProcessError as e:
150             msg = f'{self.get_name()}: {test_name} ({cmdline}) failed; exit code {e.returncode}'
151             logger.error(msg)
152             logger.debug('%s: %s output when it failed: %s', self.get_name(), test_name, e.output)
153             self.persist_output(test_name, msg, e.output)
154             return TestResults(
155                 test_name,
156                 [test_name],
157                 [],
158                 [test_name],
159                 [],
160             )
161
162     @overrides
163     def begin(self, params: TestingParameters) -> TestResults:
164         logger.debug('Thread %s started.', self.get_name())
165         interesting_tests = self.identify_tests()
166         running: List[Any] = []
167         done: List[Any] = []
168         for test in interesting_tests:
169             running.append(self.run_test(test))
170
171         while len(running) > 0:
172             self.status_report(running, done)
173             self.check_for_abort()
174             newly_finished = []
175             for fut in running:
176                 if fut.is_ready():
177                     newly_finished.append(fut)
178                     result = fut._resolve()
179                     logger.debug('Test %s finished.', result.name)
180                     self.aggregate_test_results(result)
181
182             for fut in newly_finished:
183                 running.remove(fut)
184                 done.append(fut)
185             time.sleep(0.25)
186
187         logger.debug('Thread %s finished.', self.get_name())
188         return self.test_results
189
190
191 class UnittestTestRunner(TemplatedTestRunner):
192     @overrides
193     def get_name(self) -> str:
194         return "UnittestTestRunner"
195
196     @overrides
197     def identify_tests(self) -> List[Any]:
198         return list(file_utils.expand_globs('*_test.py'))
199
200     @par.parallelize
201     def run_test(self, test: Any) -> TestResults:
202         if config.config['coverage']:
203             cmdline = f'coverage run --source {HOME}/lib {test} --unittests_ignore_perf'
204         else:
205             cmdline = test
206         return self.execute_commandline(test, cmdline)
207
208
209 class DoctestTestRunner(TemplatedTestRunner):
210     @overrides
211     def get_name(self) -> str:
212         return "DoctestTestRunner"
213
214     @overrides
215     def identify_tests(self) -> List[Any]:
216         ret = []
217         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
218         for line in out.split('\n'):
219             if re.match(r'.*\.py$', line):
220                 if 'run_tests.py' not in line:
221                     ret.append(line)
222         return ret
223
224     @par.parallelize
225     def run_test(self, test: Any) -> TestResults:
226         if config.config['coverage']:
227             cmdline = f'coverage run --source {HOME}/lib {test} 2>&1'
228         else:
229             cmdline = f'python3 {test}'
230         return self.execute_commandline(test, cmdline)
231
232
233 class IntegrationTestRunner(TemplatedTestRunner):
234     @overrides
235     def get_name(self) -> str:
236         return "IntegrationTestRunner"
237
238     @overrides
239     def identify_tests(self) -> List[Any]:
240         return list(file_utils.expand_globs('*_itest.py'))
241
242     @par.parallelize
243     def run_test(self, test: Any) -> TestResults:
244         if config.config['coverage']:
245             cmdline = f'coverage run --source {HOME}/lib {test}'
246         else:
247             cmdline = test
248         return self.execute_commandline(test, cmdline)
249
250
251 def test_results_report(results: Dict[str, TestResults]):
252     for type, result in results.items():
253         print(text_utils.header(f'{result.name}'))
254         print(f'  Ran {len(result.tests_executed)} tests.')
255         print(f'  ..{len(result.tests_succeeded)} tests succeeded.')
256         if len(result.tests_failed) > 0:
257             print(f'  ..{len(result.tests_failed)} tests failed:')
258             for test in result.tests_failed:
259                 print(f'    {test}')
260
261         if len(result.tests_timed_out) > 0:
262             print(f'  ..{len(result.tests_timed_out)} tests timed out:')
263             for test in result.tests_failed:
264                 print(f'    {test}')
265
266         if len(result.tests_failed) + len(result.tests_timed_out):
267             print('Reminder: look in ./test_output to view test output logs')
268
269
270 def code_coverage_report():
271     text_utils.header('Code Coverage')
272     out = exec_utils.cmd('coverage combine .coverage*')
273     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
274     print(out)
275     print(
276         """
277 To recall this report w/o re-running the tests:
278
279     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
280
281 ...from the 'tests' directory.  Note that subsequent calls to
282 run_tests.py with --coverage will klobber previous results.  See:
283
284     https://coverage.readthedocs.io/en/6.2/
285
286 """
287     )
288
289
290 @bootstrap.initialize
291 def main() -> Optional[int]:
292     saw_flag = False
293     halt_event = threading.Event()
294     threads: List[TestRunner] = []
295
296     halt_event.clear()
297     params = TestingParameters(
298         halt_on_error=True,
299         halt_event=halt_event,
300     )
301
302     if config.config['coverage']:
303         logger.debug('Clearing existing coverage data via "coverage erase".')
304         exec_utils.cmd('coverage erase')
305
306     if config.config['unittests']:
307         saw_flag = True
308         threads.append(UnittestTestRunner(params))
309     if config.config['doctests']:
310         saw_flag = True
311         threads.append(DoctestTestRunner(params))
312     if config.config['integration']:
313         saw_flag = True
314         threads.append(IntegrationTestRunner(params))
315
316     if not saw_flag:
317         config.print_usage()
318         print('ERROR: one of --unittests, --doctests or --integration is required.')
319         return 1
320
321     for thread in threads:
322         thread.start()
323
324     results: Dict[str, TestResults] = {}
325     while len(results) != len(threads):
326         for thread in threads:
327             if not thread.is_alive():
328                 tid = thread.name
329                 if tid not in results:
330                     result = thread.join()
331                     if result:
332                         results[tid] = result
333                         if len(result.tests_failed) > 0:
334                             logger.error(
335                                 'Thread %s returned abnormal results; killing the others.', tid
336                             )
337                             halt_event.set()
338         time.sleep(1.0)
339
340     test_results_report(results)
341     if config.config['coverage']:
342         code_coverage_report()
343     return 0
344
345
346 if __name__ == '__main__':
347     main()