3 # © Copyright 2021-2022, Scott Gasch
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them. With it, you will get:
8 * The ability to break into pdb on unhandled exceptions,
9 * automatic support for :file:`config.py` (argument parsing)
10 * automatic logging support for :file:`logging.py`,
11 * the ability to enable code profiling,
12 * the ability to enable module import auditing,
13 * optional memory profiling for your program,
14 * ability to set random seed via commandline,
15 * automatic program timing and reporting,
16 * more verbose error handling and reporting,
18 Most of these are enabled and/or configured via commandline flags
28 from inspect import stack
32 from argparse_utils import ActionNoYes
34 # This module is commonly used by others in here and should avoid
35 # taking any unnecessary dependencies back on them.
38 logger = logging.getLogger(__name__)
40 cfg = config.add_commandline_args(
41 f'Bootstrap ({__file__})',
42 'Args related to python program bootstrapper and Swiss army knife',
45 '--debug_unhandled_exceptions',
48 help='Break into pdb on top level unhandled exceptions.',
54 help='Should we display (and log.debug) the global random seed?',
62 help='Override the global random seed with a particular number.',
68 help='Should we dump the Python import tree before main?',
71 '--audit_import_events',
74 help='Should we audit all import events?',
80 help='Should we run cProfile on this code?',
86 help='Should we record/report on memory utilization?',
89 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
92 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
94 Top-level exception handler for exceptions that make it past any exception
95 handlers in the python code being run. Logs the error and stacktrace then
96 maybe attaches a debugger.
99 msg = f'Unhandled top level exception {exc_type}'
100 logger.exception(msg)
101 print(msg, file=sys.stderr)
102 if issubclass(exc_type, KeyboardInterrupt):
103 sys.__excepthook__(exc_type, exc_value, exc_tb)
109 tb_output = io.StringIO()
110 traceback.print_tb(exc_tb, None, tb_output)
111 print(tb_output.getvalue(), file=sys.stderr)
112 logger.error(tb_output.getvalue())
115 # stdin or stderr is redirected, just do the normal thing
116 if not sys.stderr.isatty() or not sys.stdin.isatty():
117 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
119 else: # a terminal is attached and stderr isn't redirected, maybe debug.
120 if config.config['debug_unhandled_exceptions']:
121 logger.info("Invoking the debugger...")
126 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
129 class ImportInterceptor(importlib.abc.MetaPathFinder):
130 """An interceptor that always allows module load events but dumps a
131 record into the log and onto stdout when modules are loaded and
132 produces an audit of who imported what at the end of the run. It
133 can't see any load events that happen before it, though, so move
134 bootstrap up in your __main__'s import list just temporarily to
142 self.module_by_filename_cache = {}
143 self.repopulate_modules_by_filename()
144 self.tree = collect.trie.Trie()
145 self.tree_node_by_module = {}
147 def repopulate_modules_by_filename(self):
148 self.module_by_filename_cache.clear()
149 for _, mod in sys.modules.copy().items(): # copy here because modules is volatile
150 if hasattr(mod, '__file__'):
151 fname = getattr(mod, '__file__')
154 self.module_by_filename_cache[fname] = mod
157 def should_ignore_filename(filename: str) -> bool:
158 return 'importlib' in filename or 'six.py' in filename
160 def find_module(self, fullname, path):
161 raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
163 def find_spec(self, loaded_module, path=None, _=None):
165 for x in range(3, len(s)):
166 filename = s[x].filename
167 if ImportInterceptor.should_ignore_filename(filename):
170 loading_function = s[x].function
171 if filename in self.module_by_filename_cache:
172 loading_module = self.module_by_filename_cache[filename]
174 self.repopulate_modules_by_filename()
175 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
177 path = self.tree_node_by_module.get(loading_module, [])
178 path.extend([loaded_module])
179 self.tree.insert(path)
180 self.tree_node_by_module[loading_module] = path
182 msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
186 msg = f'*** Import {loaded_module} from ?????'
190 def invalidate_caches(self):
193 def find_importer(self, module: str):
194 if module in self.tree_node_by_module:
195 node = self.tree_node_by_module[module]
200 # Audit import events? Note: this runs early in the lifetime of the
201 # process (assuming that import bootstrap happens early); config has
202 # (probably) not yet been loaded or parsed the commandline. Also,
203 # some things have probably already been imported while we weren't
204 # watching so this information may be incomplete.
206 # Also note: move bootstrap up in the global import list to catch
207 # more import events and have a more complete record.
208 IMPORT_INTERCEPTOR = None
210 if arg == '--audit_import_events':
211 IMPORT_INTERCEPTOR = ImportInterceptor()
212 sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
215 def dump_all_objects() -> None:
216 """Helper code to dump all known python objects."""
219 all_modules = sys.modules
220 for obj in object.__subclasses__():
221 if not hasattr(obj, '__name__'):
224 if not hasattr(obj, '__module__'):
226 class_mod_name = obj.__module__
227 if class_mod_name in all_modules:
228 mod = all_modules[class_mod_name]
229 if not hasattr(mod, '__name__'):
230 mod_name = class_mod_name
232 mod_name = mod.__name__
233 if hasattr(mod, '__file__'):
234 mod_file = mod.__file__
237 if IMPORT_INTERCEPTOR is not None:
238 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
240 import_path = 'unknown'
241 msg = f'{class_mod_name}::{klass} ({mod_file})'
242 if import_path != 'unknown' and len(import_path) > 0:
243 msg += f' imported by {import_path}'
244 messages[f'{class_mod_name}::{klass}'] = msg
245 for x in sorted(messages.keys()):
246 logger.debug(messages[x])
250 def initialize(entry_point):
252 Remember to initialize config, initialize logging, set/log a random
253 seed, etc... before running main. If you use this decorator around
254 your main, like this::
258 @bootstrap.initialize
262 if __name__ == '__main__':
267 * The ability to break into pdb on unhandled exceptions,
268 * automatic support for :file:`config.py` (argument parsing)
269 * automatic logging support for :file:`logging.py`,
270 * the ability to enable code profiling,
271 * the ability to enable module import auditing,
272 * optional memory profiling for your program,
273 * ability to set random seed via commandline,
274 * automatic program timing and reporting,
275 * more verbose error handling and reporting,
277 Most of these are enabled and/or configured via commandline flags
281 @functools.wraps(entry_point)
282 def initialize_wrapper(*args, **kwargs):
283 # Hook top level unhandled exceptions, maybe invoke debugger.
284 if sys.excepthook == sys.__excepthook__:
285 sys.excepthook = handle_uncaught_exception
287 # Try to figure out the name of the program entry point. Then
288 # parse configuration (based on cmdline flags, environment vars
290 entry_filename = None
293 entry_filename = entry_point.__code__.co_filename
294 entry_descr = entry_point.__code__.__repr__()
296 if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
297 entry_filename = entry_point.__globals__['__file__']
298 entry_descr = entry_filename
299 config.parse(entry_filename)
301 if config.config['trace_memory']:
306 # Initialize logging... and log some remembered messages from
308 logging_utils.initialize_logging(logging.getLogger())
309 config.late_logging()
311 # Maybe log some info about the python interpreter itself.
313 'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
315 logger.debug('Python interpreter version: %s', sys.version)
316 logger.debug('Python implementation: %s', sys.implementation)
317 logger.debug('Python C API version: %s', sys.api_version)
319 logger.debug('Python interpreter running in __debug__ mode.')
321 logger.debug('Python interpreter running in optimized mode.')
322 logger.debug('Python path: %s', sys.path)
324 # Log something about the site_config, many things use it.
327 logger.debug('Global site_config: %s', site_config.get_config())
329 # Allow programs that don't bother to override the random seed
330 # to be replayed via the commandline.
333 random_seed = config.config['set_random_seed']
334 if random_seed is not None:
335 random_seed = random_seed[0]
337 random_seed = int.from_bytes(os.urandom(4), 'little')
339 if config.config['show_random_seed']:
340 msg = f'Global random seed is: {random_seed}'
343 random.seed(random_seed)
345 # Do it, invoke the user's code. Pay attention to how long it takes.
346 logger.debug('Starting %s (program entry point)', entry_descr)
350 if config.config['run_profiler']:
352 from pstats import SortKey
354 with stopwatch.Timer() as t:
356 "ret = entry_point(*args, **kwargs)",
363 with stopwatch.Timer() as t:
364 ret = entry_point(*args, **kwargs)
366 logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
368 if config.config['trace_memory']:
369 snapshot = tracemalloc.take_snapshot()
370 top_stats = snapshot.statistics('lineno')
372 print("--trace_memory's top 10 memory using files:")
373 for stat in top_stats[:10]:
376 if config.config['dump_all_objects']:
379 if config.config['audit_import_events']:
380 if IMPORT_INTERCEPTOR is not None:
381 print(IMPORT_INTERCEPTOR.tree)
384 (utime, stime, cutime, cstime, elapsed_time) = os.times()
389 'child user: %.4fs\n'
390 'child system: %.4fs\n'
391 'machine uptime: %.4fs\n'
401 # If it doesn't return cleanly, call attention to the return value.
402 if ret is not None and ret != 0:
403 logger.error('Exit %s', ret)
405 logger.debug('Exit %s', ret)
408 return initialize_wrapper