Fix widths in ansi.py main color dumper.
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag :code:`--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging_utils` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag :code:`--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       :code:`--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       :code:`--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the :code:`--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 import uuid
47 from inspect import stack
48 from typing import NoReturn
49
50 from pyutils import config, logging_utils
51 from pyutils.argparse_utils import ActionNoYes
52
53 # This module is commonly used by others in here and should avoid
54 # taking any unnecessary dependencies back on them.
55
56
57 logger = logging.getLogger(__name__)
58
59 cfg = config.add_commandline_args(
60     f"Bootstrap ({__file__})",
61     "Args related to python program bootstrapper and Swiss army knife",
62 )
63 cfg.add_argument(
64     "--debug_unhandled_exceptions",
65     action=ActionNoYes,
66     default=False,
67     help="Break into pdb on top level unhandled exceptions.",
68 )
69 cfg.add_argument(
70     "--show_random_seed",
71     action=ActionNoYes,
72     default=False,
73     help="Should we display (and log.debug) the global random seed?",
74 )
75 cfg.add_argument(
76     "--set_random_seed",
77     type=int,
78     nargs=1,
79     default=None,
80     metavar="SEED_INT",
81     help="Override the global random seed with a particular number.",
82 )
83 cfg.add_argument(
84     "--dump_all_objects",
85     action=ActionNoYes,
86     default=False,
87     help="Should we dump the Python import tree before main?",
88 )
89 cfg.add_argument(
90     "--audit_import_events",
91     action=ActionNoYes,
92     default=False,
93     help="Should we audit all import events?",
94 )
95 cfg.add_argument(
96     "--run_profiler",
97     action=ActionNoYes,
98     default=False,
99     help="Should we run cProfile on this code?",
100 )
101 cfg.add_argument(
102     "--trace_memory",
103     action=ActionNoYes,
104     default=False,
105     help="Should we record/report on memory utilization?",
106 )
107
108 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
109
110
111 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
112     """
113     Top-level exception handler for exceptions that make it past any exception
114     handlers in the python code being run.  Logs the error and stacktrace then
115     maybe attaches a debugger.
116
117     """
118     msg = f"Unhandled top level exception {exc_type}"
119     logger.exception(msg)
120     print(msg, file=sys.stderr)
121     if issubclass(exc_type, KeyboardInterrupt):
122         sys.__excepthook__(exc_type, exc_value, exc_tb)
123         return
124     else:
125         import io
126         import traceback
127
128         tb_output = io.StringIO()
129         traceback.print_tb(exc_tb, None, tb_output)
130         print(tb_output.getvalue(), file=sys.stderr)
131         logger.error(tb_output.getvalue())
132         tb_output.close()
133
134         # stdin or stderr is redirected, just do the normal thing
135         if not sys.stderr.isatty() or not sys.stdin.isatty():
136             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
137
138         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
139             if config.config["debug_unhandled_exceptions"]:
140                 logger.info("Invoking the debugger...")
141                 import pdb
142
143                 pdb.pm()
144             else:
145                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
146
147
148 class ImportInterceptor(importlib.abc.MetaPathFinder):
149     """An interceptor that always allows module load events but dumps a
150     record into the log and onto stdout when modules are loaded and
151     produces an audit of who imported what at the end of the run.  It
152     can't see any load events that happen before it, though, so move
153     bootstrap up in your __main__'s import list just temporarily to
154     get a good view.
155
156     """
157
158     def __init__(self):
159         from pyutils.collectionz.trie import Trie
160
161         self.module_by_filename_cache = {}
162         self.repopulate_modules_by_filename()
163         self.tree = Trie()
164         self.tree_node_by_module = {}
165
166     def repopulate_modules_by_filename(self):
167         self.module_by_filename_cache.clear()
168         for (
169             _,
170             mod,
171         ) in sys.modules.copy().items():  # copy here because modules is volatile
172             if hasattr(mod, "__file__"):
173                 fname = getattr(mod, "__file__")
174             else:
175                 fname = "unknown"
176             self.module_by_filename_cache[fname] = mod
177
178     @staticmethod
179     def should_ignore_filename(filename: str) -> bool:
180         return "importlib" in filename or "six.py" in filename
181
182     def find_module(self, fullname, path) -> NoReturn:
183         raise Exception(
184             "This method has been deprecated since Python 3.4, please upgrade."
185         )
186
187     def find_spec(self, loaded_module, path=None, _=None):
188         s = stack()
189         for x in range(3, len(s)):
190             filename = s[x].filename
191             if ImportInterceptor.should_ignore_filename(filename):
192                 continue
193
194             loading_function = s[x].function
195             if filename in self.module_by_filename_cache:
196                 loading_module = self.module_by_filename_cache[filename]
197             else:
198                 self.repopulate_modules_by_filename()
199                 loading_module = self.module_by_filename_cache.get(filename, "unknown")
200
201             path = self.tree_node_by_module.get(loading_module, [])
202             path.extend([loaded_module])
203             self.tree.insert(path)
204             self.tree_node_by_module[loading_module] = path
205
206             msg = f"*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}"
207             logger.debug(msg)
208             print(msg)
209             return
210         msg = f"*** Import {loaded_module} from ?????"
211         logger.debug(msg)
212         print(msg)
213
214     def invalidate_caches(self):
215         pass
216
217     def find_importer(self, module: str):
218         if module in self.tree_node_by_module:
219             node = self.tree_node_by_module[module]
220             return node
221         return []
222
223
224 # Audit import events?  Note: this runs early in the lifetime of the
225 # process (assuming that import bootstrap happens early); config has
226 # (probably) not yet been loaded or parsed the commandline.  Also,
227 # some things have probably already been imported while we weren't
228 # watching so this information may be incomplete.
229 #
230 # Also note: move bootstrap up in the global import list to catch
231 # more import events and have a more complete record.
232 IMPORT_INTERCEPTOR = None
233 for arg in sys.argv:
234     if arg == "--audit_import_events":
235         IMPORT_INTERCEPTOR = ImportInterceptor()
236         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
237
238
239 def dump_all_objects() -> None:
240     """Helper code to dump all known python objects."""
241
242     messages = {}
243     all_modules = sys.modules
244     for obj in object.__subclasses__():
245         if not hasattr(obj, "__name__"):
246             continue
247         klass = obj.__name__
248         if not hasattr(obj, "__module__"):
249             continue
250         class_mod_name = obj.__module__
251         if class_mod_name in all_modules:
252             mod = all_modules[class_mod_name]
253             if not hasattr(mod, "__name__"):
254                 mod_name = class_mod_name
255             else:
256                 mod_name = mod.__name__
257             if hasattr(mod, "__file__"):
258                 mod_file = mod.__file__
259             else:
260                 mod_file = "unknown"
261             if IMPORT_INTERCEPTOR is not None:
262                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
263             else:
264                 import_path = "unknown"
265             msg = f"{class_mod_name}::{klass} ({mod_file})"
266             if import_path != "unknown" and len(import_path) > 0:
267                 msg += f" imported by {import_path}"
268             messages[f"{class_mod_name}::{klass}"] = msg
269     for x in sorted(messages.keys()):
270         logger.debug(messages[x])
271         print(messages[x])
272
273
274 def initialize(entry_point):
275     """
276     Do whole program setup and instrumentation.  See module comments for
277     details.  To use::
278
279         from pyutils import bootstrap
280
281         @bootstrap.initialize
282         def main():
283             whatever
284
285         if __name__ == '__main__':
286             main()
287     """
288
289     @functools.wraps(entry_point)
290     def initialize_wrapper(*args, **kwargs):
291         # Hook top level unhandled exceptions, maybe invoke debugger.
292         if sys.excepthook == sys.__excepthook__:
293             sys.excepthook = handle_uncaught_exception
294
295         # Try to figure out the name of the program entry point.  Then
296         # parse configuration (based on cmdline flags, environment vars
297         # etc...)
298         entry_filename = None
299         entry_descr = None
300         try:
301             entry_filename = entry_point.__code__.co_filename
302             entry_descr = repr(entry_point.__code__)
303         except Exception:
304             if (
305                 "__globals__" in entry_point.__dict__
306                 and "__file__" in entry_point.__globals__
307             ):
308                 entry_filename = entry_point.__globals__["__file__"]
309                 entry_descr = entry_filename
310         config.parse(entry_filename)
311
312         if config.config["trace_memory"]:
313             import tracemalloc
314
315             tracemalloc.start()
316
317         # Initialize logging... and log some remembered messages from
318         # config module.  Also logs about the logging config if we're
319         # in debug mode.
320         logging_utils.initialize_logging(logging.getLogger())
321         config.late_logging()
322
323         # Log some info about the python interpreter itself if we're
324         # in debug mode.
325         logger.debug(
326             "Platform: %s, maxint=0x%x, byteorder=%s",
327             sys.platform,
328             sys.maxsize,
329             sys.byteorder,
330         )
331         logger.debug("Python interpreter path: %s", sys.executable)
332         logger.debug("Python interpreter version: %s", sys.version)
333         logger.debug("Python implementation: %s", sys.implementation)
334         logger.debug("Python C API version: %s", sys.api_version)
335         if __debug__:
336             logger.debug("Python interpreter running in __debug__ mode.")
337         else:
338             logger.debug("Python interpreter running in optimized mode.")
339         logger.debug("PYTHONPATH: %s", sys.path)
340
341         # Dump some info about the physical machine we're running on
342         # if we're ing debug mode.
343         if "SC_PAGE_SIZE" in os.sysconf_names and "SC_PHYS_PAGES" in os.sysconf_names:
344             logger.debug(
345                 "Physical memory: %.1fGb",
346                 os.sysconf("SC_PAGE_SIZE")
347                 * os.sysconf("SC_PHYS_PAGES")
348                 / float(1024**3),
349             )
350         logger.debug("Logical processors: %s", os.cpu_count())
351
352         # Allow programs that don't bother to override the random seed
353         # to be replayed via the commandline.
354         import random
355
356         random_seed = config.config["set_random_seed"]
357         if random_seed is not None:
358             random_seed = random_seed[0]
359         else:
360             random_seed = int.from_bytes(os.urandom(4), "little")
361         if config.config["show_random_seed"]:
362             msg = f"Global random seed is: {random_seed}"
363             logger.debug(msg)
364             print(msg)
365         random.seed(random_seed)
366
367         # Give each run a unique identifier if we're in debug mode.
368         logger.debug("This run's UUID: %s", str(uuid.uuid4()))
369
370         # Do it, invoke the user's code.  Pay attention to how long it takes.
371         logger.debug(
372             "Starting %s (program entry point) ---------------------- ", entry_descr
373         )
374         ret = None
375         from pyutils import stopwatch
376
377         if config.config["run_profiler"]:
378             import cProfile
379             from pstats import SortKey
380
381             with stopwatch.Timer() as t:
382                 cProfile.runctx(
383                     "ret = entry_point(*args, **kwargs)",
384                     globals(),
385                     locals(),
386                     None,
387                     SortKey.CUMULATIVE,
388                 )
389         else:
390             with stopwatch.Timer() as t:
391                 ret = entry_point(*args, **kwargs)
392
393         logger.debug("%s (program entry point) returned %s.", entry_descr, ret)
394
395         if config.config["trace_memory"]:
396             snapshot = tracemalloc.take_snapshot()
397             top_stats = snapshot.statistics("lineno")
398             print()
399             print("--trace_memory's top 10 memory using files:")
400             for stat in top_stats[:10]:
401                 print(stat)
402
403         if config.config["dump_all_objects"]:
404             dump_all_objects()
405
406         if config.config["audit_import_events"]:
407             if IMPORT_INTERCEPTOR is not None:
408                 print(IMPORT_INTERCEPTOR.tree)
409
410         walltime = t()
411         (utime, stime, cutime, cstime, elapsed_time) = os.times()
412         logger.debug(
413             "\n"
414             "user: %.4fs\n"
415             "system: %.4fs\n"
416             "child user: %.4fs\n"
417             "child system: %.4fs\n"
418             "machine uptime: %.4fs\n"
419             "walltime: %.4fs",
420             utime,
421             stime,
422             cutime,
423             cstime,
424             elapsed_time,
425             walltime,
426         )
427
428         # If it doesn't return cleanly, call attention to the return value.
429         if ret is not None and ret != 0:
430             logger.error("Exit %s", ret)
431         else:
432             logger.debug("Exit %s", ret)
433         sys.exit(ret)
434
435     return initialize_wrapper