Fix typo in docs
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag :code:`--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging_utils` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag :code:`--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       :code:`--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       :code:`--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the :code:`--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 import uuid
47 from inspect import stack
48 from typing import NoReturn
49
50 from pyutils import config, logging_utils
51 from pyutils.argparse_utils import ActionNoYes
52
53 # This module is commonly used by others in here and should avoid
54 # taking any unnecessary dependencies back on them.
55
56
57 logger = logging.getLogger(__name__)
58
59 cfg = config.add_commandline_args(
60     f"Bootstrap ({__file__})",
61     "Args related to python program bootstrapper and Swiss army knife",
62 )
63 cfg.add_argument(
64     "--debug_unhandled_exceptions",
65     action=ActionNoYes,
66     default=False,
67     help="Break into pdb on top level unhandled exceptions.",
68 )
69 cfg.add_argument(
70     "--show_random_seed",
71     action=ActionNoYes,
72     default=False,
73     help="Should we display (and log.debug) the global random seed?",
74 )
75 cfg.add_argument(
76     "--set_random_seed",
77     type=int,
78     nargs=1,
79     default=None,
80     metavar="SEED_INT",
81     help="Override the global random seed with a particular number.",
82 )
83 cfg.add_argument(
84     "--dump_all_objects",
85     action=ActionNoYes,
86     default=False,
87     help="Should we dump the Python import tree before main?",
88 )
89 cfg.add_argument(
90     "--audit_import_events",
91     action=ActionNoYes,
92     default=False,
93     help="Should we audit all import events?",
94 )
95 cfg.add_argument(
96     "--run_profiler",
97     action=ActionNoYes,
98     default=False,
99     help="Should we run cProfile on this code?",
100 )
101 cfg.add_argument(
102     "--trace_memory",
103     action=ActionNoYes,
104     default=False,
105     help="Should we record/report on memory utilization?",
106 )
107
108 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
109
110
111 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
112     """
113     Top-level exception handler for exceptions that make it past any exception
114     handlers in the python code being run.  Logs the error and stacktrace then
115     maybe attaches a debugger.
116
117     """
118     msg = f"Unhandled top level exception {exc_type}"
119     logger.exception(msg)
120     print(msg, file=sys.stderr)
121     try:
122         logging_utils.unhandled_top_level_exception(exc_type)
123     except Exception:
124         pass
125     if issubclass(exc_type, KeyboardInterrupt):
126         sys.__excepthook__(exc_type, exc_value, exc_tb)
127         return
128     else:
129         import io
130         import traceback
131
132         tb_output = io.StringIO()
133         traceback.print_tb(exc_tb, None, tb_output)
134         print(tb_output.getvalue(), file=sys.stderr)
135         logger.error(tb_output.getvalue())
136         tb_output.close()
137
138         # stdin or stderr is redirected, just do the normal thing
139         if not sys.stderr.isatty() or not sys.stdin.isatty():
140             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
141
142         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
143             if config.config["debug_unhandled_exceptions"]:
144                 logger.info("Invoking the debugger...")
145                 import pdb
146
147                 pdb.pm()
148             else:
149                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
150
151
152 class ImportInterceptor(importlib.abc.MetaPathFinder):
153     """An interceptor that always allows module load events but dumps a
154     record into the log and onto stdout when modules are loaded and
155     produces an audit of who imported what at the end of the run.  It
156     can't see any load events that happen before it, though, so move
157     bootstrap up in your __main__'s import list just temporarily to
158     get a good view.
159
160     """
161
162     def __init__(self):
163         from pyutils.collectionz.trie import Trie
164
165         self.module_by_filename_cache = {}
166         self.repopulate_modules_by_filename()
167         self.tree = Trie()
168         self.tree_node_by_module = {}
169
170     def repopulate_modules_by_filename(self):
171         self.module_by_filename_cache.clear()
172         for (
173             _,
174             mod,
175         ) in sys.modules.copy().items():  # copy here because modules is volatile
176             if hasattr(mod, "__file__"):
177                 fname = getattr(mod, "__file__")
178             else:
179                 fname = "unknown"
180             self.module_by_filename_cache[fname] = mod
181
182     @staticmethod
183     def should_ignore_filename(filename: str) -> bool:
184         return "importlib" in filename or "six.py" in filename
185
186     def find_module(self, fullname, path) -> NoReturn:
187         raise Exception(
188             "This method has been deprecated since Python 3.4, please upgrade."
189         )
190
191     def find_spec(self, loaded_module, path=None, _=None):
192         s = stack()
193         for x in range(3, len(s)):
194             filename = s[x].filename
195             if ImportInterceptor.should_ignore_filename(filename):
196                 continue
197
198             loading_function = s[x].function
199             if filename in self.module_by_filename_cache:
200                 loading_module = self.module_by_filename_cache[filename]
201             else:
202                 self.repopulate_modules_by_filename()
203                 loading_module = self.module_by_filename_cache.get(filename, "unknown")
204
205             path = self.tree_node_by_module.get(loading_module, [])
206             path.extend([loaded_module])
207             self.tree.insert(path)
208             self.tree_node_by_module[loading_module] = path
209
210             msg = f"*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}"
211             logger.debug(msg)
212             print(msg)
213             return
214         msg = f"*** Import {loaded_module} from ?????"
215         logger.debug(msg)
216         print(msg)
217
218     def invalidate_caches(self):
219         pass
220
221     def find_importer(self, module: str):
222         if module in self.tree_node_by_module:
223             node = self.tree_node_by_module[module]
224             return node
225         return []
226
227
228 # Audit import events?  Note: this runs early in the lifetime of the
229 # process (assuming that import bootstrap happens early); config has
230 # (probably) not yet been loaded or parsed the commandline.  Also,
231 # some things have probably already been imported while we weren't
232 # watching so this information may be incomplete.
233 #
234 # Also note: move bootstrap up in the global import list to catch
235 # more import events and have a more complete record.
236 IMPORT_INTERCEPTOR = None
237 for arg in sys.argv:
238     if arg == "--audit_import_events":
239         IMPORT_INTERCEPTOR = ImportInterceptor()
240         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
241
242
243 def dump_all_objects() -> None:
244     """Helper code to dump all known python objects."""
245
246     messages = {}
247     all_modules = sys.modules
248     for obj in object.__subclasses__():
249         if not hasattr(obj, "__name__"):
250             continue
251         klass = obj.__name__
252         if not hasattr(obj, "__module__"):
253             continue
254         class_mod_name = obj.__module__
255         if class_mod_name in all_modules:
256             mod = all_modules[class_mod_name]
257             if not hasattr(mod, "__name__"):
258                 mod_name = class_mod_name
259             else:
260                 mod_name = mod.__name__
261             if hasattr(mod, "__file__"):
262                 mod_file = mod.__file__
263             else:
264                 mod_file = "unknown"
265             if IMPORT_INTERCEPTOR is not None:
266                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
267             else:
268                 import_path = "unknown"
269             msg = f"{class_mod_name}::{klass} ({mod_file})"
270             if import_path != "unknown" and len(import_path) > 0:
271                 msg += f" imported by {import_path}"
272             messages[f"{class_mod_name}::{klass}"] = msg
273     for x in sorted(messages.keys()):
274         logger.debug(messages[x])
275         print(messages[x])
276
277
278 def initialize(entry_point):
279     """
280     Do whole program setup and instrumentation.  See module comments for
281     details.  To use::
282
283         from pyutils import bootstrap
284
285         @bootstrap.initialize
286         def main():
287             whatever
288
289         if __name__ == '__main__':
290             main()
291     """
292
293     @functools.wraps(entry_point)
294     def initialize_wrapper(*args, **kwargs):
295         # Hook top level unhandled exceptions, maybe invoke debugger.
296         if sys.excepthook == sys.__excepthook__:
297             sys.excepthook = handle_uncaught_exception
298
299         # Try to figure out the name of the program entry point.  Then
300         # parse configuration (based on cmdline flags, environment vars
301         # etc...)
302         entry_filename = None
303         entry_descr = None
304         try:
305             entry_filename = entry_point.__code__.co_filename
306             entry_descr = repr(entry_point.__code__)
307         except Exception:
308             if (
309                 "__globals__" in entry_point.__dict__
310                 and "__file__" in entry_point.__globals__
311             ):
312                 entry_filename = entry_point.__globals__["__file__"]
313                 entry_descr = entry_filename
314         if not entry_filename:
315             entry_filename = 'UNKNOWN'
316         config.parse(entry_filename)
317
318         if config.config["trace_memory"]:
319             import tracemalloc
320
321             tracemalloc.start()
322
323         # Initialize logging... and log some remembered messages from
324         # config module.  Also logs about the logging config if we're
325         # in debug mode.
326         logging_utils.initialize_logging(logging.getLogger())
327         config.late_logging()
328
329         # Log some info about the python interpreter itself if we're
330         # in debug mode.
331         logger.debug(
332             "Platform: %s, maxint=0x%x, byteorder=%s",
333             sys.platform,
334             sys.maxsize,
335             sys.byteorder,
336         )
337         logger.debug("Python interpreter path: %s", sys.executable)
338         logger.debug("Python interpreter version: %s", sys.version)
339         logger.debug("Python implementation: %s", sys.implementation)
340         logger.debug("Python C API version: %s", sys.api_version)
341         if __debug__:
342             logger.debug("Python interpreter running in __debug__ mode.")
343         else:
344             logger.debug("Python interpreter running in optimized mode.")
345         logger.debug("PYTHONPATH: %s", sys.path)
346
347         # Dump some info about the physical machine we're running on
348         # if we're ing debug mode.
349         if "SC_PAGE_SIZE" in os.sysconf_names and "SC_PHYS_PAGES" in os.sysconf_names:
350             logger.debug(
351                 "Physical memory: %.1fGb",
352                 os.sysconf("SC_PAGE_SIZE")
353                 * os.sysconf("SC_PHYS_PAGES")
354                 / float(1024**3),
355             )
356         logger.debug("Logical processors: %s", os.cpu_count())
357
358         # Allow programs that don't bother to override the random seed
359         # to be replayed via the commandline.
360         import random
361
362         random_seed = config.config["set_random_seed"]
363         if random_seed is not None:
364             random_seed = random_seed[0]
365         else:
366             random_seed = int.from_bytes(os.urandom(4), "little")
367         if config.config["show_random_seed"]:
368             msg = f"Global random seed is: {random_seed}"
369             logger.debug(msg)
370             print(msg)
371         random.seed(random_seed)
372
373         # Give each run a unique identifier if we're in debug mode.
374         logger.debug("This run's UUID: %s", str(uuid.uuid4()))
375
376         # Do it, invoke the user's code.  Pay attention to how long it takes.
377         logger.debug(
378             "Starting %s (program entry point) ---------------------- ", entry_descr
379         )
380         ret = None
381         from pyutils import stopwatch
382
383         if config.config["run_profiler"]:
384             import cProfile
385             from pstats import SortKey
386
387             with stopwatch.Timer() as t:
388                 cProfile.runctx(
389                     "ret = entry_point(*args, **kwargs)",
390                     globals(),
391                     locals(),
392                     None,
393                     SortKey.CUMULATIVE,
394                 )
395         else:
396             with stopwatch.Timer() as t:
397                 ret = entry_point(*args, **kwargs)
398
399         logger.debug("%s (program entry point) returned %s.", entry_descr, ret)
400
401         if config.config["trace_memory"]:
402             snapshot = tracemalloc.take_snapshot()
403             top_stats = snapshot.statistics("lineno")
404             print()
405             print("--trace_memory's top 10 memory using files:")
406             for stat in top_stats[:10]:
407                 print(stat)
408
409         if config.config["dump_all_objects"]:
410             dump_all_objects()
411
412         if config.config["audit_import_events"]:
413             if IMPORT_INTERCEPTOR is not None:
414                 print(IMPORT_INTERCEPTOR.tree)
415
416         walltime = t()
417         (utime, stime, cutime, cstime, elapsed_time) = os.times()
418         logger.debug(
419             "\n"
420             "user: %.4fs\n"
421             "system: %.4fs\n"
422             "child user: %.4fs\n"
423             "child system: %.4fs\n"
424             "machine uptime: %.4fs\n"
425             "walltime: %.4fs",
426             utime,
427             stime,
428             cutime,
429             cstime,
430             elapsed_time,
431             walltime,
432         )
433
434         # If it doesn't return cleanly, call attention to the return value.
435         base_filename = os.path.basename(entry_filename)
436         if ret is not None and ret != 0:
437             logger.error("%s: Exit %s", base_filename, ret)
438             logging_utils.non_zero_return_value(ret)
439         else:
440             logger.debug("%s: Exit %s", base_filename, ret)
441         sys.exit(ret)
442
443     return initialize_wrapper