More documentation changes but includes a change to config.py that
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag `--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag `--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       `--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       `--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the `--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 from inspect import stack
47
48 from pyutils import config, logging_utils
49 from pyutils.argparse_utils import ActionNoYes
50
51 # This module is commonly used by others in here and should avoid
52 # taking any unnecessary dependencies back on them.
53
54
55 logger = logging.getLogger(__name__)
56
57 cfg = config.add_commandline_args(
58     f'Bootstrap ({__file__})',
59     'Args related to python program bootstrapper and Swiss army knife',
60 )
61 cfg.add_argument(
62     '--debug_unhandled_exceptions',
63     action=ActionNoYes,
64     default=False,
65     help='Break into pdb on top level unhandled exceptions.',
66 )
67 cfg.add_argument(
68     '--show_random_seed',
69     action=ActionNoYes,
70     default=False,
71     help='Should we display (and log.debug) the global random seed?',
72 )
73 cfg.add_argument(
74     '--set_random_seed',
75     type=int,
76     nargs=1,
77     default=None,
78     metavar='SEED_INT',
79     help='Override the global random seed with a particular number.',
80 )
81 cfg.add_argument(
82     '--dump_all_objects',
83     action=ActionNoYes,
84     default=False,
85     help='Should we dump the Python import tree before main?',
86 )
87 cfg.add_argument(
88     '--audit_import_events',
89     action=ActionNoYes,
90     default=False,
91     help='Should we audit all import events?',
92 )
93 cfg.add_argument(
94     '--run_profiler',
95     action=ActionNoYes,
96     default=False,
97     help='Should we run cProfile on this code?',
98 )
99 cfg.add_argument(
100     '--trace_memory',
101     action=ActionNoYes,
102     default=False,
103     help='Should we record/report on memory utilization?',
104 )
105
106 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
107
108
109 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
110     """
111     Top-level exception handler for exceptions that make it past any exception
112     handlers in the python code being run.  Logs the error and stacktrace then
113     maybe attaches a debugger.
114
115     """
116     msg = f'Unhandled top level exception {exc_type}'
117     logger.exception(msg)
118     print(msg, file=sys.stderr)
119     if issubclass(exc_type, KeyboardInterrupt):
120         sys.__excepthook__(exc_type, exc_value, exc_tb)
121         return
122     else:
123         import io
124         import traceback
125
126         tb_output = io.StringIO()
127         traceback.print_tb(exc_tb, None, tb_output)
128         print(tb_output.getvalue(), file=sys.stderr)
129         logger.error(tb_output.getvalue())
130         tb_output.close()
131
132         # stdin or stderr is redirected, just do the normal thing
133         if not sys.stderr.isatty() or not sys.stdin.isatty():
134             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
135
136         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
137             if config.config['debug_unhandled_exceptions']:
138                 logger.info("Invoking the debugger...")
139                 import pdb
140
141                 pdb.pm()
142             else:
143                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
144
145
146 class ImportInterceptor(importlib.abc.MetaPathFinder):
147     """An interceptor that always allows module load events but dumps a
148     record into the log and onto stdout when modules are loaded and
149     produces an audit of who imported what at the end of the run.  It
150     can't see any load events that happen before it, though, so move
151     bootstrap up in your __main__'s import list just temporarily to
152     get a good view.
153
154     """
155
156     def __init__(self):
157         from pyutils.collectionz.trie import Trie
158
159         self.module_by_filename_cache = {}
160         self.repopulate_modules_by_filename()
161         self.tree = Trie()
162         self.tree_node_by_module = {}
163
164     def repopulate_modules_by_filename(self):
165         self.module_by_filename_cache.clear()
166         for (
167             _,
168             mod,
169         ) in sys.modules.copy().items():  # copy here because modules is volatile
170             if hasattr(mod, '__file__'):
171                 fname = getattr(mod, '__file__')
172             else:
173                 fname = 'unknown'
174             self.module_by_filename_cache[fname] = mod
175
176     @staticmethod
177     def should_ignore_filename(filename: str) -> bool:
178         return 'importlib' in filename or 'six.py' in filename
179
180     def find_module(self, fullname, path):
181         raise Exception(
182             "This method has been deprecated since Python 3.4, please upgrade."
183         )
184
185     def find_spec(self, loaded_module, path=None, _=None):
186         s = stack()
187         for x in range(3, len(s)):
188             filename = s[x].filename
189             if ImportInterceptor.should_ignore_filename(filename):
190                 continue
191
192             loading_function = s[x].function
193             if filename in self.module_by_filename_cache:
194                 loading_module = self.module_by_filename_cache[filename]
195             else:
196                 self.repopulate_modules_by_filename()
197                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
198
199             path = self.tree_node_by_module.get(loading_module, [])
200             path.extend([loaded_module])
201             self.tree.insert(path)
202             self.tree_node_by_module[loading_module] = path
203
204             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
205             logger.debug(msg)
206             print(msg)
207             return
208         msg = f'*** Import {loaded_module} from ?????'
209         logger.debug(msg)
210         print(msg)
211
212     def invalidate_caches(self):
213         pass
214
215     def find_importer(self, module: str):
216         if module in self.tree_node_by_module:
217             node = self.tree_node_by_module[module]
218             return node
219         return []
220
221
222 # Audit import events?  Note: this runs early in the lifetime of the
223 # process (assuming that import bootstrap happens early); config has
224 # (probably) not yet been loaded or parsed the commandline.  Also,
225 # some things have probably already been imported while we weren't
226 # watching so this information may be incomplete.
227 #
228 # Also note: move bootstrap up in the global import list to catch
229 # more import events and have a more complete record.
230 IMPORT_INTERCEPTOR = None
231 for arg in sys.argv:
232     if arg == '--audit_import_events':
233         IMPORT_INTERCEPTOR = ImportInterceptor()
234         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
235
236
237 def dump_all_objects() -> None:
238     """Helper code to dump all known python objects."""
239
240     messages = {}
241     all_modules = sys.modules
242     for obj in object.__subclasses__():
243         if not hasattr(obj, '__name__'):
244             continue
245         klass = obj.__name__
246         if not hasattr(obj, '__module__'):
247             continue
248         class_mod_name = obj.__module__
249         if class_mod_name in all_modules:
250             mod = all_modules[class_mod_name]
251             if not hasattr(mod, '__name__'):
252                 mod_name = class_mod_name
253             else:
254                 mod_name = mod.__name__
255             if hasattr(mod, '__file__'):
256                 mod_file = mod.__file__
257             else:
258                 mod_file = 'unknown'
259             if IMPORT_INTERCEPTOR is not None:
260                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
261             else:
262                 import_path = 'unknown'
263             msg = f'{class_mod_name}::{klass} ({mod_file})'
264             if import_path != 'unknown' and len(import_path) > 0:
265                 msg += f' imported by {import_path}'
266             messages[f'{class_mod_name}::{klass}'] = msg
267     for x in sorted(messages.keys()):
268         logger.debug(messages[x])
269         print(messages[x])
270
271
272 def initialize(entry_point):
273     """
274     Do whole program setup and instrumentation.  See module comments for
275     details.  To use::
276
277         from pyutils import bootstrap
278
279         @bootstrap.initialize
280         def main():
281             whatever
282
283         if __name__ == '__main__':
284             main()
285     """
286
287     @functools.wraps(entry_point)
288     def initialize_wrapper(*args, **kwargs):
289         # Hook top level unhandled exceptions, maybe invoke debugger.
290         if sys.excepthook == sys.__excepthook__:
291             sys.excepthook = handle_uncaught_exception
292
293         # Try to figure out the name of the program entry point.  Then
294         # parse configuration (based on cmdline flags, environment vars
295         # etc...)
296         entry_filename = None
297         entry_descr = None
298         try:
299             entry_filename = entry_point.__code__.co_filename
300             entry_descr = entry_point.__code__.__repr__()
301         except Exception:
302             if (
303                 '__globals__' in entry_point.__dict__
304                 and '__file__' in entry_point.__globals__
305             ):
306                 entry_filename = entry_point.__globals__['__file__']
307                 entry_descr = entry_filename
308         config.parse(entry_filename)
309
310         if config.config['trace_memory']:
311             import tracemalloc
312
313             tracemalloc.start()
314
315         # Initialize logging... and log some remembered messages from
316         # config module.
317         logging_utils.initialize_logging(logging.getLogger())
318         config.late_logging()
319
320         # Maybe log some info about the python interpreter itself.
321         logger.debug(
322             'Platform: %s, maxint=0x%x, byteorder=%s',
323             sys.platform,
324             sys.maxsize,
325             sys.byteorder,
326         )
327         logger.debug('Python interpreter version: %s', sys.version)
328         logger.debug('Python implementation: %s', sys.implementation)
329         logger.debug('Python C API version: %s', sys.api_version)
330         if __debug__:
331             logger.debug('Python interpreter running in __debug__ mode.')
332         else:
333             logger.debug('Python interpreter running in optimized mode.')
334         logger.debug('Python path: %s', sys.path)
335
336         # Allow programs that don't bother to override the random seed
337         # to be replayed via the commandline.
338         import random
339
340         random_seed = config.config['set_random_seed']
341         if random_seed is not None:
342             random_seed = random_seed[0]
343         else:
344             random_seed = int.from_bytes(os.urandom(4), 'little')
345
346         if config.config['show_random_seed']:
347             msg = f'Global random seed is: {random_seed}'
348             logger.debug(msg)
349             print(msg)
350         random.seed(random_seed)
351
352         # Do it, invoke the user's code.  Pay attention to how long it takes.
353         logger.debug('Starting %s (program entry point)', entry_descr)
354         ret = None
355         from pyutils import stopwatch
356
357         if config.config['run_profiler']:
358             import cProfile
359             from pstats import SortKey
360
361             with stopwatch.Timer() as t:
362                 cProfile.runctx(
363                     "ret = entry_point(*args, **kwargs)",
364                     globals(),
365                     locals(),
366                     None,
367                     SortKey.CUMULATIVE,
368                 )
369         else:
370             with stopwatch.Timer() as t:
371                 ret = entry_point(*args, **kwargs)
372
373         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
374
375         if config.config['trace_memory']:
376             snapshot = tracemalloc.take_snapshot()
377             top_stats = snapshot.statistics('lineno')
378             print()
379             print("--trace_memory's top 10 memory using files:")
380             for stat in top_stats[:10]:
381                 print(stat)
382
383         if config.config['dump_all_objects']:
384             dump_all_objects()
385
386         if config.config['audit_import_events']:
387             if IMPORT_INTERCEPTOR is not None:
388                 print(IMPORT_INTERCEPTOR.tree)
389
390         walltime = t()
391         (utime, stime, cutime, cstime, elapsed_time) = os.times()
392         logger.debug(
393             '\n'
394             'user: %.4fs\n'
395             'system: %.4fs\n'
396             'child user: %.4fs\n'
397             'child system: %.4fs\n'
398             'machine uptime: %.4fs\n'
399             'walltime: %.4fs',
400             utime,
401             stime,
402             cutime,
403             cstime,
404             elapsed_time,
405             walltime,
406         )
407
408         # If it doesn't return cleanly, call attention to the return value.
409         if ret is not None and ret != 0:
410             logger.error('Exit %s', ret)
411         else:
412             logger.debug('Exit %s', ret)
413         sys.exit(ret)
414
415     return initialize_wrapper