Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them.  With it, you can break into
7 pdb on unhandled top level exceptions, profile your code by passing a
8 commandline argument in, audit module import events, examine where
9 memory is being used in your program, and so on.
10
11 """
12
13 import functools
14 import importlib
15 import logging
16 import os
17 import sys
18 from inspect import stack
19
20 import config
21 import logging_utils
22 from argparse_utils import ActionNoYes
23
24 # This module is commonly used by others in here and should avoid
25 # taking any unnecessary dependencies back on them.
26
27
28 logger = logging.getLogger(__name__)
29
30 cfg = config.add_commandline_args(
31     f'Bootstrap ({__file__})',
32     'Args related to python program bootstrapper and Swiss army knife',
33 )
34 cfg.add_argument(
35     '--debug_unhandled_exceptions',
36     action=ActionNoYes,
37     default=False,
38     help='Break into pdb on top level unhandled exceptions.',
39 )
40 cfg.add_argument(
41     '--show_random_seed',
42     action=ActionNoYes,
43     default=False,
44     help='Should we display (and log.debug) the global random seed?',
45 )
46 cfg.add_argument(
47     '--set_random_seed',
48     type=int,
49     nargs=1,
50     default=None,
51     metavar='SEED_INT',
52     help='Override the global random seed with a particular number.',
53 )
54 cfg.add_argument(
55     '--dump_all_objects',
56     action=ActionNoYes,
57     default=False,
58     help='Should we dump the Python import tree before main?',
59 )
60 cfg.add_argument(
61     '--audit_import_events',
62     action=ActionNoYes,
63     default=False,
64     help='Should we audit all import events?',
65 )
66 cfg.add_argument(
67     '--run_profiler',
68     action=ActionNoYes,
69     default=False,
70     help='Should we run cProfile on this code?',
71 )
72 cfg.add_argument(
73     '--trace_memory',
74     action=ActionNoYes,
75     default=False,
76     help='Should we record/report on memory utilization?',
77 )
78
79 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
80
81
82 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
83     """
84     Top-level exception handler for exceptions that make it past any exception
85     handlers in the python code being run.  Logs the error and stacktrace then
86     maybe attaches a debugger.
87
88     """
89     msg = f'Unhandled top level exception {exc_type}'
90     logger.exception(msg)
91     print(msg, file=sys.stderr)
92     if issubclass(exc_type, KeyboardInterrupt):
93         sys.__excepthook__(exc_type, exc_value, exc_tb)
94         return
95     else:
96         if not sys.stderr.isatty() or not sys.stdin.isatty():
97             # stdin or stderr is redirected, just do the normal thing
98             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
99         else:
100             # a terminal is attached and stderr is not redirected, maybe debug.
101             import traceback
102
103             traceback.print_exception(exc_type, exc_value, exc_tb)
104             if config.config['debug_unhandled_exceptions']:
105                 import pdb
106
107                 logger.info("Invoking the debugger...")
108                 pdb.pm()
109             else:
110                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
111
112
113 class ImportInterceptor(importlib.abc.MetaPathFinder):
114     """An interceptor that always allows module load events but dumps a
115     record into the log and onto stdout when modules are loaded and
116     produces an audit of who imported what at the end of the run.  It
117     can't see any load events that happen before it, though, so move
118     bootstrap up in your __main__'s import list just temporarily to
119     get a good view.
120
121     """
122
123     def __init__(self):
124         import collect.trie
125
126         self.module_by_filename_cache = {}
127         self.repopulate_modules_by_filename()
128         self.tree = collect.trie.Trie()
129         self.tree_node_by_module = {}
130
131     def repopulate_modules_by_filename(self):
132         self.module_by_filename_cache.clear()
133         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
134             if hasattr(mod, '__file__'):
135                 fname = getattr(mod, '__file__')
136             else:
137                 fname = 'unknown'
138             self.module_by_filename_cache[fname] = mod
139
140     @staticmethod
141     def should_ignore_filename(filename: str) -> bool:
142         return 'importlib' in filename or 'six.py' in filename
143
144     def find_module(self, fullname, path):
145         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
146
147     def find_spec(self, loaded_module, path=None, _=None):
148         s = stack()
149         for x in range(3, len(s)):
150             filename = s[x].filename
151             if ImportInterceptor.should_ignore_filename(filename):
152                 continue
153
154             loading_function = s[x].function
155             if filename in self.module_by_filename_cache:
156                 loading_module = self.module_by_filename_cache[filename]
157             else:
158                 self.repopulate_modules_by_filename()
159                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
160
161             path = self.tree_node_by_module.get(loading_module, [])
162             path.extend([loaded_module])
163             self.tree.insert(path)
164             self.tree_node_by_module[loading_module] = path
165
166             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
167             logger.debug(msg)
168             print(msg)
169             return
170         msg = f'*** Import {loaded_module} from ?????'
171         logger.debug(msg)
172         print(msg)
173
174     def invalidate_caches(self):
175         pass
176
177     def find_importer(self, module: str):
178         if module in self.tree_node_by_module:
179             node = self.tree_node_by_module[module]
180             return node
181         return []
182
183
184 # Audit import events?  Note: this runs early in the lifetime of the
185 # process (assuming that import bootstrap happens early); config has
186 # (probably) not yet been loaded or parsed the commandline.  Also,
187 # some things have probably already been imported while we weren't
188 # watching so this information may be incomplete.
189 #
190 # Also note: move bootstrap up in the global import list to catch
191 # more import events and have a more complete record.
192 IMPORT_INTERCEPTOR = None
193 for arg in sys.argv:
194     if arg == '--audit_import_events':
195         IMPORT_INTERCEPTOR = ImportInterceptor()
196         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
197
198
199 def dump_all_objects() -> None:
200     messages = {}
201     all_modules = sys.modules
202     for obj in object.__subclasses__():
203         if not hasattr(obj, '__name__'):
204             continue
205         klass = obj.__name__
206         if not hasattr(obj, '__module__'):
207             continue
208         class_mod_name = obj.__module__
209         if class_mod_name in all_modules:
210             mod = all_modules[class_mod_name]
211             if not hasattr(mod, '__name__'):
212                 mod_name = class_mod_name
213             else:
214                 mod_name = mod.__name__
215             if hasattr(mod, '__file__'):
216                 mod_file = mod.__file__
217             else:
218                 mod_file = 'unknown'
219             if IMPORT_INTERCEPTOR is not None:
220                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
221             else:
222                 import_path = 'unknown'
223             msg = f'{class_mod_name}::{klass} ({mod_file})'
224             if import_path != 'unknown' and len(import_path) > 0:
225                 msg += f' imported by {import_path}'
226             messages[f'{class_mod_name}::{klass}'] = msg
227     for x in sorted(messages.keys()):
228         logger.debug(messages[x])
229         print(messages[x])
230
231
232 def initialize(entry_point):
233     """
234     Remember to initialize config, initialize logging, set/log a random
235     seed, etc... before running main.
236
237     """
238
239     @functools.wraps(entry_point)
240     def initialize_wrapper(*args, **kwargs):
241         # Hook top level unhandled exceptions, maybe invoke debugger.
242         if sys.excepthook == sys.__excepthook__:
243             sys.excepthook = handle_uncaught_exception
244
245         # Try to figure out the name of the program entry point.  Then
246         # parse configuration (based on cmdline flags, environment vars
247         # etc...)
248         entry_filename = None
249         entry_descr = None
250         try:
251             entry_filename = entry_point.__code__.co_filename
252             entry_descr = entry_point.__code__.__repr__()
253         except Exception:
254             if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
255                 entry_filename = entry_point.__globals__['__file__']
256                 entry_descr = entry_filename
257         config.parse(entry_filename)
258
259         if config.config['trace_memory']:
260             import tracemalloc
261
262             tracemalloc.start()
263
264         # Initialize logging... and log some remembered messages from
265         # config module.
266         logging_utils.initialize_logging(logging.getLogger())
267         config.late_logging()
268
269         # Maybe log some info about the python interpreter itself.
270         logger.debug(
271             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
272         )
273         logger.debug('Python interpreter version: %s', sys.version)
274         logger.debug('Python implementation: %s', sys.implementation)
275         logger.debug('Python C API version: %s', sys.api_version)
276         if __debug__:
277             logger.debug('Python interpreter running in __debug__ mode.')
278         else:
279             logger.debug('Python interpreter running in optimized mode.')
280         logger.debug('Python path: %s', sys.path)
281
282         # Log something about the site_config, many things use it.
283         import site_config
284
285         logger.debug('Global site_config: %s', site_config.get_config())
286
287         # Allow programs that don't bother to override the random seed
288         # to be replayed via the commandline.
289         import random
290
291         random_seed = config.config['set_random_seed']
292         if random_seed is not None:
293             random_seed = random_seed[0]
294         else:
295             random_seed = int.from_bytes(os.urandom(4), 'little')
296
297         if config.config['show_random_seed']:
298             msg = f'Global random seed is: {random_seed}'
299             logger.debug(msg)
300             print(msg)
301         random.seed(random_seed)
302
303         # Do it, invoke the user's code.  Pay attention to how long it takes.
304         logger.debug('Starting %s (program entry point)', entry_descr)
305         ret = None
306         import stopwatch
307
308         if config.config['run_profiler']:
309             import cProfile
310             from pstats import SortKey
311
312             with stopwatch.Timer() as t:
313                 cProfile.runctx(
314                     "ret = entry_point(*args, **kwargs)",
315                     globals(),
316                     locals(),
317                     None,
318                     SortKey.CUMULATIVE,
319                 )
320         else:
321             with stopwatch.Timer() as t:
322                 ret = entry_point(*args, **kwargs)
323
324         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
325
326         if config.config['trace_memory']:
327             snapshot = tracemalloc.take_snapshot()
328             top_stats = snapshot.statistics('lineno')
329             print()
330             print("--trace_memory's top 10 memory using files:")
331             for stat in top_stats[:10]:
332                 print(stat)
333
334         if config.config['dump_all_objects']:
335             dump_all_objects()
336
337         if config.config['audit_import_events']:
338             if IMPORT_INTERCEPTOR is not None:
339                 print(IMPORT_INTERCEPTOR.tree)
340
341         walltime = t()
342         (utime, stime, cutime, cstime, elapsed_time) = os.times()
343         logger.debug(
344             '\n'
345             'user: %.4fs\n'
346             'system: %.4fs\n'
347             'child user: %.4fs\n'
348             'child system: %.4fs\n'
349             'machine uptime: %.4fs\n'
350             'walltime: %.4fs',
351             utime,
352             stime,
353             cutime,
354             cstime,
355             elapsed_time,
356             walltime,
357         )
358
359         # If it doesn't return cleanly, call attention to the return value.
360         if ret is not None and ret != 0:
361             logger.error('Exit %s', ret)
362         else:
363             logger.debug('Exit %s', ret)
364         sys.exit(ret)
365
366     return initialize_wrapper