8dc4a0c0886b67b8fd8dad7835d51b468690566f
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag `--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag `--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       `--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       `--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the `--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 from inspect import stack
47
48 from pyutils import config, logging_utils
49 from pyutils.argparse_utils import ActionNoYes
50
51 # This module is commonly used by others in here and should avoid
52 # taking any unnecessary dependencies back on them.
53
54
55 logger = logging.getLogger(__name__)
56
57 cfg = config.add_commandline_args(
58     f'Bootstrap ({__file__})',
59     'Args related to python program bootstrapper and Swiss army knife',
60 )
61 cfg.add_argument(
62     '--debug_unhandled_exceptions',
63     action=ActionNoYes,
64     default=False,
65     help='Break into pdb on top level unhandled exceptions.',
66 )
67 cfg.add_argument(
68     '--show_random_seed',
69     action=ActionNoYes,
70     default=False,
71     help='Should we display (and log.debug) the global random seed?',
72 )
73 cfg.add_argument(
74     '--set_random_seed',
75     type=int,
76     nargs=1,
77     default=None,
78     metavar='SEED_INT',
79     help='Override the global random seed with a particular number.',
80 )
81 cfg.add_argument(
82     '--dump_all_objects',
83     action=ActionNoYes,
84     default=False,
85     help='Should we dump the Python import tree before main?',
86 )
87 cfg.add_argument(
88     '--audit_import_events',
89     action=ActionNoYes,
90     default=False,
91     help='Should we audit all import events?',
92 )
93 cfg.add_argument(
94     '--run_profiler',
95     action=ActionNoYes,
96     default=False,
97     help='Should we run cProfile on this code?',
98 )
99 cfg.add_argument(
100     '--trace_memory',
101     action=ActionNoYes,
102     default=False,
103     help='Should we record/report on memory utilization?',
104 )
105
106 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
107
108
109 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
110     """
111     Top-level exception handler for exceptions that make it past any exception
112     handlers in the python code being run.  Logs the error and stacktrace then
113     maybe attaches a debugger.
114
115     """
116     msg = f'Unhandled top level exception {exc_type}'
117     logger.exception(msg)
118     print(msg, file=sys.stderr)
119     if issubclass(exc_type, KeyboardInterrupt):
120         sys.__excepthook__(exc_type, exc_value, exc_tb)
121         return
122     else:
123         import io
124         import traceback
125
126         tb_output = io.StringIO()
127         traceback.print_tb(exc_tb, None, tb_output)
128         print(tb_output.getvalue(), file=sys.stderr)
129         logger.error(tb_output.getvalue())
130         tb_output.close()
131
132         # stdin or stderr is redirected, just do the normal thing
133         if not sys.stderr.isatty() or not sys.stdin.isatty():
134             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
135
136         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
137             if config.config['debug_unhandled_exceptions']:
138                 logger.info("Invoking the debugger...")
139                 import pdb
140
141                 pdb.pm()
142             else:
143                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
144
145
146 class ImportInterceptor(importlib.abc.MetaPathFinder):
147     """An interceptor that always allows module load events but dumps a
148     record into the log and onto stdout when modules are loaded and
149     produces an audit of who imported what at the end of the run.  It
150     can't see any load events that happen before it, though, so move
151     bootstrap up in your __main__'s import list just temporarily to
152     get a good view.
153
154     """
155
156     def __init__(self):
157         from pyutils.collectionz.trie import Trie
158
159         self.module_by_filename_cache = {}
160         self.repopulate_modules_by_filename()
161         self.tree = Trie()
162         self.tree_node_by_module = {}
163
164     def repopulate_modules_by_filename(self):
165         self.module_by_filename_cache.clear()
166         for (
167             _,
168             mod,
169         ) in sys.modules.copy().items():  # copy here because modules is volatile
170             if hasattr(mod, '__file__'):
171                 fname = getattr(mod, '__file__')
172             else:
173                 fname = 'unknown'
174             self.module_by_filename_cache[fname] = mod
175
176     @staticmethod
177     def should_ignore_filename(filename: str) -> bool:
178         return 'importlib' in filename or 'six.py' in filename
179
180     def find_module(self, fullname, path):
181         raise Exception(
182             "This method has been deprecated since Python 3.4, please upgrade."
183         )
184
185     def find_spec(self, loaded_module, path=None, _=None):
186         s = stack()
187         for x in range(3, len(s)):
188             filename = s[x].filename
189             if ImportInterceptor.should_ignore_filename(filename):
190                 continue
191
192             loading_function = s[x].function
193             if filename in self.module_by_filename_cache:
194                 loading_module = self.module_by_filename_cache[filename]
195             else:
196                 self.repopulate_modules_by_filename()
197                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
198
199             path = self.tree_node_by_module.get(loading_module, [])
200             path.extend([loaded_module])
201             self.tree.insert(path)
202             self.tree_node_by_module[loading_module] = path
203
204             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
205             logger.debug(msg)
206             print(msg)
207             return
208         msg = f'*** Import {loaded_module} from ?????'
209         logger.debug(msg)
210         print(msg)
211
212     def invalidate_caches(self):
213         pass
214
215     def find_importer(self, module: str):
216         if module in self.tree_node_by_module:
217             node = self.tree_node_by_module[module]
218             return node
219         return []
220
221
222 # Audit import events?  Note: this runs early in the lifetime of the
223 # process (assuming that import bootstrap happens early); config has
224 # (probably) not yet been loaded or parsed the commandline.  Also,
225 # some things have probably already been imported while we weren't
226 # watching so this information may be incomplete.
227 #
228 # Also note: move bootstrap up in the global import list to catch
229 # more import events and have a more complete record.
230 IMPORT_INTERCEPTOR = None
231 for arg in sys.argv:
232     if arg == '--audit_import_events':
233         IMPORT_INTERCEPTOR = ImportInterceptor()
234         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
235
236
237 def dump_all_objects() -> None:
238     """Helper code to dump all known python objects."""
239
240     messages = {}
241     all_modules = sys.modules
242     for obj in object.__subclasses__():
243         if not hasattr(obj, '__name__'):
244             continue
245         klass = obj.__name__
246         if not hasattr(obj, '__module__'):
247             continue
248         class_mod_name = obj.__module__
249         if class_mod_name in all_modules:
250             mod = all_modules[class_mod_name]
251             if not hasattr(mod, '__name__'):
252                 mod_name = class_mod_name
253             else:
254                 mod_name = mod.__name__
255             if hasattr(mod, '__file__'):
256                 mod_file = mod.__file__
257             else:
258                 mod_file = 'unknown'
259             if IMPORT_INTERCEPTOR is not None:
260                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
261             else:
262                 import_path = 'unknown'
263             msg = f'{class_mod_name}::{klass} ({mod_file})'
264             if import_path != 'unknown' and len(import_path) > 0:
265                 msg += f' imported by {import_path}'
266             messages[f'{class_mod_name}::{klass}'] = msg
267     for x in sorted(messages.keys()):
268         logger.debug(messages[x])
269         print(messages[x])
270
271
272 def initialize(entry_point):
273     """
274     Remember to initialize config, initialize logging, set/log a random
275     seed, etc... before running main.  If you use this decorator around
276     your main, like this::
277
278         from pyutils import bootstrap
279
280         @bootstrap.initialize
281         def main():
282             whatever
283
284         if __name__ == '__main__':
285             main()
286
287     You get:
288
289     * The ability to break into pdb on unhandled exceptions,
290     * automatic support for :file:`config.py` (argument parsing)
291     * automatic logging support for :file:`logging.py`,
292     * the ability to enable code profiling,
293     * the ability to enable module import auditing,
294     * optional memory profiling for your program,
295     * ability to set random seed via commandline,
296     * automatic program timing and reporting,
297     * more verbose error handling and reporting,
298
299     Most of these are enabled and/or configured via commandline flags
300     (see below).
301     """
302
303     @functools.wraps(entry_point)
304     def initialize_wrapper(*args, **kwargs):
305         # Hook top level unhandled exceptions, maybe invoke debugger.
306         if sys.excepthook == sys.__excepthook__:
307             sys.excepthook = handle_uncaught_exception
308
309         # Try to figure out the name of the program entry point.  Then
310         # parse configuration (based on cmdline flags, environment vars
311         # etc...)
312         entry_filename = None
313         entry_descr = None
314         try:
315             entry_filename = entry_point.__code__.co_filename
316             entry_descr = entry_point.__code__.__repr__()
317         except Exception:
318             if (
319                 '__globals__' in entry_point.__dict__
320                 and '__file__' in entry_point.__globals__
321             ):
322                 entry_filename = entry_point.__globals__['__file__']
323                 entry_descr = entry_filename
324         config.parse(entry_filename)
325
326         if config.config['trace_memory']:
327             import tracemalloc
328
329             tracemalloc.start()
330
331         # Initialize logging... and log some remembered messages from
332         # config module.
333         logging_utils.initialize_logging(logging.getLogger())
334         config.late_logging()
335
336         # Maybe log some info about the python interpreter itself.
337         logger.debug(
338             'Platform: %s, maxint=0x%x, byteorder=%s',
339             sys.platform,
340             sys.maxsize,
341             sys.byteorder,
342         )
343         logger.debug('Python interpreter version: %s', sys.version)
344         logger.debug('Python implementation: %s', sys.implementation)
345         logger.debug('Python C API version: %s', sys.api_version)
346         if __debug__:
347             logger.debug('Python interpreter running in __debug__ mode.')
348         else:
349             logger.debug('Python interpreter running in optimized mode.')
350         logger.debug('Python path: %s', sys.path)
351
352         # Allow programs that don't bother to override the random seed
353         # to be replayed via the commandline.
354         import random
355
356         random_seed = config.config['set_random_seed']
357         if random_seed is not None:
358             random_seed = random_seed[0]
359         else:
360             random_seed = int.from_bytes(os.urandom(4), 'little')
361
362         if config.config['show_random_seed']:
363             msg = f'Global random seed is: {random_seed}'
364             logger.debug(msg)
365             print(msg)
366         random.seed(random_seed)
367
368         # Do it, invoke the user's code.  Pay attention to how long it takes.
369         logger.debug('Starting %s (program entry point)', entry_descr)
370         ret = None
371         from pyutils import stopwatch
372
373         if config.config['run_profiler']:
374             import cProfile
375             from pstats import SortKey
376
377             with stopwatch.Timer() as t:
378                 cProfile.runctx(
379                     "ret = entry_point(*args, **kwargs)",
380                     globals(),
381                     locals(),
382                     None,
383                     SortKey.CUMULATIVE,
384                 )
385         else:
386             with stopwatch.Timer() as t:
387                 ret = entry_point(*args, **kwargs)
388
389         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
390
391         if config.config['trace_memory']:
392             snapshot = tracemalloc.take_snapshot()
393             top_stats = snapshot.statistics('lineno')
394             print()
395             print("--trace_memory's top 10 memory using files:")
396             for stat in top_stats[:10]:
397                 print(stat)
398
399         if config.config['dump_all_objects']:
400             dump_all_objects()
401
402         if config.config['audit_import_events']:
403             if IMPORT_INTERCEPTOR is not None:
404                 print(IMPORT_INTERCEPTOR.tree)
405
406         walltime = t()
407         (utime, stime, cutime, cstime, elapsed_time) = os.times()
408         logger.debug(
409             '\n'
410             'user: %.4fs\n'
411             'system: %.4fs\n'
412             'child user: %.4fs\n'
413             'child system: %.4fs\n'
414             'machine uptime: %.4fs\n'
415             'walltime: %.4fs',
416             utime,
417             stime,
418             cutime,
419             cstime,
420             elapsed_time,
421             walltime,
422         )
423
424         # If it doesn't return cleanly, call attention to the return value.
425         if ret is not None and ret != 0:
426             logger.error('Exit %s', ret)
427         else:
428             logger.debug('Exit %s', ret)
429         sys.exit(ret)
430
431     return initialize_wrapper