Removes my hacky --lmodule which I will miss dearly but it was
[pyutils.git] / src / pyutils / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """
6 If you decorate your main method (i.e. program entry point) like this::
7
8     @bootstrap.initialize
9     def main():
10         whatever
11
12 ...you will get:
13
14     * automatic support for :py:mod:`pyutils.config` (argument parsing, see
15       that module for details),
16     * The ability to break into pdb on unhandled exceptions (which is
17       enabled/disabled via the commandline flag :code:`--debug_unhandled_exceptions`),
18     * automatic logging support from :py:mod:`pyutils.logging_utils` controllable
19       via several commandline flags,
20     * the ability to optionally enable whole-program code profiling and reporting
21       when you run your code using commandline flag :code:`--run_profiler`,
22     * the ability to optionally enable import auditing via the commandline flag
23       :code:`--audit_import_events`.  This logs a message whenever a module is imported
24       *after* the bootstrap module itself is loaded.  Note that other modules may
25       already be loaded when bootstrap is loaded and these imports will not be
26       logged.  If you're trying to debug import events or dependency problems,
27       I suggest putting bootstrap very early in your import list and using this
28       flag.
29     * optional memory profiling for your program set via the commandline flag
30       :code:`--trace_memory`.  This provides a report of python memory utilization
31       at program termination time.
32     * the ability to set the global random seed via commandline flag for
33       reproducable runs (as long as subsequent code doesn't reset the seed)
34       using the :code:`--set_random_seed` flag,
35     * automatic program timing and reporting logged to the INFO log,
36     * more verbose error handling and reporting.
37
38 """
39
40 import functools
41 import importlib
42 import importlib.abc
43 import logging
44 import os
45 import sys
46 import uuid
47 from inspect import stack
48
49 from pyutils import config, logging_utils
50 from pyutils.argparse_utils import ActionNoYes
51
52 # This module is commonly used by others in here and should avoid
53 # taking any unnecessary dependencies back on them.
54
55
56 logger = logging.getLogger(__name__)
57
58 cfg = config.add_commandline_args(
59     f"Bootstrap ({__file__})",
60     "Args related to python program bootstrapper and Swiss army knife",
61 )
62 cfg.add_argument(
63     "--debug_unhandled_exceptions",
64     action=ActionNoYes,
65     default=False,
66     help="Break into pdb on top level unhandled exceptions.",
67 )
68 cfg.add_argument(
69     "--show_random_seed",
70     action=ActionNoYes,
71     default=False,
72     help="Should we display (and log.debug) the global random seed?",
73 )
74 cfg.add_argument(
75     "--set_random_seed",
76     type=int,
77     nargs=1,
78     default=None,
79     metavar="SEED_INT",
80     help="Override the global random seed with a particular number.",
81 )
82 cfg.add_argument(
83     "--dump_all_objects",
84     action=ActionNoYes,
85     default=False,
86     help="Should we dump the Python import tree before main?",
87 )
88 cfg.add_argument(
89     "--audit_import_events",
90     action=ActionNoYes,
91     default=False,
92     help="Should we audit all import events?",
93 )
94 cfg.add_argument(
95     "--run_profiler",
96     action=ActionNoYes,
97     default=False,
98     help="Should we run cProfile on this code?",
99 )
100 cfg.add_argument(
101     "--trace_memory",
102     action=ActionNoYes,
103     default=False,
104     help="Should we record/report on memory utilization?",
105 )
106
107 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
108
109
110 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
111     """
112     Top-level exception handler for exceptions that make it past any exception
113     handlers in the python code being run.  Logs the error and stacktrace then
114     maybe attaches a debugger.
115
116     """
117     msg = f"Unhandled top level exception {exc_type}"
118     logger.exception(msg)
119     print(msg, file=sys.stderr)
120     if issubclass(exc_type, KeyboardInterrupt):
121         sys.__excepthook__(exc_type, exc_value, exc_tb)
122         return
123     else:
124         import io
125         import traceback
126
127         tb_output = io.StringIO()
128         traceback.print_tb(exc_tb, None, tb_output)
129         print(tb_output.getvalue(), file=sys.stderr)
130         logger.error(tb_output.getvalue())
131         tb_output.close()
132
133         # stdin or stderr is redirected, just do the normal thing
134         if not sys.stderr.isatty() or not sys.stdin.isatty():
135             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
136
137         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
138             if config.config["debug_unhandled_exceptions"]:
139                 logger.info("Invoking the debugger...")
140                 import pdb
141
142                 pdb.pm()
143             else:
144                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
145
146
147 class ImportInterceptor(importlib.abc.MetaPathFinder):
148     """An interceptor that always allows module load events but dumps a
149     record into the log and onto stdout when modules are loaded and
150     produces an audit of who imported what at the end of the run.  It
151     can't see any load events that happen before it, though, so move
152     bootstrap up in your __main__'s import list just temporarily to
153     get a good view.
154
155     """
156
157     def __init__(self):
158         from pyutils.collectionz.trie import Trie
159
160         self.module_by_filename_cache = {}
161         self.repopulate_modules_by_filename()
162         self.tree = Trie()
163         self.tree_node_by_module = {}
164
165     def repopulate_modules_by_filename(self):
166         self.module_by_filename_cache.clear()
167         for (
168             _,
169             mod,
170         ) in sys.modules.copy().items():  # copy here because modules is volatile
171             if hasattr(mod, "__file__"):
172                 fname = getattr(mod, "__file__")
173             else:
174                 fname = "unknown"
175             self.module_by_filename_cache[fname] = mod
176
177     @staticmethod
178     def should_ignore_filename(filename: str) -> bool:
179         return "importlib" in filename or "six.py" in filename
180
181     def find_module(self, fullname, path):
182         raise Exception(
183             "This method has been deprecated since Python 3.4, please upgrade."
184         )
185
186     def find_spec(self, loaded_module, path=None, _=None):
187         s = stack()
188         for x in range(3, len(s)):
189             filename = s[x].filename
190             if ImportInterceptor.should_ignore_filename(filename):
191                 continue
192
193             loading_function = s[x].function
194             if filename in self.module_by_filename_cache:
195                 loading_module = self.module_by_filename_cache[filename]
196             else:
197                 self.repopulate_modules_by_filename()
198                 loading_module = self.module_by_filename_cache.get(filename, "unknown")
199
200             path = self.tree_node_by_module.get(loading_module, [])
201             path.extend([loaded_module])
202             self.tree.insert(path)
203             self.tree_node_by_module[loading_module] = path
204
205             msg = f"*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}"
206             logger.debug(msg)
207             print(msg)
208             return
209         msg = f"*** Import {loaded_module} from ?????"
210         logger.debug(msg)
211         print(msg)
212
213     def invalidate_caches(self):
214         pass
215
216     def find_importer(self, module: str):
217         if module in self.tree_node_by_module:
218             node = self.tree_node_by_module[module]
219             return node
220         return []
221
222
223 # Audit import events?  Note: this runs early in the lifetime of the
224 # process (assuming that import bootstrap happens early); config has
225 # (probably) not yet been loaded or parsed the commandline.  Also,
226 # some things have probably already been imported while we weren't
227 # watching so this information may be incomplete.
228 #
229 # Also note: move bootstrap up in the global import list to catch
230 # more import events and have a more complete record.
231 IMPORT_INTERCEPTOR = None
232 for arg in sys.argv:
233     if arg == "--audit_import_events":
234         IMPORT_INTERCEPTOR = ImportInterceptor()
235         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
236
237
238 def dump_all_objects() -> None:
239     """Helper code to dump all known python objects."""
240
241     messages = {}
242     all_modules = sys.modules
243     for obj in object.__subclasses__():
244         if not hasattr(obj, "__name__"):
245             continue
246         klass = obj.__name__
247         if not hasattr(obj, "__module__"):
248             continue
249         class_mod_name = obj.__module__
250         if class_mod_name in all_modules:
251             mod = all_modules[class_mod_name]
252             if not hasattr(mod, "__name__"):
253                 mod_name = class_mod_name
254             else:
255                 mod_name = mod.__name__
256             if hasattr(mod, "__file__"):
257                 mod_file = mod.__file__
258             else:
259                 mod_file = "unknown"
260             if IMPORT_INTERCEPTOR is not None:
261                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
262             else:
263                 import_path = "unknown"
264             msg = f"{class_mod_name}::{klass} ({mod_file})"
265             if import_path != "unknown" and len(import_path) > 0:
266                 msg += f" imported by {import_path}"
267             messages[f"{class_mod_name}::{klass}"] = msg
268     for x in sorted(messages.keys()):
269         logger.debug(messages[x])
270         print(messages[x])
271
272
273 def initialize(entry_point):
274     """
275     Do whole program setup and instrumentation.  See module comments for
276     details.  To use::
277
278         from pyutils import bootstrap
279
280         @bootstrap.initialize
281         def main():
282             whatever
283
284         if __name__ == '__main__':
285             main()
286     """
287
288     @functools.wraps(entry_point)
289     def initialize_wrapper(*args, **kwargs):
290         # Hook top level unhandled exceptions, maybe invoke debugger.
291         if sys.excepthook == sys.__excepthook__:
292             sys.excepthook = handle_uncaught_exception
293
294         # Try to figure out the name of the program entry point.  Then
295         # parse configuration (based on cmdline flags, environment vars
296         # etc...)
297         entry_filename = None
298         entry_descr = None
299         try:
300             entry_filename = entry_point.__code__.co_filename
301             entry_descr = repr(entry_point.__code__)
302         except Exception:
303             if (
304                 "__globals__" in entry_point.__dict__
305                 and "__file__" in entry_point.__globals__
306             ):
307                 entry_filename = entry_point.__globals__["__file__"]
308                 entry_descr = entry_filename
309         config.parse(entry_filename)
310
311         if config.config["trace_memory"]:
312             import tracemalloc
313
314             tracemalloc.start()
315
316         # Initialize logging... and log some remembered messages from
317         # config module.  Also logs about the logging config if we're
318         # in debug mode.
319         logging_utils.initialize_logging(logging.getLogger())
320         config.late_logging()
321
322         # Log some info about the python interpreter itself if we're
323         # in debug mode.
324         logger.debug(
325             "Platform: %s, maxint=0x%x, byteorder=%s",
326             sys.platform,
327             sys.maxsize,
328             sys.byteorder,
329         )
330         logger.debug("Python interpreter version: %s", sys.version)
331         logger.debug("Python implementation: %s", sys.implementation)
332         logger.debug("Python C API version: %s", sys.api_version)
333         if __debug__:
334             logger.debug("Python interpreter running in __debug__ mode.")
335         else:
336             logger.debug("Python interpreter running in optimized mode.")
337         logger.debug("Python path: %s", sys.path)
338
339         # Dump some info about the physical machine we're running on
340         # if we're ing debug mode.
341         if "SC_PAGE_SIZE" in os.sysconf_names and "SC_PHYS_PAGES" in os.sysconf_names:
342             logger.debug(
343                 "Physical memory: %.1fGb",
344                 os.sysconf("SC_PAGE_SIZE")
345                 * os.sysconf("SC_PHYS_PAGES")
346                 / float(1024**3),
347             )
348         logger.debug("Logical processors: %s", os.cpu_count())
349
350         # Allow programs that don't bother to override the random seed
351         # to be replayed via the commandline.
352         import random
353
354         random_seed = config.config["set_random_seed"]
355         if random_seed is not None:
356             random_seed = random_seed[0]
357         else:
358             random_seed = int.from_bytes(os.urandom(4), "little")
359         if config.config["show_random_seed"]:
360             msg = f"Global random seed is: {random_seed}"
361             logger.debug(msg)
362             print(msg)
363         random.seed(random_seed)
364
365         # Give each run a unique identifier if we're in debug mode.
366         logger.debug("This run's UUID: %s", str(uuid.uuid4()))
367
368         # Do it, invoke the user's code.  Pay attention to how long it takes.
369         logger.debug(
370             "Starting %s (program entry point) ---------------------- ", entry_descr
371         )
372         ret = None
373         from pyutils import stopwatch
374
375         if config.config["run_profiler"]:
376             import cProfile
377             from pstats import SortKey
378
379             with stopwatch.Timer() as t:
380                 cProfile.runctx(
381                     "ret = entry_point(*args, **kwargs)",
382                     globals(),
383                     locals(),
384                     None,
385                     SortKey.CUMULATIVE,
386                 )
387         else:
388             with stopwatch.Timer() as t:
389                 ret = entry_point(*args, **kwargs)
390
391         logger.debug("%s (program entry point) returned %s.", entry_descr, ret)
392
393         if config.config["trace_memory"]:
394             snapshot = tracemalloc.take_snapshot()
395             top_stats = snapshot.statistics("lineno")
396             print()
397             print("--trace_memory's top 10 memory using files:")
398             for stat in top_stats[:10]:
399                 print(stat)
400
401         if config.config["dump_all_objects"]:
402             dump_all_objects()
403
404         if config.config["audit_import_events"]:
405             if IMPORT_INTERCEPTOR is not None:
406                 print(IMPORT_INTERCEPTOR.tree)
407
408         walltime = t()
409         (utime, stime, cutime, cstime, elapsed_time) = os.times()
410         logger.debug(
411             "\n"
412             "user: %.4fs\n"
413             "system: %.4fs\n"
414             "child user: %.4fs\n"
415             "child system: %.4fs\n"
416             "machine uptime: %.4fs\n"
417             "walltime: %.4fs",
418             utime,
419             stime,
420             cutime,
421             cstime,
422             elapsed_time,
423             walltime,
424         )
425
426         # If it doesn't return cleanly, call attention to the return value.
427         if ret is not None and ret != 0:
428             logger.error("Exit %s", ret)
429         else:
430             logger.debug("Exit %s", ret)
431         sys.exit(ret)
432
433     return initialize_wrapper