3 """This is a module for wrapping around python programs and doing some
4 minor setup and tear down work for them. With it, you can break into
5 pdb on unhandled top level exceptions, profile your code by passing a
6 commandline argument in, audit module import events, examine where
7 memory is being used in your program, and so on."""
14 from inspect import stack
18 from argparse_utils import ActionNoYes
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
24 logger = logging.getLogger(__name__)
26 cfg = config.add_commandline_args(
27 f'Bootstrap ({__file__})',
28 'Args related to python program bootstrapper and Swiss army knife',
31 '--debug_unhandled_exceptions',
34 help='Break into pdb on top level unhandled exceptions.',
40 help='Should we display (and log.debug) the global random seed?',
48 help='Override the global random seed with a particular number.',
54 help='Should we dump the Python import tree before main?',
57 '--audit_import_events',
60 help='Should we audit all import events?',
66 help='Should we run cProfile on this code?',
72 help='Should we record/report on memory utilization?',
75 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
78 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
80 Top-level exception handler for exceptions that make it past any exception
81 handlers in the python code being run. Logs the error and stacktrace then
82 maybe attaches a debugger.
85 msg = f'Unhandled top level exception {exc_type}'
87 print(msg, file=sys.stderr)
88 if issubclass(exc_type, KeyboardInterrupt):
89 sys.__excepthook__(exc_type, exc_value, exc_tb)
92 if not sys.stderr.isatty() or not sys.stdin.isatty():
93 # stdin or stderr is redirected, just do the normal thing
94 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
96 # a terminal is attached and stderr is not redirected, maybe debug.
99 traceback.print_exception(exc_type, exc_value, exc_tb)
100 if config.config['debug_unhandled_exceptions']:
103 logger.info("Invoking the debugger...")
106 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
109 class ImportInterceptor(importlib.abc.MetaPathFinder):
110 """An interceptor that always allows module load events but dumps a
111 record into the log and onto stdout when modules are loaded and
112 produces an audit of who imported what at the end of the run. It
113 can't see any load events that happen before it, though, so move
114 bootstrap up in your __main__'s import list just temporarily to
122 self.module_by_filename_cache = {}
123 self.repopulate_modules_by_filename()
124 self.tree = collect.trie.Trie()
125 self.tree_node_by_module = {}
127 def repopulate_modules_by_filename(self):
128 self.module_by_filename_cache.clear()
129 for _, mod in sys.modules.copy().items(): # copy here because modules is volatile
130 if hasattr(mod, '__file__'):
131 fname = getattr(mod, '__file__')
134 self.module_by_filename_cache[fname] = mod
137 def should_ignore_filename(filename: str) -> bool:
138 return 'importlib' in filename or 'six.py' in filename
140 def find_module(self, fullname, path):
141 raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
143 def find_spec(self, loaded_module, path=None, _=None):
145 for x in range(3, len(s)):
146 filename = s[x].filename
147 if ImportInterceptor.should_ignore_filename(filename):
150 loading_function = s[x].function
151 if filename in self.module_by_filename_cache:
152 loading_module = self.module_by_filename_cache[filename]
154 self.repopulate_modules_by_filename()
155 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
157 path = self.tree_node_by_module.get(loading_module, [])
158 path.extend([loaded_module])
159 self.tree.insert(path)
160 self.tree_node_by_module[loading_module] = path
162 msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
166 msg = f'*** Import {loaded_module} from ?????'
170 def invalidate_caches(self):
173 def find_importer(self, module: str):
174 if module in self.tree_node_by_module:
175 node = self.tree_node_by_module[module]
180 # Audit import events? Note: this runs early in the lifetime of the
181 # process (assuming that import bootstrap happens early); config has
182 # (probably) not yet been loaded or parsed the commandline. Also,
183 # some things have probably already been imported while we weren't
184 # watching so this information may be incomplete.
186 # Also note: move bootstrap up in the global import list to catch
187 # more import events and have a more complete record.
188 IMPORT_INTERCEPTOR = None
190 if arg == '--audit_import_events':
191 IMPORT_INTERCEPTOR = ImportInterceptor()
192 sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
195 def dump_all_objects() -> None:
197 all_modules = sys.modules
198 for obj in object.__subclasses__():
199 if not hasattr(obj, '__name__'):
202 if not hasattr(obj, '__module__'):
204 class_mod_name = obj.__module__
205 if class_mod_name in all_modules:
206 mod = all_modules[class_mod_name]
207 if not hasattr(mod, '__name__'):
208 mod_name = class_mod_name
210 mod_name = mod.__name__
211 if hasattr(mod, '__file__'):
212 mod_file = mod.__file__
215 if IMPORT_INTERCEPTOR is not None:
216 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
218 import_path = 'unknown'
219 msg = f'{class_mod_name}::{klass} ({mod_file})'
220 if import_path != 'unknown' and len(import_path) > 0:
221 msg += f' imported by {import_path}'
222 messages[f'{class_mod_name}::{klass}'] = msg
223 for x in sorted(messages.keys()):
224 logger.debug(messages[x])
228 def initialize(entry_point):
230 Remember to initialize config, initialize logging, set/log a random
231 seed, etc... before running main.
235 @functools.wraps(entry_point)
236 def initialize_wrapper(*args, **kwargs):
237 # Hook top level unhandled exceptions, maybe invoke debugger.
238 if sys.excepthook == sys.__excepthook__:
239 sys.excepthook = handle_uncaught_exception
241 # Try to figure out the name of the program entry point. Then
242 # parse configuration (based on cmdline flags, environment vars
244 entry_filename = None
247 entry_filename = entry_point.__code__.co_filename
248 entry_descr = entry_point.__code__.__repr__()
250 if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
251 entry_filename = entry_point.__globals__['__file__']
252 entry_descr = entry_filename
253 config.parse(entry_filename)
255 if config.config['trace_memory']:
260 # Initialize logging... and log some remembered messages from
262 logging_utils.initialize_logging(logging.getLogger())
263 config.late_logging()
265 # Maybe log some info about the python interpreter itself.
267 'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
269 logger.debug('Python interpreter version: %s', sys.version)
270 logger.debug('Python implementation: %s', sys.implementation)
271 logger.debug('Python C API version: %s', sys.api_version)
273 logger.debug('Python interpreter running in __debug__ mode.')
275 logger.debug('Python interpreter running in optimized mode.')
276 logger.debug('Python path: %s', sys.path)
278 # Log something about the site_config, many things use it.
281 logger.debug('Global site_config: %s', site_config.get_config())
283 # Allow programs that don't bother to override the random seed
284 # to be replayed via the commandline.
287 random_seed = config.config['set_random_seed']
288 if random_seed is not None:
289 random_seed = random_seed[0]
291 random_seed = int.from_bytes(os.urandom(4), 'little')
293 if config.config['show_random_seed']:
294 msg = f'Global random seed is: {random_seed}'
297 random.seed(random_seed)
299 # Do it, invoke the user's code. Pay attention to how long it takes.
300 logger.debug('Starting %s (program entry point)', entry_descr)
304 if config.config['run_profiler']:
306 from pstats import SortKey
308 with stopwatch.Timer() as t:
310 "ret = entry_point(*args, **kwargs)",
317 with stopwatch.Timer() as t:
318 ret = entry_point(*args, **kwargs)
320 logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
322 if config.config['trace_memory']:
323 snapshot = tracemalloc.take_snapshot()
324 top_stats = snapshot.statistics('lineno')
326 print("--trace_memory's top 10 memory using files:")
327 for stat in top_stats[:10]:
330 if config.config['dump_all_objects']:
333 if config.config['audit_import_events']:
334 if IMPORT_INTERCEPTOR is not None:
335 print(IMPORT_INTERCEPTOR.tree)
338 (utime, stime, cutime, cstime, elapsed_time) = os.times()
343 'child user: %.4fs\n'
344 'child system: %.4fs\n'
345 'machine uptime: %.4fs\n'
355 # If it doesn't return cleanly, call attention to the return value.
356 if ret is not None and ret != 0:
357 logger.error('Exit %s', ret)
359 logger.debug('Exit %s', ret)
362 return initialize_wrapper