Oops.
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 """This is a module for wrapping around python programs and doing some
4 minor setup and tear down work for them.  With it, you can break into
5 pdb on unhandled top level exceptions, profile your code by passing a
6 commandline argument in, audit module import events, examine where
7 memory is being used in your program, and so on."""
8
9 import functools
10 import importlib
11 import logging
12 import os
13 import sys
14 from inspect import stack
15
16 import config
17 import logging_utils
18 from argparse_utils import ActionNoYes
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22
23
24 logger = logging.getLogger(__name__)
25
26 cfg = config.add_commandline_args(
27     f'Bootstrap ({__file__})',
28     'Args related to python program bootstrapper and Swiss army knife',
29 )
30 cfg.add_argument(
31     '--debug_unhandled_exceptions',
32     action=ActionNoYes,
33     default=False,
34     help='Break into pdb on top level unhandled exceptions.',
35 )
36 cfg.add_argument(
37     '--show_random_seed',
38     action=ActionNoYes,
39     default=False,
40     help='Should we display (and log.debug) the global random seed?',
41 )
42 cfg.add_argument(
43     '--set_random_seed',
44     type=int,
45     nargs=1,
46     default=None,
47     metavar='SEED_INT',
48     help='Override the global random seed with a particular number.',
49 )
50 cfg.add_argument(
51     '--dump_all_objects',
52     action=ActionNoYes,
53     default=False,
54     help='Should we dump the Python import tree before main?',
55 )
56 cfg.add_argument(
57     '--audit_import_events',
58     action=ActionNoYes,
59     default=False,
60     help='Should we audit all import events?',
61 )
62 cfg.add_argument(
63     '--run_profiler',
64     action=ActionNoYes,
65     default=False,
66     help='Should we run cProfile on this code?',
67 )
68 cfg.add_argument(
69     '--trace_memory',
70     action=ActionNoYes,
71     default=False,
72     help='Should we record/report on memory utilization?',
73 )
74
75 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
76
77
78 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
79     """
80     Top-level exception handler for exceptions that make it past any exception
81     handlers in the python code being run.  Logs the error and stacktrace then
82     maybe attaches a debugger.
83
84     """
85     msg = f'Unhandled top level exception {exc_type}'
86     logger.exception(msg)
87     print(msg, file=sys.stderr)
88     if issubclass(exc_type, KeyboardInterrupt):
89         sys.__excepthook__(exc_type, exc_value, exc_tb)
90         return
91     else:
92         if not sys.stderr.isatty() or not sys.stdin.isatty():
93             # stdin or stderr is redirected, just do the normal thing
94             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
95         else:
96             # a terminal is attached and stderr is not redirected, maybe debug.
97             import traceback
98
99             traceback.print_exception(exc_type, exc_value, exc_tb)
100             if config.config['debug_unhandled_exceptions']:
101                 import pdb
102
103                 logger.info("Invoking the debugger...")
104                 pdb.pm()
105             else:
106                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
107
108
109 class ImportInterceptor(importlib.abc.MetaPathFinder):
110     """An interceptor that always allows module load events but dumps a
111     record into the log and onto stdout when modules are loaded and
112     produces an audit of who imported what at the end of the run.  It
113     can't see any load events that happen before it, though, so move
114     bootstrap up in your __main__'s import list just temporarily to
115     get a good view.
116
117     """
118
119     def __init__(self):
120         import collect.trie
121
122         self.module_by_filename_cache = {}
123         self.repopulate_modules_by_filename()
124         self.tree = collect.trie.Trie()
125         self.tree_node_by_module = {}
126
127     def repopulate_modules_by_filename(self):
128         self.module_by_filename_cache.clear()
129         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
130             if hasattr(mod, '__file__'):
131                 fname = getattr(mod, '__file__')
132             else:
133                 fname = 'unknown'
134             self.module_by_filename_cache[fname] = mod
135
136     @staticmethod
137     def should_ignore_filename(filename: str) -> bool:
138         return 'importlib' in filename or 'six.py' in filename
139
140     def find_module(self, fullname, path):
141         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
142
143     def find_spec(self, loaded_module, path=None, _=None):
144         s = stack()
145         for x in range(3, len(s)):
146             filename = s[x].filename
147             if ImportInterceptor.should_ignore_filename(filename):
148                 continue
149
150             loading_function = s[x].function
151             if filename in self.module_by_filename_cache:
152                 loading_module = self.module_by_filename_cache[filename]
153             else:
154                 self.repopulate_modules_by_filename()
155                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
156
157             path = self.tree_node_by_module.get(loading_module, [])
158             path.extend([loaded_module])
159             self.tree.insert(path)
160             self.tree_node_by_module[loading_module] = path
161
162             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
163             logger.debug(msg)
164             print(msg)
165             return
166         msg = f'*** Import {loaded_module} from ?????'
167         logger.debug(msg)
168         print(msg)
169
170     def invalidate_caches(self):
171         pass
172
173     def find_importer(self, module: str):
174         if module in self.tree_node_by_module:
175             node = self.tree_node_by_module[module]
176             return node
177         return []
178
179
180 # Audit import events?  Note: this runs early in the lifetime of the
181 # process (assuming that import bootstrap happens early); config has
182 # (probably) not yet been loaded or parsed the commandline.  Also,
183 # some things have probably already been imported while we weren't
184 # watching so this information may be incomplete.
185 #
186 # Also note: move bootstrap up in the global import list to catch
187 # more import events and have a more complete record.
188 IMPORT_INTERCEPTOR = None
189 for arg in sys.argv:
190     if arg == '--audit_import_events':
191         IMPORT_INTERCEPTOR = ImportInterceptor()
192         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
193
194
195 def dump_all_objects() -> None:
196     messages = {}
197     all_modules = sys.modules
198     for obj in object.__subclasses__():
199         if not hasattr(obj, '__name__'):
200             continue
201         klass = obj.__name__
202         if not hasattr(obj, '__module__'):
203             continue
204         class_mod_name = obj.__module__
205         if class_mod_name in all_modules:
206             mod = all_modules[class_mod_name]
207             if not hasattr(mod, '__name__'):
208                 mod_name = class_mod_name
209             else:
210                 mod_name = mod.__name__
211             if hasattr(mod, '__file__'):
212                 mod_file = mod.__file__
213             else:
214                 mod_file = 'unknown'
215             if IMPORT_INTERCEPTOR is not None:
216                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
217             else:
218                 import_path = 'unknown'
219             msg = f'{class_mod_name}::{klass} ({mod_file})'
220             if import_path != 'unknown' and len(import_path) > 0:
221                 msg += f' imported by {import_path}'
222             messages[f'{class_mod_name}::{klass}'] = msg
223     for x in sorted(messages.keys()):
224         logger.debug(messages[x])
225         print(messages[x])
226
227
228 def initialize(entry_point):
229     """
230     Remember to initialize config, initialize logging, set/log a random
231     seed, etc... before running main.
232
233     """
234
235     @functools.wraps(entry_point)
236     def initialize_wrapper(*args, **kwargs):
237         # Hook top level unhandled exceptions, maybe invoke debugger.
238         if sys.excepthook == sys.__excepthook__:
239             sys.excepthook = handle_uncaught_exception
240
241         # Try to figure out the name of the program entry point.  Then
242         # parse configuration (based on cmdline flags, environment vars
243         # etc...)
244         entry_filename = None
245         entry_descr = None
246         try:
247             entry_filename = entry_point.__code__.co_filename
248             entry_descr = entry_point.__code__.__repr__()
249         except Exception:
250             if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
251                 entry_filename = entry_point.__globals__['__file__']
252                 entry_descr = entry_filename
253         config.parse(entry_filename)
254
255         if config.config['trace_memory']:
256             import tracemalloc
257
258             tracemalloc.start()
259
260         # Initialize logging... and log some remembered messages from
261         # config module.
262         logging_utils.initialize_logging(logging.getLogger())
263         config.late_logging()
264
265         # Maybe log some info about the python interpreter itself.
266         logger.debug(
267             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
268         )
269         logger.debug('Python interpreter version: %s', sys.version)
270         logger.debug('Python implementation: %s', sys.implementation)
271         logger.debug('Python C API version: %s', sys.api_version)
272         logger.debug('Python path: %s', sys.path)
273
274         # Log something about the site_config, many things use it.
275         import site_config
276
277         logger.debug('Global site_config: %s', site_config.get_config())
278
279         # Allow programs that don't bother to override the random seed
280         # to be replayed via the commandline.
281         import random
282
283         random_seed = config.config['set_random_seed']
284         if random_seed is not None:
285             random_seed = random_seed[0]
286         else:
287             random_seed = int.from_bytes(os.urandom(4), 'little')
288
289         if config.config['show_random_seed']:
290             msg = f'Global random seed is: {random_seed}'
291             logger.debug(msg)
292             print(msg)
293         random.seed(random_seed)
294
295         # Do it, invoke the user's code.  Pay attention to how long it takes.
296         logger.debug('Starting %s (program entry point)', entry_descr)
297         ret = None
298         import stopwatch
299
300         if config.config['run_profiler']:
301             import cProfile
302             from pstats import SortKey
303
304             with stopwatch.Timer() as t:
305                 cProfile.runctx(
306                     "ret = entry_point(*args, **kwargs)",
307                     globals(),
308                     locals(),
309                     None,
310                     SortKey.CUMULATIVE,
311                 )
312         else:
313             with stopwatch.Timer() as t:
314                 ret = entry_point(*args, **kwargs)
315
316         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
317
318         if config.config['trace_memory']:
319             snapshot = tracemalloc.take_snapshot()
320             top_stats = snapshot.statistics('lineno')
321             print()
322             print("--trace_memory's top 10 memory using files:")
323             for stat in top_stats[:10]:
324                 print(stat)
325
326         if config.config['dump_all_objects']:
327             dump_all_objects()
328
329         if config.config['audit_import_events']:
330             if IMPORT_INTERCEPTOR is not None:
331                 print(IMPORT_INTERCEPTOR.tree)
332
333         walltime = t()
334         (utime, stime, cutime, cstime, elapsed_time) = os.times()
335         logger.debug(
336             '\n'
337             'user: %.4fs\n'
338             'system: %.4fs\n'
339             'child user: %.4fs\n'
340             'child system: %.4fs\n'
341             'machine uptime: %.4fs\n'
342             'walltime: %.4fs',
343             utime,
344             stime,
345             cutime,
346             cstime,
347             elapsed_time,
348             walltime,
349         )
350
351         # If it doesn't return cleanly, call attention to the return value.
352         if ret is not None and ret != 0:
353             logger.error('Exit %s', ret)
354         else:
355             logger.debug('Exit %s', ret)
356         sys.exit(ret)
357
358     return initialize_wrapper