More spring cleaning.
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 """This is a module for wrapping around python programs and doing some
4 minor setup and tear down work for them.  With it, you can break into
5 pdb on unhandled top level exceptions, profile your code by passing a
6 commandline argument in, audit module import events, examine where
7 memory is being used in your program, and so on."""
8
9 import functools
10 import importlib
11 import logging
12 import os
13 import sys
14 from inspect import stack
15
16 import config
17 import logging_utils
18 from argparse_utils import ActionNoYes
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22
23
24 logger = logging.getLogger(__name__)
25
26 cfg = config.add_commandline_args(
27     f'Bootstrap ({__file__})',
28     'Args related to python program bootstrapper and Swiss army knife',
29 )
30 cfg.add_argument(
31     '--debug_unhandled_exceptions',
32     action=ActionNoYes,
33     default=False,
34     help='Break into pdb on top level unhandled exceptions.',
35 )
36 cfg.add_argument(
37     '--show_random_seed',
38     action=ActionNoYes,
39     default=False,
40     help='Should we display (and log.debug) the global random seed?',
41 )
42 cfg.add_argument(
43     '--set_random_seed',
44     type=int,
45     nargs=1,
46     default=None,
47     metavar='SEED_INT',
48     help='Override the global random seed with a particular number.',
49 )
50 cfg.add_argument(
51     '--dump_all_objects',
52     action=ActionNoYes,
53     default=False,
54     help='Should we dump the Python import tree before main?',
55 )
56 cfg.add_argument(
57     '--audit_import_events',
58     action=ActionNoYes,
59     default=False,
60     help='Should we audit all import events?',
61 )
62 cfg.add_argument(
63     '--run_profiler',
64     action=ActionNoYes,
65     default=False,
66     help='Should we run cProfile on this code?',
67 )
68 cfg.add_argument(
69     '--trace_memory',
70     action=ActionNoYes,
71     default=False,
72     help='Should we record/report on memory utilization?',
73 )
74
75 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
76
77
78 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
79     """
80     Top-level exception handler for exceptions that make it past any exception
81     handlers in the python code being run.  Logs the error and stacktrace then
82     maybe attaches a debugger.
83
84     """
85     msg = f'Unhandled top level exception {exc_type}'
86     logger.exception(msg)
87     print(msg, file=sys.stderr)
88     if issubclass(exc_type, KeyboardInterrupt):
89         sys.__excepthook__(exc_type, exc_value, exc_tb)
90         return
91     else:
92         if not sys.stderr.isatty() or not sys.stdin.isatty():
93             # stdin or stderr is redirected, just do the normal thing
94             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
95         else:
96             # a terminal is attached and stderr is not redirected, maybe debug.
97             import traceback
98
99             traceback.print_exception(exc_type, exc_value, exc_tb)
100             if config.config['debug_unhandled_exceptions']:
101                 import pdb
102
103                 logger.info("Invoking the debugger...")
104                 pdb.pm()
105             else:
106                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
107
108
109 class ImportInterceptor(importlib.abc.MetaPathFinder):
110     """An interceptor that always allows module load events but dumps a
111     record into the log and onto stdout when modules are loaded and
112     produces an audit of who imported what at the end of the run.  It
113     can't see any load events that happen before it, though, so move
114     bootstrap up in your __main__'s import list just temporarily to
115     get a good view.
116
117     """
118
119     def __init__(self):
120         import collect.trie
121
122         self.module_by_filename_cache = {}
123         self.repopulate_modules_by_filename()
124         self.tree = collect.trie.Trie()
125         self.tree_node_by_module = {}
126
127     def repopulate_modules_by_filename(self):
128         self.module_by_filename_cache.clear()
129         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
130             if hasattr(mod, '__file__'):
131                 fname = getattr(mod, '__file__')
132             else:
133                 fname = 'unknown'
134             self.module_by_filename_cache[fname] = mod
135
136     @staticmethod
137     def should_ignore_filename(filename: str) -> bool:
138         return 'importlib' in filename or 'six.py' in filename
139
140     def find_module(self, fullname, path):
141         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
142
143     def find_spec(self, loaded_module, path=None, _=None):
144         s = stack()
145         for x in range(3, len(s)):
146             filename = s[x].filename
147             if ImportInterceptor.should_ignore_filename(filename):
148                 continue
149
150             loading_function = s[x].function
151             if filename in self.module_by_filename_cache:
152                 loading_module = self.module_by_filename_cache[filename]
153             else:
154                 self.repopulate_modules_by_filename()
155                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
156
157             path = self.tree_node_by_module.get(loading_module, [])
158             path.extend([loaded_module])
159             self.tree.insert(path)
160             self.tree_node_by_module[loading_module] = path
161
162             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
163             logger.debug(msg)
164             print(msg)
165             return
166         msg = f'*** Import {loaded_module} from ?????'
167         logger.debug(msg)
168         print(msg)
169
170     def invalidate_caches(self):
171         pass
172
173     def find_importer(self, module: str):
174         if module in self.tree_node_by_module:
175             node = self.tree_node_by_module[module]
176             return node
177         return []
178
179
180 # Audit import events?  Note: this runs early in the lifetime of the
181 # process (assuming that import bootstrap happens early); config has
182 # (probably) not yet been loaded or parsed the commandline.  Also,
183 # some things have probably already been imported while we weren't
184 # watching so this information may be incomplete.
185 #
186 # Also note: move bootstrap up in the global import list to catch
187 # more import events and have a more complete record.
188 IMPORT_INTERCEPTOR = None
189 for arg in sys.argv:
190     if arg == '--audit_import_events':
191         IMPORT_INTERCEPTOR = ImportInterceptor()
192         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
193
194
195 def dump_all_objects() -> None:
196     messages = {}
197     all_modules = sys.modules
198     for obj in object.__subclasses__():
199         if not hasattr(obj, '__name__'):
200             continue
201         klass = obj.__name__
202         if not hasattr(obj, '__module__'):
203             continue
204         class_mod_name = obj.__module__
205         if class_mod_name in all_modules:
206             mod = all_modules[class_mod_name]
207             if not hasattr(mod, '__name__'):
208                 mod_name = class_mod_name
209             else:
210                 mod_name = mod.__name__
211             if hasattr(mod, '__file__'):
212                 mod_file = mod.__file__
213             else:
214                 mod_file = 'unknown'
215             if IMPORT_INTERCEPTOR is not None:
216                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
217             else:
218                 import_path = 'unknown'
219             msg = f'{class_mod_name}::{klass} ({mod_file})'
220             if import_path != 'unknown' and len(import_path) > 0:
221                 msg += f' imported by {import_path}'
222             messages[f'{class_mod_name}::{klass}'] = msg
223     for x in sorted(messages.keys()):
224         logger.debug(messages[x])
225         print(messages[x])
226
227
228 def initialize(entry_point):
229     """
230     Remember to initialize config, initialize logging, set/log a random
231     seed, etc... before running main.
232
233     """
234
235     @functools.wraps(entry_point)
236     def initialize_wrapper(*args, **kwargs):
237         # Hook top level unhandled exceptions, maybe invoke debugger.
238         if sys.excepthook == sys.__excepthook__:
239             sys.excepthook = handle_uncaught_exception
240
241         # Try to figure out the name of the program entry point.  Then
242         # parse configuration (based on cmdline flags, environment vars
243         # etc...)
244         if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
245             config.parse(entry_point.__globals__['__file__'])
246         else:
247             config.parse(None)
248
249         if config.config['trace_memory']:
250             import tracemalloc
251
252             tracemalloc.start()
253
254         # Initialize logging... and log some remembered messages from
255         # config module.
256         logging_utils.initialize_logging(logging.getLogger())
257         config.late_logging()
258
259         # Maybe log some info about the python interpreter itself.
260         logger.debug(
261             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
262         )
263         logger.debug('Python interpreter version: %s', sys.version)
264         logger.debug('Python implementation: %s', sys.implementation)
265         logger.debug('Python C API version: %s', sys.api_version)
266         logger.debug('Python path: %s', sys.path)
267
268         # Log something about the site_config, many things use it.
269         import site_config
270
271         logger.debug('Global site_config: %s', site_config.get_config())
272
273         # Allow programs that don't bother to override the random seed
274         # to be replayed via the commandline.
275         import random
276
277         random_seed = config.config['set_random_seed']
278         if random_seed is not None:
279             random_seed = random_seed[0]
280         else:
281             random_seed = int.from_bytes(os.urandom(4), 'little')
282
283         if config.config['show_random_seed']:
284             msg = f'Global random seed is: {random_seed}'
285             logger.debug(msg)
286             print(msg)
287         random.seed(random_seed)
288
289         # Do it, invoke the user's code.  Pay attention to how long it takes.
290         logger.debug('Starting %s (program entry point)', entry_point.__name__)
291         ret = None
292         import stopwatch
293
294         if config.config['run_profiler']:
295             import cProfile
296             from pstats import SortKey
297
298             with stopwatch.Timer() as t:
299                 cProfile.runctx(
300                     "ret = entry_point(*args, **kwargs)",
301                     globals(),
302                     locals(),
303                     None,
304                     SortKey.CUMULATIVE,
305                 )
306         else:
307             with stopwatch.Timer() as t:
308                 ret = entry_point(*args, **kwargs)
309
310         logger.debug('%s (program entry point) returned %s.', entry_point.__name__, ret)
311
312         if config.config['trace_memory']:
313             snapshot = tracemalloc.take_snapshot()
314             top_stats = snapshot.statistics('lineno')
315             print()
316             print("--trace_memory's top 10 memory using files:")
317             for stat in top_stats[:10]:
318                 print(stat)
319
320         if config.config['dump_all_objects']:
321             dump_all_objects()
322
323         if config.config['audit_import_events']:
324             if IMPORT_INTERCEPTOR is not None:
325                 print(IMPORT_INTERCEPTOR.tree)
326
327         walltime = t()
328         (utime, stime, cutime, cstime, elapsed_time) = os.times()
329         logger.debug(
330             '\n'
331             'user: %.4fs\n'
332             'system: %.4fs\n'
333             'child user: %.4fs\n'
334             'child system: %.4fs\n'
335             'machine uptime: %.4fs\n'
336             'walltime: %.4fs',
337             utime,
338             stime,
339             cutime,
340             cstime,
341             elapsed_time,
342             walltime,
343         )
344
345         # If it doesn't return cleanly, call attention to the return value.
346         if ret is not None and ret != 0:
347             logger.error('Exit %s', ret)
348         else:
349             logger.debug('Exit %s', ret)
350         sys.exit(ret)
351
352     return initialize_wrapper