Migration from old pyutilz package name (which, in turn, came from
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them.  With it, you will get:
7
8 * The ability to break into pdb on unhandled exceptions,
9 * automatic support for :file:`config.py` (argument parsing)
10 * automatic logging support for :file:`logging.py`,
11 * the ability to enable code profiling,
12 * the ability to enable module import auditing,
13 * optional memory profiling for your program,
14 * ability to set random seed via commandline,
15 * automatic program timing and reporting,
16 * more verbose error handling and reporting,
17
18 Most of these are enabled and/or configured via commandline flags
19 (see below).
20
21 """
22
23 import functools
24 import importlib
25 import importlib.abc
26 import logging
27 import os
28 import sys
29 from inspect import stack
30
31 from pyutils import config, logging_utils
32 from pyutils.argparse_utils import ActionNoYes
33
34 # This module is commonly used by others in here and should avoid
35 # taking any unnecessary dependencies back on them.
36
37
38 logger = logging.getLogger(__name__)
39
40 cfg = config.add_commandline_args(
41     f'Bootstrap ({__file__})',
42     'Args related to python program bootstrapper and Swiss army knife',
43 )
44 cfg.add_argument(
45     '--debug_unhandled_exceptions',
46     action=ActionNoYes,
47     default=False,
48     help='Break into pdb on top level unhandled exceptions.',
49 )
50 cfg.add_argument(
51     '--show_random_seed',
52     action=ActionNoYes,
53     default=False,
54     help='Should we display (and log.debug) the global random seed?',
55 )
56 cfg.add_argument(
57     '--set_random_seed',
58     type=int,
59     nargs=1,
60     default=None,
61     metavar='SEED_INT',
62     help='Override the global random seed with a particular number.',
63 )
64 cfg.add_argument(
65     '--dump_all_objects',
66     action=ActionNoYes,
67     default=False,
68     help='Should we dump the Python import tree before main?',
69 )
70 cfg.add_argument(
71     '--audit_import_events',
72     action=ActionNoYes,
73     default=False,
74     help='Should we audit all import events?',
75 )
76 cfg.add_argument(
77     '--run_profiler',
78     action=ActionNoYes,
79     default=False,
80     help='Should we run cProfile on this code?',
81 )
82 cfg.add_argument(
83     '--trace_memory',
84     action=ActionNoYes,
85     default=False,
86     help='Should we record/report on memory utilization?',
87 )
88
89 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
90
91
92 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
93     """
94     Top-level exception handler for exceptions that make it past any exception
95     handlers in the python code being run.  Logs the error and stacktrace then
96     maybe attaches a debugger.
97
98     """
99     msg = f'Unhandled top level exception {exc_type}'
100     logger.exception(msg)
101     print(msg, file=sys.stderr)
102     if issubclass(exc_type, KeyboardInterrupt):
103         sys.__excepthook__(exc_type, exc_value, exc_tb)
104         return
105     else:
106         import io
107         import traceback
108
109         tb_output = io.StringIO()
110         traceback.print_tb(exc_tb, None, tb_output)
111         print(tb_output.getvalue(), file=sys.stderr)
112         logger.error(tb_output.getvalue())
113         tb_output.close()
114
115         # stdin or stderr is redirected, just do the normal thing
116         if not sys.stderr.isatty() or not sys.stdin.isatty():
117             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
118
119         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
120             if config.config['debug_unhandled_exceptions']:
121                 logger.info("Invoking the debugger...")
122                 import pdb
123
124                 pdb.pm()
125             else:
126                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
127
128
129 class ImportInterceptor(importlib.abc.MetaPathFinder):
130     """An interceptor that always allows module load events but dumps a
131     record into the log and onto stdout when modules are loaded and
132     produces an audit of who imported what at the end of the run.  It
133     can't see any load events that happen before it, though, so move
134     bootstrap up in your __main__'s import list just temporarily to
135     get a good view.
136
137     """
138
139     def __init__(self):
140         from pyutils.collectionz.trie import Trie
141
142         self.module_by_filename_cache = {}
143         self.repopulate_modules_by_filename()
144         self.tree = Trie()
145         self.tree_node_by_module = {}
146
147     def repopulate_modules_by_filename(self):
148         self.module_by_filename_cache.clear()
149         for (
150             _,
151             mod,
152         ) in sys.modules.copy().items():  # copy here because modules is volatile
153             if hasattr(mod, '__file__'):
154                 fname = getattr(mod, '__file__')
155             else:
156                 fname = 'unknown'
157             self.module_by_filename_cache[fname] = mod
158
159     @staticmethod
160     def should_ignore_filename(filename: str) -> bool:
161         return 'importlib' in filename or 'six.py' in filename
162
163     def find_module(self, fullname, path):
164         raise Exception(
165             "This method has been deprecated since Python 3.4, please upgrade."
166         )
167
168     def find_spec(self, loaded_module, path=None, _=None):
169         s = stack()
170         for x in range(3, len(s)):
171             filename = s[x].filename
172             if ImportInterceptor.should_ignore_filename(filename):
173                 continue
174
175             loading_function = s[x].function
176             if filename in self.module_by_filename_cache:
177                 loading_module = self.module_by_filename_cache[filename]
178             else:
179                 self.repopulate_modules_by_filename()
180                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
181
182             path = self.tree_node_by_module.get(loading_module, [])
183             path.extend([loaded_module])
184             self.tree.insert(path)
185             self.tree_node_by_module[loading_module] = path
186
187             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
188             logger.debug(msg)
189             print(msg)
190             return
191         msg = f'*** Import {loaded_module} from ?????'
192         logger.debug(msg)
193         print(msg)
194
195     def invalidate_caches(self):
196         pass
197
198     def find_importer(self, module: str):
199         if module in self.tree_node_by_module:
200             node = self.tree_node_by_module[module]
201             return node
202         return []
203
204
205 # Audit import events?  Note: this runs early in the lifetime of the
206 # process (assuming that import bootstrap happens early); config has
207 # (probably) not yet been loaded or parsed the commandline.  Also,
208 # some things have probably already been imported while we weren't
209 # watching so this information may be incomplete.
210 #
211 # Also note: move bootstrap up in the global import list to catch
212 # more import events and have a more complete record.
213 IMPORT_INTERCEPTOR = None
214 for arg in sys.argv:
215     if arg == '--audit_import_events':
216         IMPORT_INTERCEPTOR = ImportInterceptor()
217         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
218
219
220 def dump_all_objects() -> None:
221     """Helper code to dump all known python objects."""
222
223     messages = {}
224     all_modules = sys.modules
225     for obj in object.__subclasses__():
226         if not hasattr(obj, '__name__'):
227             continue
228         klass = obj.__name__
229         if not hasattr(obj, '__module__'):
230             continue
231         class_mod_name = obj.__module__
232         if class_mod_name in all_modules:
233             mod = all_modules[class_mod_name]
234             if not hasattr(mod, '__name__'):
235                 mod_name = class_mod_name
236             else:
237                 mod_name = mod.__name__
238             if hasattr(mod, '__file__'):
239                 mod_file = mod.__file__
240             else:
241                 mod_file = 'unknown'
242             if IMPORT_INTERCEPTOR is not None:
243                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
244             else:
245                 import_path = 'unknown'
246             msg = f'{class_mod_name}::{klass} ({mod_file})'
247             if import_path != 'unknown' and len(import_path) > 0:
248                 msg += f' imported by {import_path}'
249             messages[f'{class_mod_name}::{klass}'] = msg
250     for x in sorted(messages.keys()):
251         logger.debug(messages[x])
252         print(messages[x])
253
254
255 def initialize(entry_point):
256     """
257     Remember to initialize config, initialize logging, set/log a random
258     seed, etc... before running main.  If you use this decorator around
259     your main, like this::
260
261         from pyutils import bootstrap
262
263         @bootstrap.initialize
264         def main():
265             whatever
266
267         if __name__ == '__main__':
268             main()
269
270     You get:
271
272     * The ability to break into pdb on unhandled exceptions,
273     * automatic support for :file:`config.py` (argument parsing)
274     * automatic logging support for :file:`logging.py`,
275     * the ability to enable code profiling,
276     * the ability to enable module import auditing,
277     * optional memory profiling for your program,
278     * ability to set random seed via commandline,
279     * automatic program timing and reporting,
280     * more verbose error handling and reporting,
281
282     Most of these are enabled and/or configured via commandline flags
283     (see below).
284     """
285
286     @functools.wraps(entry_point)
287     def initialize_wrapper(*args, **kwargs):
288         # Hook top level unhandled exceptions, maybe invoke debugger.
289         if sys.excepthook == sys.__excepthook__:
290             sys.excepthook = handle_uncaught_exception
291
292         # Try to figure out the name of the program entry point.  Then
293         # parse configuration (based on cmdline flags, environment vars
294         # etc...)
295         entry_filename = None
296         entry_descr = None
297         try:
298             entry_filename = entry_point.__code__.co_filename
299             entry_descr = entry_point.__code__.__repr__()
300         except Exception:
301             if (
302                 '__globals__' in entry_point.__dict__
303                 and '__file__' in entry_point.__globals__
304             ):
305                 entry_filename = entry_point.__globals__['__file__']
306                 entry_descr = entry_filename
307         config.parse(entry_filename)
308
309         if config.config['trace_memory']:
310             import tracemalloc
311
312             tracemalloc.start()
313
314         # Initialize logging... and log some remembered messages from
315         # config module.
316         logging_utils.initialize_logging(logging.getLogger())
317         config.late_logging()
318
319         # Maybe log some info about the python interpreter itself.
320         logger.debug(
321             'Platform: %s, maxint=0x%x, byteorder=%s',
322             sys.platform,
323             sys.maxsize,
324             sys.byteorder,
325         )
326         logger.debug('Python interpreter version: %s', sys.version)
327         logger.debug('Python implementation: %s', sys.implementation)
328         logger.debug('Python C API version: %s', sys.api_version)
329         if __debug__:
330             logger.debug('Python interpreter running in __debug__ mode.')
331         else:
332             logger.debug('Python interpreter running in optimized mode.')
333         logger.debug('Python path: %s', sys.path)
334
335         # Allow programs that don't bother to override the random seed
336         # to be replayed via the commandline.
337         import random
338
339         random_seed = config.config['set_random_seed']
340         if random_seed is not None:
341             random_seed = random_seed[0]
342         else:
343             random_seed = int.from_bytes(os.urandom(4), 'little')
344
345         if config.config['show_random_seed']:
346             msg = f'Global random seed is: {random_seed}'
347             logger.debug(msg)
348             print(msg)
349         random.seed(random_seed)
350
351         # Do it, invoke the user's code.  Pay attention to how long it takes.
352         logger.debug('Starting %s (program entry point)', entry_descr)
353         ret = None
354         from pyutils import stopwatch
355
356         if config.config['run_profiler']:
357             import cProfile
358             from pstats import SortKey
359
360             with stopwatch.Timer() as t:
361                 cProfile.runctx(
362                     "ret = entry_point(*args, **kwargs)",
363                     globals(),
364                     locals(),
365                     None,
366                     SortKey.CUMULATIVE,
367                 )
368         else:
369             with stopwatch.Timer() as t:
370                 ret = entry_point(*args, **kwargs)
371
372         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
373
374         if config.config['trace_memory']:
375             snapshot = tracemalloc.take_snapshot()
376             top_stats = snapshot.statistics('lineno')
377             print()
378             print("--trace_memory's top 10 memory using files:")
379             for stat in top_stats[:10]:
380                 print(stat)
381
382         if config.config['dump_all_objects']:
383             dump_all_objects()
384
385         if config.config['audit_import_events']:
386             if IMPORT_INTERCEPTOR is not None:
387                 print(IMPORT_INTERCEPTOR.tree)
388
389         walltime = t()
390         (utime, stime, cutime, cstime, elapsed_time) = os.times()
391         logger.debug(
392             '\n'
393             'user: %.4fs\n'
394             'system: %.4fs\n'
395             'child user: %.4fs\n'
396             'child system: %.4fs\n'
397             'machine uptime: %.4fs\n'
398             'walltime: %.4fs',
399             utime,
400             stime,
401             cutime,
402             cstime,
403             elapsed_time,
404             walltime,
405         )
406
407         # If it doesn't return cleanly, call attention to the return value.
408         if ret is not None and ret != 0:
409             logger.error('Exit %s', ret)
410         else:
411             logger.debug('Exit %s', ret)
412         sys.exit(ret)
413
414     return initialize_wrapper