Include filename on exit.
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag :code:`--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging_utils` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag :code:`--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       :code:`--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       :code:`--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the :code:`--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 import uuid
47 from inspect import stack
48 from typing import NoReturn
49
50 from pyutils import config, logging_utils
51 from pyutils.argparse_utils import ActionNoYes
52
53 # This module is commonly used by others in here and should avoid
54 # taking any unnecessary dependencies back on them.
55
56
57 logger = logging.getLogger(__name__)
58
59 cfg = config.add_commandline_args(
60     f"Bootstrap ({__file__})",
61     "Args related to python program bootstrapper and Swiss army knife",
62 )
63 cfg.add_argument(
64     "--debug_unhandled_exceptions",
65     action=ActionNoYes,
66     default=False,
67     help="Break into pdb on top level unhandled exceptions.",
68 )
69 cfg.add_argument(
70     "--show_random_seed",
71     action=ActionNoYes,
72     default=False,
73     help="Should we display (and log.debug) the global random seed?",
74 )
75 cfg.add_argument(
76     "--set_random_seed",
77     type=int,
78     nargs=1,
79     default=None,
80     metavar="SEED_INT",
81     help="Override the global random seed with a particular number.",
82 )
83 cfg.add_argument(
84     "--dump_all_objects",
85     action=ActionNoYes,
86     default=False,
87     help="Should we dump the Python import tree before main?",
88 )
89 cfg.add_argument(
90     "--audit_import_events",
91     action=ActionNoYes,
92     default=False,
93     help="Should we audit all import events?",
94 )
95 cfg.add_argument(
96     "--run_profiler",
97     action=ActionNoYes,
98     default=False,
99     help="Should we run cProfile on this code?",
100 )
101 cfg.add_argument(
102     "--trace_memory",
103     action=ActionNoYes,
104     default=False,
105     help="Should we record/report on memory utilization?",
106 )
107
108 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
109
110
111 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
112     """
113     Top-level exception handler for exceptions that make it past any exception
114     handlers in the python code being run.  Logs the error and stacktrace then
115     maybe attaches a debugger.
116
117     """
118     msg = f"Unhandled top level exception {exc_type}"
119     logger.exception(msg)
120     print(msg, file=sys.stderr)
121     if issubclass(exc_type, KeyboardInterrupt):
122         sys.__excepthook__(exc_type, exc_value, exc_tb)
123         return
124     else:
125         import io
126         import traceback
127
128         tb_output = io.StringIO()
129         traceback.print_tb(exc_tb, None, tb_output)
130         print(tb_output.getvalue(), file=sys.stderr)
131         logger.error(tb_output.getvalue())
132         tb_output.close()
133
134         # stdin or stderr is redirected, just do the normal thing
135         if not sys.stderr.isatty() or not sys.stdin.isatty():
136             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
137
138         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
139             if config.config["debug_unhandled_exceptions"]:
140                 logger.info("Invoking the debugger...")
141                 import pdb
142
143                 pdb.pm()
144             else:
145                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
146
147
148 class ImportInterceptor(importlib.abc.MetaPathFinder):
149     """An interceptor that always allows module load events but dumps a
150     record into the log and onto stdout when modules are loaded and
151     produces an audit of who imported what at the end of the run.  It
152     can't see any load events that happen before it, though, so move
153     bootstrap up in your __main__'s import list just temporarily to
154     get a good view.
155
156     """
157
158     def __init__(self):
159         from pyutils.collectionz.trie import Trie
160
161         self.module_by_filename_cache = {}
162         self.repopulate_modules_by_filename()
163         self.tree = Trie()
164         self.tree_node_by_module = {}
165
166     def repopulate_modules_by_filename(self):
167         self.module_by_filename_cache.clear()
168         for (
169             _,
170             mod,
171         ) in sys.modules.copy().items():  # copy here because modules is volatile
172             if hasattr(mod, "__file__"):
173                 fname = getattr(mod, "__file__")
174             else:
175                 fname = "unknown"
176             self.module_by_filename_cache[fname] = mod
177
178     @staticmethod
179     def should_ignore_filename(filename: str) -> bool:
180         return "importlib" in filename or "six.py" in filename
181
182     def find_module(self, fullname, path) -> NoReturn:
183         raise Exception(
184             "This method has been deprecated since Python 3.4, please upgrade."
185         )
186
187     def find_spec(self, loaded_module, path=None, _=None):
188         s = stack()
189         for x in range(3, len(s)):
190             filename = s[x].filename
191             if ImportInterceptor.should_ignore_filename(filename):
192                 continue
193
194             loading_function = s[x].function
195             if filename in self.module_by_filename_cache:
196                 loading_module = self.module_by_filename_cache[filename]
197             else:
198                 self.repopulate_modules_by_filename()
199                 loading_module = self.module_by_filename_cache.get(filename, "unknown")
200
201             path = self.tree_node_by_module.get(loading_module, [])
202             path.extend([loaded_module])
203             self.tree.insert(path)
204             self.tree_node_by_module[loading_module] = path
205
206             msg = f"*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}"
207             logger.debug(msg)
208             print(msg)
209             return
210         msg = f"*** Import {loaded_module} from ?????"
211         logger.debug(msg)
212         print(msg)
213
214     def invalidate_caches(self):
215         pass
216
217     def find_importer(self, module: str):
218         if module in self.tree_node_by_module:
219             node = self.tree_node_by_module[module]
220             return node
221         return []
222
223
224 # Audit import events?  Note: this runs early in the lifetime of the
225 # process (assuming that import bootstrap happens early); config has
226 # (probably) not yet been loaded or parsed the commandline.  Also,
227 # some things have probably already been imported while we weren't
228 # watching so this information may be incomplete.
229 #
230 # Also note: move bootstrap up in the global import list to catch
231 # more import events and have a more complete record.
232 IMPORT_INTERCEPTOR = None
233 for arg in sys.argv:
234     if arg == "--audit_import_events":
235         IMPORT_INTERCEPTOR = ImportInterceptor()
236         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
237
238
239 def dump_all_objects() -> None:
240     """Helper code to dump all known python objects."""
241
242     messages = {}
243     all_modules = sys.modules
244     for obj in object.__subclasses__():
245         if not hasattr(obj, "__name__"):
246             continue
247         klass = obj.__name__
248         if not hasattr(obj, "__module__"):
249             continue
250         class_mod_name = obj.__module__
251         if class_mod_name in all_modules:
252             mod = all_modules[class_mod_name]
253             if not hasattr(mod, "__name__"):
254                 mod_name = class_mod_name
255             else:
256                 mod_name = mod.__name__
257             if hasattr(mod, "__file__"):
258                 mod_file = mod.__file__
259             else:
260                 mod_file = "unknown"
261             if IMPORT_INTERCEPTOR is not None:
262                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
263             else:
264                 import_path = "unknown"
265             msg = f"{class_mod_name}::{klass} ({mod_file})"
266             if import_path != "unknown" and len(import_path) > 0:
267                 msg += f" imported by {import_path}"
268             messages[f"{class_mod_name}::{klass}"] = msg
269     for x in sorted(messages.keys()):
270         logger.debug(messages[x])
271         print(messages[x])
272
273
274 def initialize(entry_point):
275     """
276     Do whole program setup and instrumentation.  See module comments for
277     details.  To use::
278
279         from pyutils import bootstrap
280
281         @bootstrap.initialize
282         def main():
283             whatever
284
285         if __name__ == '__main__':
286             main()
287     """
288
289     @functools.wraps(entry_point)
290     def initialize_wrapper(*args, **kwargs):
291         # Hook top level unhandled exceptions, maybe invoke debugger.
292         if sys.excepthook == sys.__excepthook__:
293             sys.excepthook = handle_uncaught_exception
294
295         # Try to figure out the name of the program entry point.  Then
296         # parse configuration (based on cmdline flags, environment vars
297         # etc...)
298         entry_filename = None
299         entry_descr = None
300         try:
301             entry_filename = entry_point.__code__.co_filename
302             entry_descr = repr(entry_point.__code__)
303         except Exception:
304             if (
305                 "__globals__" in entry_point.__dict__
306                 and "__file__" in entry_point.__globals__
307             ):
308                 entry_filename = entry_point.__globals__["__file__"]
309                 entry_descr = entry_filename
310         if not entry_filename:
311             entry_filename = 'UNKNOWN'
312         config.parse(entry_filename)
313
314         if config.config["trace_memory"]:
315             import tracemalloc
316
317             tracemalloc.start()
318
319         # Initialize logging... and log some remembered messages from
320         # config module.  Also logs about the logging config if we're
321         # in debug mode.
322         logging_utils.initialize_logging(logging.getLogger())
323         config.late_logging()
324
325         # Log some info about the python interpreter itself if we're
326         # in debug mode.
327         logger.debug(
328             "Platform: %s, maxint=0x%x, byteorder=%s",
329             sys.platform,
330             sys.maxsize,
331             sys.byteorder,
332         )
333         logger.debug("Python interpreter path: %s", sys.executable)
334         logger.debug("Python interpreter version: %s", sys.version)
335         logger.debug("Python implementation: %s", sys.implementation)
336         logger.debug("Python C API version: %s", sys.api_version)
337         if __debug__:
338             logger.debug("Python interpreter running in __debug__ mode.")
339         else:
340             logger.debug("Python interpreter running in optimized mode.")
341         logger.debug("PYTHONPATH: %s", sys.path)
342
343         # Dump some info about the physical machine we're running on
344         # if we're ing debug mode.
345         if "SC_PAGE_SIZE" in os.sysconf_names and "SC_PHYS_PAGES" in os.sysconf_names:
346             logger.debug(
347                 "Physical memory: %.1fGb",
348                 os.sysconf("SC_PAGE_SIZE")
349                 * os.sysconf("SC_PHYS_PAGES")
350                 / float(1024**3),
351             )
352         logger.debug("Logical processors: %s", os.cpu_count())
353
354         # Allow programs that don't bother to override the random seed
355         # to be replayed via the commandline.
356         import random
357
358         random_seed = config.config["set_random_seed"]
359         if random_seed is not None:
360             random_seed = random_seed[0]
361         else:
362             random_seed = int.from_bytes(os.urandom(4), "little")
363         if config.config["show_random_seed"]:
364             msg = f"Global random seed is: {random_seed}"
365             logger.debug(msg)
366             print(msg)
367         random.seed(random_seed)
368
369         # Give each run a unique identifier if we're in debug mode.
370         logger.debug("This run's UUID: %s", str(uuid.uuid4()))
371
372         # Do it, invoke the user's code.  Pay attention to how long it takes.
373         logger.debug(
374             "Starting %s (program entry point) ---------------------- ", entry_descr
375         )
376         ret = None
377         from pyutils import stopwatch
378
379         if config.config["run_profiler"]:
380             import cProfile
381             from pstats import SortKey
382
383             with stopwatch.Timer() as t:
384                 cProfile.runctx(
385                     "ret = entry_point(*args, **kwargs)",
386                     globals(),
387                     locals(),
388                     None,
389                     SortKey.CUMULATIVE,
390                 )
391         else:
392             with stopwatch.Timer() as t:
393                 ret = entry_point(*args, **kwargs)
394
395         logger.debug("%s (program entry point) returned %s.", entry_descr, ret)
396
397         if config.config["trace_memory"]:
398             snapshot = tracemalloc.take_snapshot()
399             top_stats = snapshot.statistics("lineno")
400             print()
401             print("--trace_memory's top 10 memory using files:")
402             for stat in top_stats[:10]:
403                 print(stat)
404
405         if config.config["dump_all_objects"]:
406             dump_all_objects()
407
408         if config.config["audit_import_events"]:
409             if IMPORT_INTERCEPTOR is not None:
410                 print(IMPORT_INTERCEPTOR.tree)
411
412         walltime = t()
413         (utime, stime, cutime, cstime, elapsed_time) = os.times()
414         logger.debug(
415             "\n"
416             "user: %.4fs\n"
417             "system: %.4fs\n"
418             "child user: %.4fs\n"
419             "child system: %.4fs\n"
420             "machine uptime: %.4fs\n"
421             "walltime: %.4fs",
422             utime,
423             stime,
424             cutime,
425             cstime,
426             elapsed_time,
427             walltime,
428         )
429
430         # If it doesn't return cleanly, call attention to the return value.
431         base_filename = os.path.basename(entry_filename)
432         if ret is not None and ret != 0:
433             logger.error("%s: Exit %s", base_filename, ret)
434         else:
435             logger.debug("%s: Exit %s", base_filename, ret)
436         sys.exit(ret)
437
438     return initialize_wrapper