2b2d2389f3f0c5a53fd7d160d6e21d406f6934c6
[python_utils.git] / tests / run_tests.py
1 #!/usr/bin/env python3
2
3 """
4 A smart, fast test runner.
5 """
6
7 import logging
8 import os
9 import re
10 import subprocess
11 import threading
12 import time
13 from abc import ABC, abstractmethod
14 from dataclasses import dataclass
15 from typing import Any, Dict, List, Optional
16
17 from overrides import overrides
18
19 import bootstrap
20 import config
21 import exec_utils
22 import file_utils
23 import parallelize as par
24 import text_utils
25 import thread_utils
26
27 logger = logging.getLogger(__name__)
28 args = config.add_commandline_args(f'({__file__})', 'Args related to __file__')
29 args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.')
30 args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.')
31 args.add_argument('--integration', '-i', action='store_true', help='Run integration tests.')
32 args.add_argument(
33     '--coverage', '-c', action='store_true', help='Run tests and capture code coverage data'
34 )
35
36 HOME = os.environ['HOME']
37
38
39 @dataclass
40 class TestingParameters:
41     halt_on_error: bool
42     halt_event: threading.Event
43
44
45 @dataclass
46 class TestResults:
47     name: str
48     num_tests_executed: int
49     num_tests_succeeded: int
50     num_tests_failed: int
51     normal_exit: bool
52     output: str
53
54
55 class TestRunner(ABC, thread_utils.ThreadWithReturnValue):
56     def __init__(self, params: TestingParameters):
57         super().__init__(self, target=self.begin, args=[params])
58         self.params = params
59         self.test_results = TestResults(
60             name=f"All {self.get_name()} tests",
61             num_tests_executed=0,
62             num_tests_succeeded=0,
63             num_tests_failed=0,
64             normal_exit=True,
65             output="",
66         )
67
68     def aggregate_test_results(self, result: TestResults):
69         self.test_results.num_tests_executed += result.num_tests_executed
70         self.test_results.num_tests_succeeded += result.num_tests_succeeded
71         self.test_results.num_tests_failed += result.num_tests_failed
72         self.test_results.normal_exit = self.test_results.normal_exit and result.normal_exit
73         self.test_results.output += "\n\n\n" + result.output
74
75     @abstractmethod
76     def get_name(self) -> str:
77         pass
78
79     @abstractmethod
80     def begin(self, params: TestingParameters) -> TestResults:
81         pass
82
83
84 class TemplatedTestRunner(TestRunner, ABC):
85     @abstractmethod
86     def identify_tests(self) -> List[Any]:
87         pass
88
89     @abstractmethod
90     def run_test(self, test: Any) -> TestResults:
91         pass
92
93     def check_for_abort(self):
94         if self.params.halt_event.is_set():
95             logger.debug('Thread %s saw halt event; exiting.', self.get_name())
96             raise Exception("Kill myself!")
97         if not self.test_results.normal_exit:
98             if self.params.halt_on_error:
99                 logger.error('Thread %s saw abnormal results; exiting.', self.get_name())
100                 raise Exception("Kill myself!")
101
102     def status_report(self, running: List[Any], done: List[Any]):
103         total = len(running) + len(done)
104         logging.info(
105             '%s: %d/%d in flight; %d/%d completed.',
106             self.get_name(),
107             len(running),
108             total,
109             len(done),
110             total,
111         )
112
113     @overrides
114     def begin(self, params: TestingParameters) -> TestResults:
115         logger.debug('Thread %s started.', self.get_name())
116         interesting_tests = self.identify_tests()
117         running: List[Any] = []
118         done: List[Any] = []
119         for test in interesting_tests:
120             running.append(self.run_test(test))
121
122         while len(running) > 0:
123             self.status_report(running, done)
124             self.check_for_abort()
125             newly_finished = []
126             for fut in running:
127                 if fut.is_ready():
128                     newly_finished.append(fut)
129                     result = fut._resolve()
130                     logger.debug('Test %s finished.', result.name)
131                     self.aggregate_test_results(result)
132
133             for fut in newly_finished:
134                 running.remove(fut)
135                 done.append(fut)
136             time.sleep(0.25)
137
138         logger.debug('Thread %s finished.', self.get_name())
139         return self.test_results
140
141
142 class UnittestTestRunner(TemplatedTestRunner):
143     @overrides
144     def get_name(self) -> str:
145         return "UnittestTestRunner"
146
147     @overrides
148     def identify_tests(self) -> List[Any]:
149         return list(file_utils.expand_globs('*_test.py'))
150
151     @par.parallelize
152     def run_test(self, test: Any) -> TestResults:
153         if config.config['coverage']:
154             cmdline = f'coverage run --source {HOME}/lib --append {test} --unittests_ignore_perf'
155         else:
156             cmdline = test
157
158         try:
159             logger.debug('Running unittest %s (%s)', test, cmdline)
160             output = exec_utils.cmd(
161                 cmdline,
162                 timeout_seconds=120.0,
163             )
164         except TimeoutError:
165             logger.error('Unittest %s timed out; ran for > 120.0 seconds', test)
166             return TestResults(
167                 test,
168                 1,
169                 0,
170                 1,
171                 False,
172                 f"Unittest {test} timed out.",
173             )
174         except subprocess.CalledProcessError:
175             logger.error('Unittest %s failed.', test)
176             return TestResults(
177                 test,
178                 1,
179                 0,
180                 1,
181                 False,
182                 f"Unittest {test} failed.",
183             )
184         return TestResults(test, 1, 1, 0, True, output)
185
186
187 class DoctestTestRunner(TemplatedTestRunner):
188     @overrides
189     def get_name(self) -> str:
190         return "DoctestTestRunner"
191
192     @overrides
193     def identify_tests(self) -> List[Any]:
194         ret = []
195         out = exec_utils.cmd('grep -lR "^ *import doctest" /home/scott/lib/python_modules/*')
196         for line in out.split('\n'):
197             if re.match(r'.*\.py$', line):
198                 if 'run_tests.py' not in line:
199                     ret.append(line)
200         return ret
201
202     @par.parallelize
203     def run_test(self, test: Any) -> TestResults:
204         if config.config['coverage']:
205             cmdline = f'coverage run --source {HOME}/lib --append {test} 2>&1'
206         else:
207             cmdline = f'python3 {test}'
208         try:
209             logger.debug('Running doctest %s (%s).', test, cmdline)
210             output = exec_utils.cmd(
211                 cmdline,
212                 timeout_seconds=120.0,
213             )
214         except TimeoutError:
215             logger.error('Doctest %s timed out; ran for > 120.0 seconds', test)
216             return TestResults(
217                 test,
218                 1,
219                 0,
220                 1,
221                 False,
222                 f"Doctest {test} timed out.",
223             )
224         except subprocess.CalledProcessError:
225             logger.error('Doctest %s failed.', test)
226             return TestResults(
227                 test,
228                 1,
229                 0,
230                 1,
231                 False,
232                 f"Docttest {test} failed.",
233             )
234         return TestResults(
235             test,
236             1,
237             1,
238             0,
239             True,
240             "",
241         )
242
243
244 class IntegrationTestRunner(TemplatedTestRunner):
245     @overrides
246     def get_name(self) -> str:
247         return "IntegrationTestRunner"
248
249     @overrides
250     def identify_tests(self) -> List[Any]:
251         return list(file_utils.expand_globs('*_itest.py'))
252
253     @par.parallelize
254     def run_test(self, test: Any) -> TestResults:
255         if config.config['coverage']:
256             cmdline = f'coverage run --source {HOME}/lib --append {test}'
257         else:
258             cmdline = test
259         try:
260             logger.debug('Running integration test %s (%s).', test, cmdline)
261             output = exec_utils.cmd(
262                 cmdline,
263                 timeout_seconds=240.0,
264             )
265         except TimeoutError:
266             logger.error('Integration Test %s timed out; ran for > 240.0 seconds', test)
267             return TestResults(
268                 test,
269                 1,
270                 0,
271                 1,
272                 False,
273                 f"Integration Test {test} timed out.",
274             )
275         except subprocess.CalledProcessError:
276             logger.error('Integration Test %s failed.', test)
277             return TestResults(
278                 test,
279                 1,
280                 0,
281                 1,
282                 False,
283                 f"Integration Test {test} failed.",
284             )
285         return TestResults(
286             test,
287             1,
288             1,
289             0,
290             True,
291             "",
292         )
293
294
295 def test_results_report(results: Dict[str, TestResults]):
296     print(results)
297
298
299 def code_coverage_report():
300     text_utils.header('Code Coverage')
301     out = exec_utils.cmd('coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover')
302     print(out)
303     print(
304         """
305 To recall this report w/o re-running the tests:
306
307     $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover
308
309 ...from the 'tests' directory.  Note that subsequent calls to
310 run_tests.py with --coverage will klobber previous results.  See:
311
312     https://coverage.readthedocs.io/en/6.2/
313
314 """
315     )
316
317
318 @bootstrap.initialize
319 def main() -> Optional[int]:
320     saw_flag = False
321     halt_event = threading.Event()
322     threads: List[TestRunner] = []
323
324     halt_event.clear()
325     params = TestingParameters(
326         halt_on_error=True,
327         halt_event=halt_event,
328     )
329
330     if config.config['coverage']:
331         logger.debug('Clearing existing coverage data via "coverage erase".')
332         exec_utils.cmd('coverage erase')
333
334     if config.config['unittests']:
335         saw_flag = True
336         threads.append(UnittestTestRunner(params))
337     if config.config['doctests']:
338         saw_flag = True
339         threads.append(DoctestTestRunner(params))
340     if config.config['integration']:
341         saw_flag = True
342         threads.append(IntegrationTestRunner(params))
343
344     if not saw_flag:
345         config.print_usage()
346         print('ERROR: one of --unittests, --doctests or --integration is required.')
347         return 1
348
349     for thread in threads:
350         thread.start()
351
352     results: Dict[str, TestResults] = {}
353     while len(results) != len(threads):
354         for thread in threads:
355             if not thread.is_alive():
356                 tid = thread.name
357                 if tid not in results:
358                     result = thread.join()
359                     if result:
360                         results[tid] = result
361                         if not result.normal_exit:
362                             halt_event.set()
363         time.sleep(1.0)
364
365     test_results_report(results)
366     if config.config['coverage']:
367         code_coverage_report()
368     return 0
369
370
371 if __name__ == '__main__':
372     main()