3 # © Copyright 2021-2022, Scott Gasch
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them. With it, you will get:
8 * The ability to break into pdb on unhandled exceptions,
9 * automatic support for :file:`config.py` (argument parsing)
10 * automatic logging support for :file:`logging.py`,
11 * the ability to enable code profiling,
12 * the ability to enable module import auditing,
13 * optional memory profiling for your program,
14 * ability to set random seed via commandline,
15 * automatic program timing and reporting,
16 * more verbose error handling and reporting,
18 Most of these are enabled and/or configured via commandline flags
29 from inspect import stack
33 from argparse_utils import ActionNoYes
35 # This module is commonly used by others in here and should avoid
36 # taking any unnecessary dependencies back on them.
39 logger = logging.getLogger(__name__)
41 cfg = config.add_commandline_args(
42 f'Bootstrap ({__file__})',
43 'Args related to python program bootstrapper and Swiss army knife',
46 '--debug_unhandled_exceptions',
49 help='Break into pdb on top level unhandled exceptions.',
55 help='Should we display (and log.debug) the global random seed?',
63 help='Override the global random seed with a particular number.',
69 help='Should we dump the Python import tree before main?',
72 '--audit_import_events',
75 help='Should we audit all import events?',
81 help='Should we run cProfile on this code?',
87 help='Should we record/report on memory utilization?',
90 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
93 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
95 Top-level exception handler for exceptions that make it past any exception
96 handlers in the python code being run. Logs the error and stacktrace then
97 maybe attaches a debugger.
100 msg = f'Unhandled top level exception {exc_type}'
101 logger.exception(msg)
102 print(msg, file=sys.stderr)
103 if issubclass(exc_type, KeyboardInterrupt):
104 sys.__excepthook__(exc_type, exc_value, exc_tb)
110 tb_output = io.StringIO()
111 traceback.print_tb(exc_tb, None, tb_output)
112 print(tb_output.getvalue(), file=sys.stderr)
113 logger.error(tb_output.getvalue())
116 # stdin or stderr is redirected, just do the normal thing
117 if not sys.stderr.isatty() or not sys.stdin.isatty():
118 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
120 else: # a terminal is attached and stderr isn't redirected, maybe debug.
121 if config.config['debug_unhandled_exceptions']:
122 logger.info("Invoking the debugger...")
127 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
130 class ImportInterceptor(importlib.abc.MetaPathFinder):
131 """An interceptor that always allows module load events but dumps a
132 record into the log and onto stdout when modules are loaded and
133 produces an audit of who imported what at the end of the run. It
134 can't see any load events that happen before it, though, so move
135 bootstrap up in your __main__'s import list just temporarily to
143 self.module_by_filename_cache = {}
144 self.repopulate_modules_by_filename()
145 self.tree = collect.trie.Trie()
146 self.tree_node_by_module = {}
148 def repopulate_modules_by_filename(self):
149 self.module_by_filename_cache.clear()
150 for _, mod in sys.modules.copy().items(): # copy here because modules is volatile
151 if hasattr(mod, '__file__'):
152 fname = getattr(mod, '__file__')
155 self.module_by_filename_cache[fname] = mod
158 def should_ignore_filename(filename: str) -> bool:
159 return 'importlib' in filename or 'six.py' in filename
161 def find_module(self, fullname, path):
162 raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
164 def find_spec(self, loaded_module, path=None, _=None):
166 for x in range(3, len(s)):
167 filename = s[x].filename
168 if ImportInterceptor.should_ignore_filename(filename):
171 loading_function = s[x].function
172 if filename in self.module_by_filename_cache:
173 loading_module = self.module_by_filename_cache[filename]
175 self.repopulate_modules_by_filename()
176 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
178 path = self.tree_node_by_module.get(loading_module, [])
179 path.extend([loaded_module])
180 self.tree.insert(path)
181 self.tree_node_by_module[loading_module] = path
183 msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
187 msg = f'*** Import {loaded_module} from ?????'
191 def invalidate_caches(self):
194 def find_importer(self, module: str):
195 if module in self.tree_node_by_module:
196 node = self.tree_node_by_module[module]
201 # Audit import events? Note: this runs early in the lifetime of the
202 # process (assuming that import bootstrap happens early); config has
203 # (probably) not yet been loaded or parsed the commandline. Also,
204 # some things have probably already been imported while we weren't
205 # watching so this information may be incomplete.
207 # Also note: move bootstrap up in the global import list to catch
208 # more import events and have a more complete record.
209 IMPORT_INTERCEPTOR = None
211 if arg == '--audit_import_events':
212 IMPORT_INTERCEPTOR = ImportInterceptor()
213 sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
216 def dump_all_objects() -> None:
217 """Helper code to dump all known python objects."""
220 all_modules = sys.modules
221 for obj in object.__subclasses__():
222 if not hasattr(obj, '__name__'):
225 if not hasattr(obj, '__module__'):
227 class_mod_name = obj.__module__
228 if class_mod_name in all_modules:
229 mod = all_modules[class_mod_name]
230 if not hasattr(mod, '__name__'):
231 mod_name = class_mod_name
233 mod_name = mod.__name__
234 if hasattr(mod, '__file__'):
235 mod_file = mod.__file__
238 if IMPORT_INTERCEPTOR is not None:
239 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
241 import_path = 'unknown'
242 msg = f'{class_mod_name}::{klass} ({mod_file})'
243 if import_path != 'unknown' and len(import_path) > 0:
244 msg += f' imported by {import_path}'
245 messages[f'{class_mod_name}::{klass}'] = msg
246 for x in sorted(messages.keys()):
247 logger.debug(messages[x])
251 def initialize(entry_point):
253 Remember to initialize config, initialize logging, set/log a random
254 seed, etc... before running main. If you use this decorator around
255 your main, like this::
259 @bootstrap.initialize
263 if __name__ == '__main__':
268 * The ability to break into pdb on unhandled exceptions,
269 * automatic support for :file:`config.py` (argument parsing)
270 * automatic logging support for :file:`logging.py`,
271 * the ability to enable code profiling,
272 * the ability to enable module import auditing,
273 * optional memory profiling for your program,
274 * ability to set random seed via commandline,
275 * automatic program timing and reporting,
276 * more verbose error handling and reporting,
278 Most of these are enabled and/or configured via commandline flags
282 @functools.wraps(entry_point)
283 def initialize_wrapper(*args, **kwargs):
284 # Hook top level unhandled exceptions, maybe invoke debugger.
285 if sys.excepthook == sys.__excepthook__:
286 sys.excepthook = handle_uncaught_exception
288 # Try to figure out the name of the program entry point. Then
289 # parse configuration (based on cmdline flags, environment vars
291 entry_filename = None
294 entry_filename = entry_point.__code__.co_filename
295 entry_descr = entry_point.__code__.__repr__()
297 if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
298 entry_filename = entry_point.__globals__['__file__']
299 entry_descr = entry_filename
300 config.parse(entry_filename)
302 if config.config['trace_memory']:
307 # Initialize logging... and log some remembered messages from
309 logging_utils.initialize_logging(logging.getLogger())
310 config.late_logging()
312 # Maybe log some info about the python interpreter itself.
314 'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
316 logger.debug('Python interpreter version: %s', sys.version)
317 logger.debug('Python implementation: %s', sys.implementation)
318 logger.debug('Python C API version: %s', sys.api_version)
320 logger.debug('Python interpreter running in __debug__ mode.')
322 logger.debug('Python interpreter running in optimized mode.')
323 logger.debug('Python path: %s', sys.path)
325 # Log something about the site_config, many things use it.
328 logger.debug('Global site_config: %s', site_config.get_config())
330 # Allow programs that don't bother to override the random seed
331 # to be replayed via the commandline.
334 random_seed = config.config['set_random_seed']
335 if random_seed is not None:
336 random_seed = random_seed[0]
338 random_seed = int.from_bytes(os.urandom(4), 'little')
340 if config.config['show_random_seed']:
341 msg = f'Global random seed is: {random_seed}'
344 random.seed(random_seed)
346 # Do it, invoke the user's code. Pay attention to how long it takes.
347 logger.debug('Starting %s (program entry point)', entry_descr)
351 if config.config['run_profiler']:
353 from pstats import SortKey
355 with stopwatch.Timer() as t:
357 "ret = entry_point(*args, **kwargs)",
364 with stopwatch.Timer() as t:
365 ret = entry_point(*args, **kwargs)
367 logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
369 if config.config['trace_memory']:
370 snapshot = tracemalloc.take_snapshot()
371 top_stats = snapshot.statistics('lineno')
373 print("--trace_memory's top 10 memory using files:")
374 for stat in top_stats[:10]:
377 if config.config['dump_all_objects']:
380 if config.config['audit_import_events']:
381 if IMPORT_INTERCEPTOR is not None:
382 print(IMPORT_INTERCEPTOR.tree)
385 (utime, stime, cutime, cstime, elapsed_time) = os.times()
390 'child user: %.4fs\n'
391 'child system: %.4fs\n'
392 'machine uptime: %.4fs\n'
402 # If it doesn't return cleanly, call attention to the return value.
403 if ret is not None and ret != 0:
404 logger.error('Exit %s', ret)
406 logger.debug('Exit %s', ret)
409 return initialize_wrapper