3 # © Copyright 2021-2022, Scott Gasch
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them. With it, you can break into
7 pdb on unhandled top level exceptions, profile your code by passing a
8 commandline argument in, audit module import events, examine where
9 memory is being used in your program, and so on.
18 from inspect import stack
22 from argparse_utils import ActionNoYes
24 # This module is commonly used by others in here and should avoid
25 # taking any unnecessary dependencies back on them.
28 logger = logging.getLogger(__name__)
30 cfg = config.add_commandline_args(
31 f'Bootstrap ({__file__})',
32 'Args related to python program bootstrapper and Swiss army knife',
35 '--debug_unhandled_exceptions',
38 help='Break into pdb on top level unhandled exceptions.',
44 help='Should we display (and log.debug) the global random seed?',
52 help='Override the global random seed with a particular number.',
58 help='Should we dump the Python import tree before main?',
61 '--audit_import_events',
64 help='Should we audit all import events?',
70 help='Should we run cProfile on this code?',
76 help='Should we record/report on memory utilization?',
79 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
82 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
84 Top-level exception handler for exceptions that make it past any exception
85 handlers in the python code being run. Logs the error and stacktrace then
86 maybe attaches a debugger.
89 msg = f'Unhandled top level exception {exc_type}'
91 print(msg, file=sys.stderr)
92 if issubclass(exc_type, KeyboardInterrupt):
93 sys.__excepthook__(exc_type, exc_value, exc_tb)
96 if not sys.stderr.isatty() or not sys.stdin.isatty():
97 # stdin or stderr is redirected, just do the normal thing
98 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
100 # a terminal is attached and stderr is not redirected, maybe debug.
103 traceback.print_exception(exc_type, exc_value, exc_tb)
104 if config.config['debug_unhandled_exceptions']:
107 logger.info("Invoking the debugger...")
110 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
113 class ImportInterceptor(importlib.abc.MetaPathFinder):
114 """An interceptor that always allows module load events but dumps a
115 record into the log and onto stdout when modules are loaded and
116 produces an audit of who imported what at the end of the run. It
117 can't see any load events that happen before it, though, so move
118 bootstrap up in your __main__'s import list just temporarily to
126 self.module_by_filename_cache = {}
127 self.repopulate_modules_by_filename()
128 self.tree = collect.trie.Trie()
129 self.tree_node_by_module = {}
131 def repopulate_modules_by_filename(self):
132 self.module_by_filename_cache.clear()
133 for _, mod in sys.modules.copy().items(): # copy here because modules is volatile
134 if hasattr(mod, '__file__'):
135 fname = getattr(mod, '__file__')
138 self.module_by_filename_cache[fname] = mod
141 def should_ignore_filename(filename: str) -> bool:
142 return 'importlib' in filename or 'six.py' in filename
144 def find_module(self, fullname, path):
145 raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
147 def find_spec(self, loaded_module, path=None, _=None):
149 for x in range(3, len(s)):
150 filename = s[x].filename
151 if ImportInterceptor.should_ignore_filename(filename):
154 loading_function = s[x].function
155 if filename in self.module_by_filename_cache:
156 loading_module = self.module_by_filename_cache[filename]
158 self.repopulate_modules_by_filename()
159 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
161 path = self.tree_node_by_module.get(loading_module, [])
162 path.extend([loaded_module])
163 self.tree.insert(path)
164 self.tree_node_by_module[loading_module] = path
166 msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
170 msg = f'*** Import {loaded_module} from ?????'
174 def invalidate_caches(self):
177 def find_importer(self, module: str):
178 if module in self.tree_node_by_module:
179 node = self.tree_node_by_module[module]
184 # Audit import events? Note: this runs early in the lifetime of the
185 # process (assuming that import bootstrap happens early); config has
186 # (probably) not yet been loaded or parsed the commandline. Also,
187 # some things have probably already been imported while we weren't
188 # watching so this information may be incomplete.
190 # Also note: move bootstrap up in the global import list to catch
191 # more import events and have a more complete record.
192 IMPORT_INTERCEPTOR = None
194 if arg == '--audit_import_events':
195 IMPORT_INTERCEPTOR = ImportInterceptor()
196 sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
199 def dump_all_objects() -> None:
201 all_modules = sys.modules
202 for obj in object.__subclasses__():
203 if not hasattr(obj, '__name__'):
206 if not hasattr(obj, '__module__'):
208 class_mod_name = obj.__module__
209 if class_mod_name in all_modules:
210 mod = all_modules[class_mod_name]
211 if not hasattr(mod, '__name__'):
212 mod_name = class_mod_name
214 mod_name = mod.__name__
215 if hasattr(mod, '__file__'):
216 mod_file = mod.__file__
219 if IMPORT_INTERCEPTOR is not None:
220 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
222 import_path = 'unknown'
223 msg = f'{class_mod_name}::{klass} ({mod_file})'
224 if import_path != 'unknown' and len(import_path) > 0:
225 msg += f' imported by {import_path}'
226 messages[f'{class_mod_name}::{klass}'] = msg
227 for x in sorted(messages.keys()):
228 logger.debug(messages[x])
232 def initialize(entry_point):
234 Remember to initialize config, initialize logging, set/log a random
235 seed, etc... before running main.
239 @functools.wraps(entry_point)
240 def initialize_wrapper(*args, **kwargs):
241 # Hook top level unhandled exceptions, maybe invoke debugger.
242 if sys.excepthook == sys.__excepthook__:
243 sys.excepthook = handle_uncaught_exception
245 # Try to figure out the name of the program entry point. Then
246 # parse configuration (based on cmdline flags, environment vars
248 entry_filename = None
251 entry_filename = entry_point.__code__.co_filename
252 entry_descr = entry_point.__code__.__repr__()
254 if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
255 entry_filename = entry_point.__globals__['__file__']
256 entry_descr = entry_filename
257 config.parse(entry_filename)
259 if config.config['trace_memory']:
264 # Initialize logging... and log some remembered messages from
266 logging_utils.initialize_logging(logging.getLogger())
267 config.late_logging()
269 # Maybe log some info about the python interpreter itself.
271 'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
273 logger.debug('Python interpreter version: %s', sys.version)
274 logger.debug('Python implementation: %s', sys.implementation)
275 logger.debug('Python C API version: %s', sys.api_version)
277 logger.debug('Python interpreter running in __debug__ mode.')
279 logger.debug('Python interpreter running in optimized mode.')
280 logger.debug('Python path: %s', sys.path)
282 # Log something about the site_config, many things use it.
285 logger.debug('Global site_config: %s', site_config.get_config())
287 # Allow programs that don't bother to override the random seed
288 # to be replayed via the commandline.
291 random_seed = config.config['set_random_seed']
292 if random_seed is not None:
293 random_seed = random_seed[0]
295 random_seed = int.from_bytes(os.urandom(4), 'little')
297 if config.config['show_random_seed']:
298 msg = f'Global random seed is: {random_seed}'
301 random.seed(random_seed)
303 # Do it, invoke the user's code. Pay attention to how long it takes.
304 logger.debug('Starting %s (program entry point)', entry_descr)
308 if config.config['run_profiler']:
310 from pstats import SortKey
312 with stopwatch.Timer() as t:
314 "ret = entry_point(*args, **kwargs)",
321 with stopwatch.Timer() as t:
322 ret = entry_point(*args, **kwargs)
324 logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
326 if config.config['trace_memory']:
327 snapshot = tracemalloc.take_snapshot()
328 top_stats = snapshot.statistics('lineno')
330 print("--trace_memory's top 10 memory using files:")
331 for stat in top_stats[:10]:
334 if config.config['dump_all_objects']:
337 if config.config['audit_import_events']:
338 if IMPORT_INTERCEPTOR is not None:
339 print(IMPORT_INTERCEPTOR.tree)
342 (utime, stime, cutime, cstime, elapsed_time) = os.times()
347 'child user: %.4fs\n'
348 'child system: %.4fs\n'
349 'machine uptime: %.4fs\n'
359 # If it doesn't return cleanly, call attention to the return value.
360 if ret is not None and ret != 0:
361 logger.error('Exit %s', ret)
363 logger.debug('Exit %s', ret)
366 return initialize_wrapper