Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them.  With it, you will get:
7
8 * The ability to break into pdb on unhandled exceptions,
9 * automatic support for :file:`config.py` (argument parsing)
10 * automatic logging support for :file:`logging.py`,
11 * the ability to enable code profiling,
12 * the ability to enable module import auditing,
13 * optional memory profiling for your program,
14 * ability to set random seed via commandline,
15 * automatic program timing and reporting,
16 * more verbose error handling and reporting,
17
18 Most of these are enabled and/or configured via commandline flags
19 (see below).
20
21 """
22
23 import functools
24 import importlib
25 import importlib.abc
26 import logging
27 import os
28 import sys
29 from inspect import stack
30
31 import config
32 import logging_utils
33 from argparse_utils import ActionNoYes
34
35 # This module is commonly used by others in here and should avoid
36 # taking any unnecessary dependencies back on them.
37
38
39 logger = logging.getLogger(__name__)
40
41 cfg = config.add_commandline_args(
42     f'Bootstrap ({__file__})',
43     'Args related to python program bootstrapper and Swiss army knife',
44 )
45 cfg.add_argument(
46     '--debug_unhandled_exceptions',
47     action=ActionNoYes,
48     default=False,
49     help='Break into pdb on top level unhandled exceptions.',
50 )
51 cfg.add_argument(
52     '--show_random_seed',
53     action=ActionNoYes,
54     default=False,
55     help='Should we display (and log.debug) the global random seed?',
56 )
57 cfg.add_argument(
58     '--set_random_seed',
59     type=int,
60     nargs=1,
61     default=None,
62     metavar='SEED_INT',
63     help='Override the global random seed with a particular number.',
64 )
65 cfg.add_argument(
66     '--dump_all_objects',
67     action=ActionNoYes,
68     default=False,
69     help='Should we dump the Python import tree before main?',
70 )
71 cfg.add_argument(
72     '--audit_import_events',
73     action=ActionNoYes,
74     default=False,
75     help='Should we audit all import events?',
76 )
77 cfg.add_argument(
78     '--run_profiler',
79     action=ActionNoYes,
80     default=False,
81     help='Should we run cProfile on this code?',
82 )
83 cfg.add_argument(
84     '--trace_memory',
85     action=ActionNoYes,
86     default=False,
87     help='Should we record/report on memory utilization?',
88 )
89
90 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
91
92
93 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
94     """
95     Top-level exception handler for exceptions that make it past any exception
96     handlers in the python code being run.  Logs the error and stacktrace then
97     maybe attaches a debugger.
98
99     """
100     msg = f'Unhandled top level exception {exc_type}'
101     logger.exception(msg)
102     print(msg, file=sys.stderr)
103     if issubclass(exc_type, KeyboardInterrupt):
104         sys.__excepthook__(exc_type, exc_value, exc_tb)
105         return
106     else:
107         import io
108         import traceback
109
110         tb_output = io.StringIO()
111         traceback.print_tb(exc_tb, None, tb_output)
112         print(tb_output.getvalue(), file=sys.stderr)
113         logger.error(tb_output.getvalue())
114         tb_output.close()
115
116         # stdin or stderr is redirected, just do the normal thing
117         if not sys.stderr.isatty() or not sys.stdin.isatty():
118             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
119
120         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
121             if config.config['debug_unhandled_exceptions']:
122                 logger.info("Invoking the debugger...")
123                 import pdb
124
125                 pdb.pm()
126             else:
127                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
128
129
130 class ImportInterceptor(importlib.abc.MetaPathFinder):
131     """An interceptor that always allows module load events but dumps a
132     record into the log and onto stdout when modules are loaded and
133     produces an audit of who imported what at the end of the run.  It
134     can't see any load events that happen before it, though, so move
135     bootstrap up in your __main__'s import list just temporarily to
136     get a good view.
137
138     """
139
140     def __init__(self):
141         import collect.trie
142
143         self.module_by_filename_cache = {}
144         self.repopulate_modules_by_filename()
145         self.tree = collect.trie.Trie()
146         self.tree_node_by_module = {}
147
148     def repopulate_modules_by_filename(self):
149         self.module_by_filename_cache.clear()
150         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
151             if hasattr(mod, '__file__'):
152                 fname = getattr(mod, '__file__')
153             else:
154                 fname = 'unknown'
155             self.module_by_filename_cache[fname] = mod
156
157     @staticmethod
158     def should_ignore_filename(filename: str) -> bool:
159         return 'importlib' in filename or 'six.py' in filename
160
161     def find_module(self, fullname, path):
162         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
163
164     def find_spec(self, loaded_module, path=None, _=None):
165         s = stack()
166         for x in range(3, len(s)):
167             filename = s[x].filename
168             if ImportInterceptor.should_ignore_filename(filename):
169                 continue
170
171             loading_function = s[x].function
172             if filename in self.module_by_filename_cache:
173                 loading_module = self.module_by_filename_cache[filename]
174             else:
175                 self.repopulate_modules_by_filename()
176                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
177
178             path = self.tree_node_by_module.get(loading_module, [])
179             path.extend([loaded_module])
180             self.tree.insert(path)
181             self.tree_node_by_module[loading_module] = path
182
183             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
184             logger.debug(msg)
185             print(msg)
186             return
187         msg = f'*** Import {loaded_module} from ?????'
188         logger.debug(msg)
189         print(msg)
190
191     def invalidate_caches(self):
192         pass
193
194     def find_importer(self, module: str):
195         if module in self.tree_node_by_module:
196             node = self.tree_node_by_module[module]
197             return node
198         return []
199
200
201 # Audit import events?  Note: this runs early in the lifetime of the
202 # process (assuming that import bootstrap happens early); config has
203 # (probably) not yet been loaded or parsed the commandline.  Also,
204 # some things have probably already been imported while we weren't
205 # watching so this information may be incomplete.
206 #
207 # Also note: move bootstrap up in the global import list to catch
208 # more import events and have a more complete record.
209 IMPORT_INTERCEPTOR = None
210 for arg in sys.argv:
211     if arg == '--audit_import_events':
212         IMPORT_INTERCEPTOR = ImportInterceptor()
213         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
214
215
216 def dump_all_objects() -> None:
217     """Helper code to dump all known python objects."""
218
219     messages = {}
220     all_modules = sys.modules
221     for obj in object.__subclasses__():
222         if not hasattr(obj, '__name__'):
223             continue
224         klass = obj.__name__
225         if not hasattr(obj, '__module__'):
226             continue
227         class_mod_name = obj.__module__
228         if class_mod_name in all_modules:
229             mod = all_modules[class_mod_name]
230             if not hasattr(mod, '__name__'):
231                 mod_name = class_mod_name
232             else:
233                 mod_name = mod.__name__
234             if hasattr(mod, '__file__'):
235                 mod_file = mod.__file__
236             else:
237                 mod_file = 'unknown'
238             if IMPORT_INTERCEPTOR is not None:
239                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
240             else:
241                 import_path = 'unknown'
242             msg = f'{class_mod_name}::{klass} ({mod_file})'
243             if import_path != 'unknown' and len(import_path) > 0:
244                 msg += f' imported by {import_path}'
245             messages[f'{class_mod_name}::{klass}'] = msg
246     for x in sorted(messages.keys()):
247         logger.debug(messages[x])
248         print(messages[x])
249
250
251 def initialize(entry_point):
252     """
253     Remember to initialize config, initialize logging, set/log a random
254     seed, etc... before running main.  If you use this decorator around
255     your main, like this::
256
257         import bootstrap
258
259         @bootstrap.initialize
260         def main():
261             whatever
262
263         if __name__ == '__main__':
264             main()
265
266     You get:
267
268     * The ability to break into pdb on unhandled exceptions,
269     * automatic support for :file:`config.py` (argument parsing)
270     * automatic logging support for :file:`logging.py`,
271     * the ability to enable code profiling,
272     * the ability to enable module import auditing,
273     * optional memory profiling for your program,
274     * ability to set random seed via commandline,
275     * automatic program timing and reporting,
276     * more verbose error handling and reporting,
277
278     Most of these are enabled and/or configured via commandline flags
279     (see below).
280     """
281
282     @functools.wraps(entry_point)
283     def initialize_wrapper(*args, **kwargs):
284         # Hook top level unhandled exceptions, maybe invoke debugger.
285         if sys.excepthook == sys.__excepthook__:
286             sys.excepthook = handle_uncaught_exception
287
288         # Try to figure out the name of the program entry point.  Then
289         # parse configuration (based on cmdline flags, environment vars
290         # etc...)
291         entry_filename = None
292         entry_descr = None
293         try:
294             entry_filename = entry_point.__code__.co_filename
295             entry_descr = entry_point.__code__.__repr__()
296         except Exception:
297             if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
298                 entry_filename = entry_point.__globals__['__file__']
299                 entry_descr = entry_filename
300         config.parse(entry_filename)
301
302         if config.config['trace_memory']:
303             import tracemalloc
304
305             tracemalloc.start()
306
307         # Initialize logging... and log some remembered messages from
308         # config module.
309         logging_utils.initialize_logging(logging.getLogger())
310         config.late_logging()
311
312         # Maybe log some info about the python interpreter itself.
313         logger.debug(
314             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
315         )
316         logger.debug('Python interpreter version: %s', sys.version)
317         logger.debug('Python implementation: %s', sys.implementation)
318         logger.debug('Python C API version: %s', sys.api_version)
319         if __debug__:
320             logger.debug('Python interpreter running in __debug__ mode.')
321         else:
322             logger.debug('Python interpreter running in optimized mode.')
323         logger.debug('Python path: %s', sys.path)
324
325         # Log something about the site_config, many things use it.
326         import site_config
327
328         logger.debug('Global site_config: %s', site_config.get_config())
329
330         # Allow programs that don't bother to override the random seed
331         # to be replayed via the commandline.
332         import random
333
334         random_seed = config.config['set_random_seed']
335         if random_seed is not None:
336             random_seed = random_seed[0]
337         else:
338             random_seed = int.from_bytes(os.urandom(4), 'little')
339
340         if config.config['show_random_seed']:
341             msg = f'Global random seed is: {random_seed}'
342             logger.debug(msg)
343             print(msg)
344         random.seed(random_seed)
345
346         # Do it, invoke the user's code.  Pay attention to how long it takes.
347         logger.debug('Starting %s (program entry point)', entry_descr)
348         ret = None
349         import stopwatch
350
351         if config.config['run_profiler']:
352             import cProfile
353             from pstats import SortKey
354
355             with stopwatch.Timer() as t:
356                 cProfile.runctx(
357                     "ret = entry_point(*args, **kwargs)",
358                     globals(),
359                     locals(),
360                     None,
361                     SortKey.CUMULATIVE,
362                 )
363         else:
364             with stopwatch.Timer() as t:
365                 ret = entry_point(*args, **kwargs)
366
367         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
368
369         if config.config['trace_memory']:
370             snapshot = tracemalloc.take_snapshot()
371             top_stats = snapshot.statistics('lineno')
372             print()
373             print("--trace_memory's top 10 memory using files:")
374             for stat in top_stats[:10]:
375                 print(stat)
376
377         if config.config['dump_all_objects']:
378             dump_all_objects()
379
380         if config.config['audit_import_events']:
381             if IMPORT_INTERCEPTOR is not None:
382                 print(IMPORT_INTERCEPTOR.tree)
383
384         walltime = t()
385         (utime, stime, cutime, cstime, elapsed_time) = os.times()
386         logger.debug(
387             '\n'
388             'user: %.4fs\n'
389             'system: %.4fs\n'
390             'child user: %.4fs\n'
391             'child system: %.4fs\n'
392             'machine uptime: %.4fs\n'
393             'walltime: %.4fs',
394             utime,
395             stime,
396             cutime,
397             cstime,
398             elapsed_time,
399             walltime,
400         )
401
402         # If it doesn't return cleanly, call attention to the return value.
403         if ret is not None and ret != 0:
404             logger.error('Exit %s', ret)
405         else:
406             logger.debug('Exit %s', ret)
407         sys.exit(ret)
408
409     return initialize_wrapper