3 # © Copyright 2021-2022, Scott Gasch
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them. With it, you can break into
7 pdb on unhandled top level exceptions, profile your code by passing a
8 commandline argument in, audit module import events, examine where
9 memory is being used in your program, and so on.
18 from inspect import stack
22 from argparse_utils import ActionNoYes
24 # This module is commonly used by others in here and should avoid
25 # taking any unnecessary dependencies back on them.
28 logger = logging.getLogger(__name__)
30 cfg = config.add_commandline_args(
31 f'Bootstrap ({__file__})',
32 'Args related to python program bootstrapper and Swiss army knife',
35 '--debug_unhandled_exceptions',
38 help='Break into pdb on top level unhandled exceptions.',
44 help='Should we display (and log.debug) the global random seed?',
52 help='Override the global random seed with a particular number.',
58 help='Should we dump the Python import tree before main?',
61 '--audit_import_events',
64 help='Should we audit all import events?',
70 help='Should we run cProfile on this code?',
76 help='Should we record/report on memory utilization?',
79 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
82 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
84 Top-level exception handler for exceptions that make it past any exception
85 handlers in the python code being run. Logs the error and stacktrace then
86 maybe attaches a debugger.
89 msg = f'Unhandled top level exception {exc_type}'
91 print(msg, file=sys.stderr)
92 if issubclass(exc_type, KeyboardInterrupt):
93 sys.__excepthook__(exc_type, exc_value, exc_tb)
99 tb_output = io.StringIO()
100 traceback.print_tb(exc_tb, None, tb_output)
101 print(tb_output.getvalue(), file=sys.stderr)
102 logger.error(tb_output.getvalue())
105 # stdin or stderr is redirected, just do the normal thing
106 if not sys.stderr.isatty() or not sys.stdin.isatty():
107 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
109 else: # a terminal is attached and stderr isn't redirected, maybe debug.
110 if config.config['debug_unhandled_exceptions']:
111 logger.info("Invoking the debugger...")
116 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
119 class ImportInterceptor(importlib.abc.MetaPathFinder):
120 """An interceptor that always allows module load events but dumps a
121 record into the log and onto stdout when modules are loaded and
122 produces an audit of who imported what at the end of the run. It
123 can't see any load events that happen before it, though, so move
124 bootstrap up in your __main__'s import list just temporarily to
132 self.module_by_filename_cache = {}
133 self.repopulate_modules_by_filename()
134 self.tree = collect.trie.Trie()
135 self.tree_node_by_module = {}
137 def repopulate_modules_by_filename(self):
138 self.module_by_filename_cache.clear()
139 for _, mod in sys.modules.copy().items(): # copy here because modules is volatile
140 if hasattr(mod, '__file__'):
141 fname = getattr(mod, '__file__')
144 self.module_by_filename_cache[fname] = mod
147 def should_ignore_filename(filename: str) -> bool:
148 return 'importlib' in filename or 'six.py' in filename
150 def find_module(self, fullname, path):
151 raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
153 def find_spec(self, loaded_module, path=None, _=None):
155 for x in range(3, len(s)):
156 filename = s[x].filename
157 if ImportInterceptor.should_ignore_filename(filename):
160 loading_function = s[x].function
161 if filename in self.module_by_filename_cache:
162 loading_module = self.module_by_filename_cache[filename]
164 self.repopulate_modules_by_filename()
165 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
167 path = self.tree_node_by_module.get(loading_module, [])
168 path.extend([loaded_module])
169 self.tree.insert(path)
170 self.tree_node_by_module[loading_module] = path
172 msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
176 msg = f'*** Import {loaded_module} from ?????'
180 def invalidate_caches(self):
183 def find_importer(self, module: str):
184 if module in self.tree_node_by_module:
185 node = self.tree_node_by_module[module]
190 # Audit import events? Note: this runs early in the lifetime of the
191 # process (assuming that import bootstrap happens early); config has
192 # (probably) not yet been loaded or parsed the commandline. Also,
193 # some things have probably already been imported while we weren't
194 # watching so this information may be incomplete.
196 # Also note: move bootstrap up in the global import list to catch
197 # more import events and have a more complete record.
198 IMPORT_INTERCEPTOR = None
200 if arg == '--audit_import_events':
201 IMPORT_INTERCEPTOR = ImportInterceptor()
202 sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
205 def dump_all_objects() -> None:
207 all_modules = sys.modules
208 for obj in object.__subclasses__():
209 if not hasattr(obj, '__name__'):
212 if not hasattr(obj, '__module__'):
214 class_mod_name = obj.__module__
215 if class_mod_name in all_modules:
216 mod = all_modules[class_mod_name]
217 if not hasattr(mod, '__name__'):
218 mod_name = class_mod_name
220 mod_name = mod.__name__
221 if hasattr(mod, '__file__'):
222 mod_file = mod.__file__
225 if IMPORT_INTERCEPTOR is not None:
226 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
228 import_path = 'unknown'
229 msg = f'{class_mod_name}::{klass} ({mod_file})'
230 if import_path != 'unknown' and len(import_path) > 0:
231 msg += f' imported by {import_path}'
232 messages[f'{class_mod_name}::{klass}'] = msg
233 for x in sorted(messages.keys()):
234 logger.debug(messages[x])
238 def initialize(entry_point):
240 Remember to initialize config, initialize logging, set/log a random
241 seed, etc... before running main.
245 @functools.wraps(entry_point)
246 def initialize_wrapper(*args, **kwargs):
247 # Hook top level unhandled exceptions, maybe invoke debugger.
248 if sys.excepthook == sys.__excepthook__:
249 sys.excepthook = handle_uncaught_exception
251 # Try to figure out the name of the program entry point. Then
252 # parse configuration (based on cmdline flags, environment vars
254 entry_filename = None
257 entry_filename = entry_point.__code__.co_filename
258 entry_descr = entry_point.__code__.__repr__()
260 if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
261 entry_filename = entry_point.__globals__['__file__']
262 entry_descr = entry_filename
263 config.parse(entry_filename)
265 if config.config['trace_memory']:
270 # Initialize logging... and log some remembered messages from
272 logging_utils.initialize_logging(logging.getLogger())
273 config.late_logging()
275 # Maybe log some info about the python interpreter itself.
277 'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
279 logger.debug('Python interpreter version: %s', sys.version)
280 logger.debug('Python implementation: %s', sys.implementation)
281 logger.debug('Python C API version: %s', sys.api_version)
283 logger.debug('Python interpreter running in __debug__ mode.')
285 logger.debug('Python interpreter running in optimized mode.')
286 logger.debug('Python path: %s', sys.path)
288 # Log something about the site_config, many things use it.
291 logger.debug('Global site_config: %s', site_config.get_config())
293 # Allow programs that don't bother to override the random seed
294 # to be replayed via the commandline.
297 random_seed = config.config['set_random_seed']
298 if random_seed is not None:
299 random_seed = random_seed[0]
301 random_seed = int.from_bytes(os.urandom(4), 'little')
303 if config.config['show_random_seed']:
304 msg = f'Global random seed is: {random_seed}'
307 random.seed(random_seed)
309 # Do it, invoke the user's code. Pay attention to how long it takes.
310 logger.debug('Starting %s (program entry point)', entry_descr)
314 if config.config['run_profiler']:
316 from pstats import SortKey
318 with stopwatch.Timer() as t:
320 "ret = entry_point(*args, **kwargs)",
327 with stopwatch.Timer() as t:
328 ret = entry_point(*args, **kwargs)
330 logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
332 if config.config['trace_memory']:
333 snapshot = tracemalloc.take_snapshot()
334 top_stats = snapshot.statistics('lineno')
336 print("--trace_memory's top 10 memory using files:")
337 for stat in top_stats[:10]:
340 if config.config['dump_all_objects']:
343 if config.config['audit_import_events']:
344 if IMPORT_INTERCEPTOR is not None:
345 print(IMPORT_INTERCEPTOR.tree)
348 (utime, stime, cutime, cstime, elapsed_time) = os.times()
353 'child user: %.4fs\n'
354 'child system: %.4fs\n'
355 'machine uptime: %.4fs\n'
365 # If it doesn't return cleanly, call attention to the return value.
366 if ret is not None and ret != 0:
367 logger.error('Exit %s', ret)
369 logger.debug('Exit %s', ret)
372 return initialize_wrapper