Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag :code:`--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging_utils` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag :code:`--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       :code:`--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       :code:`--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the :code:`--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 import uuid
47 from inspect import stack
48 from typing import NoReturn
49
50 from pyutils import config, logging_utils
51 from pyutils.argparse_utils import ActionNoYes
52
53 # This module is commonly used by others in here and should avoid
54 # taking any unnecessary dependencies back on them.
55
56
57 logger = logging.getLogger(__name__)
58
59 cfg = config.add_commandline_args(
60     f"Bootstrap ({__file__})",
61     "Args related to python program bootstrapper and Swiss army knife",
62 )
63 cfg.add_argument(
64     "--debug_unhandled_exceptions",
65     action=ActionNoYes,
66     default=False,
67     help="Break into pdb on top level unhandled exceptions.",
68 )
69 cfg.add_argument(
70     "--show_random_seed",
71     action=ActionNoYes,
72     default=False,
73     help="Should we display (and log.debug) the global random seed?",
74 )
75 cfg.add_argument(
76     "--set_random_seed",
77     type=int,
78     nargs=1,
79     default=None,
80     metavar="SEED_INT",
81     help="Override the global random seed with a particular number.",
82 )
83 cfg.add_argument(
84     "--dump_all_objects",
85     action=ActionNoYes,
86     default=False,
87     help="Should we dump the Python import tree before main?",
88 )
89 cfg.add_argument(
90     "--audit_import_events",
91     action=ActionNoYes,
92     default=False,
93     help="Should we audit all import events?",
94 )
95 cfg.add_argument(
96     "--run_profiler",
97     action=ActionNoYes,
98     default=False,
99     help="Should we run cProfile on this code?",
100 )
101 cfg.add_argument(
102     "--trace_memory",
103     action=ActionNoYes,
104     default=False,
105     help="Should we record/report on memory utilization?",
106 )
107
108 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
109
110
111 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
112     """
113     Top-level exception handler for exceptions that make it past any exception
114     handlers in the python code being run.  Logs the error and stacktrace then
115     maybe attaches a debugger.
116
117     """
118     msg = f"Unhandled top level exception {exc_type}"
119     logger.exception(msg)
120     print(msg, file=sys.stderr)
121     logging_utils.unhandled_top_level_exception(exc_type, exc_value, exc_tb)
122     if issubclass(exc_type, KeyboardInterrupt):
123         sys.__excepthook__(exc_type, exc_value, exc_tb)
124         return
125     else:
126         import io
127         import traceback
128
129         tb_output = io.StringIO()
130         traceback.print_tb(exc_tb, None, tb_output)
131         print(tb_output.getvalue(), file=sys.stderr)
132         logger.error(tb_output.getvalue())
133         tb_output.close()
134
135         # stdin or stderr is redirected, just do the normal thing
136         if not sys.stderr.isatty() or not sys.stdin.isatty():
137             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
138
139         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
140             if config.config["debug_unhandled_exceptions"]:
141                 logger.info("Invoking the debugger...")
142                 import pdb
143
144                 pdb.pm()
145             else:
146                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
147
148
149 class ImportInterceptor(importlib.abc.MetaPathFinder):
150     """An interceptor that always allows module load events but dumps a
151     record into the log and onto stdout when modules are loaded and
152     produces an audit of who imported what at the end of the run.  It
153     can't see any load events that happen before it, though, so move
154     bootstrap up in your __main__'s import list just temporarily to
155     get a good view.
156
157     """
158
159     def __init__(self):
160         from pyutils.collectionz.trie import Trie
161
162         self.module_by_filename_cache = {}
163         self.repopulate_modules_by_filename()
164         self.tree = Trie()
165         self.tree_node_by_module = {}
166
167     def repopulate_modules_by_filename(self):
168         self.module_by_filename_cache.clear()
169         for (
170             _,
171             mod,
172         ) in sys.modules.copy().items():  # copy here because modules is volatile
173             if hasattr(mod, "__file__"):
174                 fname = getattr(mod, "__file__")
175             else:
176                 fname = "unknown"
177             self.module_by_filename_cache[fname] = mod
178
179     @staticmethod
180     def should_ignore_filename(filename: str) -> bool:
181         return "importlib" in filename or "six.py" in filename
182
183     def find_module(self, fullname, path) -> NoReturn:
184         raise Exception(
185             "This method has been deprecated since Python 3.4, please upgrade."
186         )
187
188     def find_spec(self, loaded_module, path=None, _=None):
189         s = stack()
190         for x in range(3, len(s)):
191             filename = s[x].filename
192             if ImportInterceptor.should_ignore_filename(filename):
193                 continue
194
195             loading_function = s[x].function
196             if filename in self.module_by_filename_cache:
197                 loading_module = self.module_by_filename_cache[filename]
198             else:
199                 self.repopulate_modules_by_filename()
200                 loading_module = self.module_by_filename_cache.get(filename, "unknown")
201
202             path = self.tree_node_by_module.get(loading_module, [])
203             path.extend([loaded_module])
204             self.tree.insert(path)
205             self.tree_node_by_module[loading_module] = path
206
207             msg = f"*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}"
208             logger.debug(msg)
209             print(msg)
210             return
211         msg = f"*** Import {loaded_module} from ?????"
212         logger.debug(msg)
213         print(msg)
214
215     def invalidate_caches(self):
216         pass
217
218     def find_importer(self, module: str):
219         if module in self.tree_node_by_module:
220             node = self.tree_node_by_module[module]
221             return node
222         return []
223
224
225 # Audit import events?  Note: this runs early in the lifetime of the
226 # process (assuming that import bootstrap happens early); config has
227 # (probably) not yet been loaded or parsed the commandline.  Also,
228 # some things have probably already been imported while we weren't
229 # watching so this information may be incomplete.
230 #
231 # Also note: move bootstrap up in the global import list to catch
232 # more import events and have a more complete record.
233 IMPORT_INTERCEPTOR = None
234 for arg in sys.argv:
235     if arg == "--audit_import_events":
236         IMPORT_INTERCEPTOR = ImportInterceptor()
237         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
238
239
240 def dump_all_objects() -> None:
241     """Helper code to dump all known python objects."""
242
243     messages = {}
244     all_modules = sys.modules
245     for obj in object.__subclasses__():
246         if not hasattr(obj, "__name__"):
247             continue
248         klass = obj.__name__
249         if not hasattr(obj, "__module__"):
250             continue
251         class_mod_name = obj.__module__
252         if class_mod_name in all_modules:
253             mod = all_modules[class_mod_name]
254             if not hasattr(mod, "__name__"):
255                 mod_name = class_mod_name
256             else:
257                 mod_name = mod.__name__
258             if hasattr(mod, "__file__"):
259                 mod_file = mod.__file__
260             else:
261                 mod_file = "unknown"
262             if IMPORT_INTERCEPTOR is not None:
263                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
264             else:
265                 import_path = "unknown"
266             msg = f"{class_mod_name}::{klass} ({mod_file})"
267             if import_path != "unknown" and len(import_path) > 0:
268                 msg += f" imported by {import_path}"
269             messages[f"{class_mod_name}::{klass}"] = msg
270     for x in sorted(messages.keys()):
271         logger.debug(messages[x])
272         print(messages[x])
273
274
275 def initialize(entry_point):
276     """
277     Do whole program setup and instrumentation.  See module comments for
278     details.  To use::
279
280         from pyutils import bootstrap
281
282         @bootstrap.initialize
283         def main():
284             whatever
285
286         if __name__ == '__main__':
287             main()
288     """
289
290     @functools.wraps(entry_point)
291     def initialize_wrapper(*args, **kwargs):
292         # Hook top level unhandled exceptions, maybe invoke debugger.
293         if sys.excepthook == sys.__excepthook__:
294             sys.excepthook = handle_uncaught_exception
295
296         # Try to figure out the name of the program entry point.  Then
297         # parse configuration (based on cmdline flags, environment vars
298         # etc...)
299         entry_filename = None
300         entry_descr = None
301         try:
302             entry_filename = entry_point.__code__.co_filename
303             entry_descr = repr(entry_point.__code__)
304         except Exception:
305             if (
306                 "__globals__" in entry_point.__dict__
307                 and "__file__" in entry_point.__globals__
308             ):
309                 entry_filename = entry_point.__globals__["__file__"]
310                 entry_descr = entry_filename
311         if not entry_filename:
312             entry_filename = 'UNKNOWN'
313         config.parse(entry_filename)
314
315         if config.config["trace_memory"]:
316             import tracemalloc
317
318             tracemalloc.start()
319
320         # Initialize logging... and log some remembered messages from
321         # config module.  Also logs about the logging config if we're
322         # in debug mode.
323         logging_utils.initialize_logging(logging.getLogger())
324         config.late_logging()
325
326         # Log some info about the python interpreter itself if we're
327         # in debug mode.
328         logger.debug(
329             "Platform: %s, maxint=0x%x, byteorder=%s",
330             sys.platform,
331             sys.maxsize,
332             sys.byteorder,
333         )
334         logger.debug("Python interpreter path: %s", sys.executable)
335         logger.debug("Python interpreter version: %s", sys.version)
336         logger.debug("Python implementation: %s", sys.implementation)
337         logger.debug("Python C API version: %s", sys.api_version)
338         if __debug__:
339             logger.debug("Python interpreter running in __debug__ mode.")
340         else:
341             logger.debug("Python interpreter running in optimized mode.")
342         logger.debug("PYTHONPATH: %s", sys.path)
343
344         # Dump some info about the physical machine we're running on
345         # if we're ing debug mode.
346         if "SC_PAGE_SIZE" in os.sysconf_names and "SC_PHYS_PAGES" in os.sysconf_names:
347             logger.debug(
348                 "Physical memory: %.1fGb",
349                 os.sysconf("SC_PAGE_SIZE")
350                 * os.sysconf("SC_PHYS_PAGES")
351                 / float(1024**3),
352             )
353         logger.debug("Logical processors: %s", os.cpu_count())
354
355         # Allow programs that don't bother to override the random seed
356         # to be replayed via the commandline.
357         import random
358
359         random_seed = config.config["set_random_seed"]
360         if random_seed is not None:
361             random_seed = random_seed[0]
362         else:
363             random_seed = int.from_bytes(os.urandom(4), "little")
364         if config.config["show_random_seed"]:
365             msg = f"Global random seed is: {random_seed}"
366             logger.debug(msg)
367             print(msg)
368         random.seed(random_seed)
369
370         # Give each run a unique identifier if we're in debug mode.
371         logger.debug("This run's UUID: %s", str(uuid.uuid4()))
372
373         # Do it, invoke the user's code.  Pay attention to how long it takes.
374         logger.debug(
375             "Starting %s (program entry point) ---------------------- ", entry_descr
376         )
377         ret = None
378         from pyutils import stopwatch
379
380         if config.config["run_profiler"]:
381             import cProfile
382             from pstats import SortKey
383
384             with stopwatch.Timer() as t:
385                 cProfile.runctx(
386                     "ret = entry_point(*args, **kwargs)",
387                     globals(),
388                     locals(),
389                     None,
390                     SortKey.CUMULATIVE,
391                 )
392         else:
393             with stopwatch.Timer() as t:
394                 ret = entry_point(*args, **kwargs)
395
396         logger.debug("%s (program entry point) returned %s.", entry_descr, ret)
397
398         if config.config["trace_memory"]:
399             snapshot = tracemalloc.take_snapshot()
400             top_stats = snapshot.statistics("lineno")
401             print()
402             print("--trace_memory's top 10 memory using files:")
403             for stat in top_stats[:10]:
404                 print(stat)
405
406         if config.config["dump_all_objects"]:
407             dump_all_objects()
408
409         if config.config["audit_import_events"]:
410             if IMPORT_INTERCEPTOR is not None:
411                 print(IMPORT_INTERCEPTOR.tree)
412
413         walltime = t()
414         (utime, stime, cutime, cstime, elapsed_time) = os.times()
415         logger.debug(
416             "\n"
417             "user: %.4fs\n"
418             "system: %.4fs\n"
419             "child user: %.4fs\n"
420             "child system: %.4fs\n"
421             "machine uptime: %.4fs\n"
422             "walltime: %.4fs",
423             utime,
424             stime,
425             cutime,
426             cstime,
427             elapsed_time,
428             walltime,
429         )
430
431         # If it doesn't return cleanly, call attention to the return value.
432         base_filename = os.path.basename(entry_filename)
433         if ret is not None and ret != 0:
434             if not logging_utils.non_zero_return_value(ret):
435                 logger.error("%s: Exit %s", base_filename, ret)
436         else:
437             logger.debug("%s: Exit %s", base_filename, ret)
438         sys.exit(ret)
439
440     return initialize_wrapper