f0fa15fb95319626552f025619ba17d76b5ec88d
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them.  With it, you will get:
7
8 * The ability to break into pdb on unhandled exceptions,
9 * automatic support for :file:`config.py` (argument parsing)
10 * automatic logging support for :file:`logging.py`,
11 * the ability to enable code profiling,
12 * the ability to enable module import auditing,
13 * optional memory profiling for your program,
14 * ability to set random seed via commandline,
15 * automatic program timing and reporting,
16 * more verbose error handling and reporting,
17
18 Most of these are enabled and/or configured via commandline flags
19 (see below).
20
21 """
22
23 import functools
24 import importlib
25 import logging
26 import os
27 import sys
28 from inspect import stack
29
30 import config
31 import logging_utils
32 from argparse_utils import ActionNoYes
33
34 # This module is commonly used by others in here and should avoid
35 # taking any unnecessary dependencies back on them.
36
37
38 logger = logging.getLogger(__name__)
39
40 cfg = config.add_commandline_args(
41     f'Bootstrap ({__file__})',
42     'Args related to python program bootstrapper and Swiss army knife',
43 )
44 cfg.add_argument(
45     '--debug_unhandled_exceptions',
46     action=ActionNoYes,
47     default=False,
48     help='Break into pdb on top level unhandled exceptions.',
49 )
50 cfg.add_argument(
51     '--show_random_seed',
52     action=ActionNoYes,
53     default=False,
54     help='Should we display (and log.debug) the global random seed?',
55 )
56 cfg.add_argument(
57     '--set_random_seed',
58     type=int,
59     nargs=1,
60     default=None,
61     metavar='SEED_INT',
62     help='Override the global random seed with a particular number.',
63 )
64 cfg.add_argument(
65     '--dump_all_objects',
66     action=ActionNoYes,
67     default=False,
68     help='Should we dump the Python import tree before main?',
69 )
70 cfg.add_argument(
71     '--audit_import_events',
72     action=ActionNoYes,
73     default=False,
74     help='Should we audit all import events?',
75 )
76 cfg.add_argument(
77     '--run_profiler',
78     action=ActionNoYes,
79     default=False,
80     help='Should we run cProfile on this code?',
81 )
82 cfg.add_argument(
83     '--trace_memory',
84     action=ActionNoYes,
85     default=False,
86     help='Should we record/report on memory utilization?',
87 )
88
89 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
90
91
92 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
93     """
94     Top-level exception handler for exceptions that make it past any exception
95     handlers in the python code being run.  Logs the error and stacktrace then
96     maybe attaches a debugger.
97
98     """
99     msg = f'Unhandled top level exception {exc_type}'
100     logger.exception(msg)
101     print(msg, file=sys.stderr)
102     if issubclass(exc_type, KeyboardInterrupt):
103         sys.__excepthook__(exc_type, exc_value, exc_tb)
104         return
105     else:
106         import io
107         import traceback
108
109         tb_output = io.StringIO()
110         traceback.print_tb(exc_tb, None, tb_output)
111         print(tb_output.getvalue(), file=sys.stderr)
112         logger.error(tb_output.getvalue())
113         tb_output.close()
114
115         # stdin or stderr is redirected, just do the normal thing
116         if not sys.stderr.isatty() or not sys.stdin.isatty():
117             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
118
119         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
120             if config.config['debug_unhandled_exceptions']:
121                 logger.info("Invoking the debugger...")
122                 import pdb
123
124                 pdb.pm()
125             else:
126                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
127
128
129 class ImportInterceptor(importlib.abc.MetaPathFinder):
130     """An interceptor that always allows module load events but dumps a
131     record into the log and onto stdout when modules are loaded and
132     produces an audit of who imported what at the end of the run.  It
133     can't see any load events that happen before it, though, so move
134     bootstrap up in your __main__'s import list just temporarily to
135     get a good view.
136
137     """
138
139     def __init__(self):
140         import collect.trie
141
142         self.module_by_filename_cache = {}
143         self.repopulate_modules_by_filename()
144         self.tree = collect.trie.Trie()
145         self.tree_node_by_module = {}
146
147     def repopulate_modules_by_filename(self):
148         self.module_by_filename_cache.clear()
149         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
150             if hasattr(mod, '__file__'):
151                 fname = getattr(mod, '__file__')
152             else:
153                 fname = 'unknown'
154             self.module_by_filename_cache[fname] = mod
155
156     @staticmethod
157     def should_ignore_filename(filename: str) -> bool:
158         return 'importlib' in filename or 'six.py' in filename
159
160     def find_module(self, fullname, path):
161         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
162
163     def find_spec(self, loaded_module, path=None, _=None):
164         s = stack()
165         for x in range(3, len(s)):
166             filename = s[x].filename
167             if ImportInterceptor.should_ignore_filename(filename):
168                 continue
169
170             loading_function = s[x].function
171             if filename in self.module_by_filename_cache:
172                 loading_module = self.module_by_filename_cache[filename]
173             else:
174                 self.repopulate_modules_by_filename()
175                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
176
177             path = self.tree_node_by_module.get(loading_module, [])
178             path.extend([loaded_module])
179             self.tree.insert(path)
180             self.tree_node_by_module[loading_module] = path
181
182             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
183             logger.debug(msg)
184             print(msg)
185             return
186         msg = f'*** Import {loaded_module} from ?????'
187         logger.debug(msg)
188         print(msg)
189
190     def invalidate_caches(self):
191         pass
192
193     def find_importer(self, module: str):
194         if module in self.tree_node_by_module:
195             node = self.tree_node_by_module[module]
196             return node
197         return []
198
199
200 # Audit import events?  Note: this runs early in the lifetime of the
201 # process (assuming that import bootstrap happens early); config has
202 # (probably) not yet been loaded or parsed the commandline.  Also,
203 # some things have probably already been imported while we weren't
204 # watching so this information may be incomplete.
205 #
206 # Also note: move bootstrap up in the global import list to catch
207 # more import events and have a more complete record.
208 IMPORT_INTERCEPTOR = None
209 for arg in sys.argv:
210     if arg == '--audit_import_events':
211         IMPORT_INTERCEPTOR = ImportInterceptor()
212         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
213
214
215 def dump_all_objects() -> None:
216     """Helper code to dump all known python objects."""
217
218     messages = {}
219     all_modules = sys.modules
220     for obj in object.__subclasses__():
221         if not hasattr(obj, '__name__'):
222             continue
223         klass = obj.__name__
224         if not hasattr(obj, '__module__'):
225             continue
226         class_mod_name = obj.__module__
227         if class_mod_name in all_modules:
228             mod = all_modules[class_mod_name]
229             if not hasattr(mod, '__name__'):
230                 mod_name = class_mod_name
231             else:
232                 mod_name = mod.__name__
233             if hasattr(mod, '__file__'):
234                 mod_file = mod.__file__
235             else:
236                 mod_file = 'unknown'
237             if IMPORT_INTERCEPTOR is not None:
238                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
239             else:
240                 import_path = 'unknown'
241             msg = f'{class_mod_name}::{klass} ({mod_file})'
242             if import_path != 'unknown' and len(import_path) > 0:
243                 msg += f' imported by {import_path}'
244             messages[f'{class_mod_name}::{klass}'] = msg
245     for x in sorted(messages.keys()):
246         logger.debug(messages[x])
247         print(messages[x])
248
249
250 def initialize(entry_point):
251     """
252     Remember to initialize config, initialize logging, set/log a random
253     seed, etc... before running main.  If you use this decorator around
254     your main, like this::
255
256         import bootstrap
257
258         @bootstrap.initialize
259         def main():
260             whatever
261
262         if __name__ == '__main__':
263             main()
264
265     You get:
266
267     * The ability to break into pdb on unhandled exceptions,
268     * automatic support for :file:`config.py` (argument parsing)
269     * automatic logging support for :file:`logging.py`,
270     * the ability to enable code profiling,
271     * the ability to enable module import auditing,
272     * optional memory profiling for your program,
273     * ability to set random seed via commandline,
274     * automatic program timing and reporting,
275     * more verbose error handling and reporting,
276
277     Most of these are enabled and/or configured via commandline flags
278     (see below).
279     """
280
281     @functools.wraps(entry_point)
282     def initialize_wrapper(*args, **kwargs):
283         # Hook top level unhandled exceptions, maybe invoke debugger.
284         if sys.excepthook == sys.__excepthook__:
285             sys.excepthook = handle_uncaught_exception
286
287         # Try to figure out the name of the program entry point.  Then
288         # parse configuration (based on cmdline flags, environment vars
289         # etc...)
290         entry_filename = None
291         entry_descr = None
292         try:
293             entry_filename = entry_point.__code__.co_filename
294             entry_descr = entry_point.__code__.__repr__()
295         except Exception:
296             if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
297                 entry_filename = entry_point.__globals__['__file__']
298                 entry_descr = entry_filename
299         config.parse(entry_filename)
300
301         if config.config['trace_memory']:
302             import tracemalloc
303
304             tracemalloc.start()
305
306         # Initialize logging... and log some remembered messages from
307         # config module.
308         logging_utils.initialize_logging(logging.getLogger())
309         config.late_logging()
310
311         # Maybe log some info about the python interpreter itself.
312         logger.debug(
313             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
314         )
315         logger.debug('Python interpreter version: %s', sys.version)
316         logger.debug('Python implementation: %s', sys.implementation)
317         logger.debug('Python C API version: %s', sys.api_version)
318         if __debug__:
319             logger.debug('Python interpreter running in __debug__ mode.')
320         else:
321             logger.debug('Python interpreter running in optimized mode.')
322         logger.debug('Python path: %s', sys.path)
323
324         # Log something about the site_config, many things use it.
325         import site_config
326
327         logger.debug('Global site_config: %s', site_config.get_config())
328
329         # Allow programs that don't bother to override the random seed
330         # to be replayed via the commandline.
331         import random
332
333         random_seed = config.config['set_random_seed']
334         if random_seed is not None:
335             random_seed = random_seed[0]
336         else:
337             random_seed = int.from_bytes(os.urandom(4), 'little')
338
339         if config.config['show_random_seed']:
340             msg = f'Global random seed is: {random_seed}'
341             logger.debug(msg)
342             print(msg)
343         random.seed(random_seed)
344
345         # Do it, invoke the user's code.  Pay attention to how long it takes.
346         logger.debug('Starting %s (program entry point)', entry_descr)
347         ret = None
348         import stopwatch
349
350         if config.config['run_profiler']:
351             import cProfile
352             from pstats import SortKey
353
354             with stopwatch.Timer() as t:
355                 cProfile.runctx(
356                     "ret = entry_point(*args, **kwargs)",
357                     globals(),
358                     locals(),
359                     None,
360                     SortKey.CUMULATIVE,
361                 )
362         else:
363             with stopwatch.Timer() as t:
364                 ret = entry_point(*args, **kwargs)
365
366         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
367
368         if config.config['trace_memory']:
369             snapshot = tracemalloc.take_snapshot()
370             top_stats = snapshot.statistics('lineno')
371             print()
372             print("--trace_memory's top 10 memory using files:")
373             for stat in top_stats[:10]:
374                 print(stat)
375
376         if config.config['dump_all_objects']:
377             dump_all_objects()
378
379         if config.config['audit_import_events']:
380             if IMPORT_INTERCEPTOR is not None:
381                 print(IMPORT_INTERCEPTOR.tree)
382
383         walltime = t()
384         (utime, stime, cutime, cstime, elapsed_time) = os.times()
385         logger.debug(
386             '\n'
387             'user: %.4fs\n'
388             'system: %.4fs\n'
389             'child user: %.4fs\n'
390             'child system: %.4fs\n'
391             'machine uptime: %.4fs\n'
392             'walltime: %.4fs',
393             utime,
394             stime,
395             cutime,
396             cstime,
397             elapsed_time,
398             walltime,
399         )
400
401         # If it doesn't return cleanly, call attention to the return value.
402         if ret is not None and ret != 0:
403             logger.error('Exit %s', ret)
404         else:
405             logger.debug('Exit %s', ret)
406         sys.exit(ret)
407
408     return initialize_wrapper