Cleanup geocode.
[python_utils.git] / bootstrap.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module for wrapping around python programs and doing some
6 minor setup and tear down work for them.  With it, you can break into
7 pdb on unhandled top level exceptions, profile your code by passing a
8 commandline argument in, audit module import events, examine where
9 memory is being used in your program, and so on.
10
11 """
12
13 import functools
14 import importlib
15 import logging
16 import os
17 import sys
18 from inspect import stack
19
20 import config
21 import logging_utils
22 from argparse_utils import ActionNoYes
23
24 # This module is commonly used by others in here and should avoid
25 # taking any unnecessary dependencies back on them.
26
27
28 logger = logging.getLogger(__name__)
29
30 cfg = config.add_commandline_args(
31     f'Bootstrap ({__file__})',
32     'Args related to python program bootstrapper and Swiss army knife',
33 )
34 cfg.add_argument(
35     '--debug_unhandled_exceptions',
36     action=ActionNoYes,
37     default=False,
38     help='Break into pdb on top level unhandled exceptions.',
39 )
40 cfg.add_argument(
41     '--show_random_seed',
42     action=ActionNoYes,
43     default=False,
44     help='Should we display (and log.debug) the global random seed?',
45 )
46 cfg.add_argument(
47     '--set_random_seed',
48     type=int,
49     nargs=1,
50     default=None,
51     metavar='SEED_INT',
52     help='Override the global random seed with a particular number.',
53 )
54 cfg.add_argument(
55     '--dump_all_objects',
56     action=ActionNoYes,
57     default=False,
58     help='Should we dump the Python import tree before main?',
59 )
60 cfg.add_argument(
61     '--audit_import_events',
62     action=ActionNoYes,
63     default=False,
64     help='Should we audit all import events?',
65 )
66 cfg.add_argument(
67     '--run_profiler',
68     action=ActionNoYes,
69     default=False,
70     help='Should we run cProfile on this code?',
71 )
72 cfg.add_argument(
73     '--trace_memory',
74     action=ActionNoYes,
75     default=False,
76     help='Should we record/report on memory utilization?',
77 )
78
79 ORIGINAL_EXCEPTION_HOOK = sys.excepthook
80
81
82 def handle_uncaught_exception(exc_type, exc_value, exc_tb):
83     """
84     Top-level exception handler for exceptions that make it past any exception
85     handlers in the python code being run.  Logs the error and stacktrace then
86     maybe attaches a debugger.
87
88     """
89     msg = f'Unhandled top level exception {exc_type}'
90     logger.exception(msg)
91     print(msg, file=sys.stderr)
92     if issubclass(exc_type, KeyboardInterrupt):
93         sys.__excepthook__(exc_type, exc_value, exc_tb)
94         return
95     else:
96         import io
97         import traceback
98
99         tb_output = io.StringIO()
100         traceback.print_tb(exc_tb, None, tb_output)
101         print(tb_output.getvalue(), file=sys.stderr)
102         logger.error(tb_output.getvalue())
103         tb_output.close()
104
105         # stdin or stderr is redirected, just do the normal thing
106         if not sys.stderr.isatty() or not sys.stdin.isatty():
107             ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
108
109         else:  # a terminal is attached and stderr isn't redirected, maybe debug.
110             if config.config['debug_unhandled_exceptions']:
111                 logger.info("Invoking the debugger...")
112                 import pdb
113
114                 pdb.pm()
115             else:
116                 ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb)
117
118
119 class ImportInterceptor(importlib.abc.MetaPathFinder):
120     """An interceptor that always allows module load events but dumps a
121     record into the log and onto stdout when modules are loaded and
122     produces an audit of who imported what at the end of the run.  It
123     can't see any load events that happen before it, though, so move
124     bootstrap up in your __main__'s import list just temporarily to
125     get a good view.
126
127     """
128
129     def __init__(self):
130         import collect.trie
131
132         self.module_by_filename_cache = {}
133         self.repopulate_modules_by_filename()
134         self.tree = collect.trie.Trie()
135         self.tree_node_by_module = {}
136
137     def repopulate_modules_by_filename(self):
138         self.module_by_filename_cache.clear()
139         for _, mod in sys.modules.copy().items():  # copy here because modules is volatile
140             if hasattr(mod, '__file__'):
141                 fname = getattr(mod, '__file__')
142             else:
143                 fname = 'unknown'
144             self.module_by_filename_cache[fname] = mod
145
146     @staticmethod
147     def should_ignore_filename(filename: str) -> bool:
148         return 'importlib' in filename or 'six.py' in filename
149
150     def find_module(self, fullname, path):
151         raise Exception("This method has been deprecated since Python 3.4, please upgrade.")
152
153     def find_spec(self, loaded_module, path=None, _=None):
154         s = stack()
155         for x in range(3, len(s)):
156             filename = s[x].filename
157             if ImportInterceptor.should_ignore_filename(filename):
158                 continue
159
160             loading_function = s[x].function
161             if filename in self.module_by_filename_cache:
162                 loading_module = self.module_by_filename_cache[filename]
163             else:
164                 self.repopulate_modules_by_filename()
165                 loading_module = self.module_by_filename_cache.get(filename, 'unknown')
166
167             path = self.tree_node_by_module.get(loading_module, [])
168             path.extend([loaded_module])
169             self.tree.insert(path)
170             self.tree_node_by_module[loading_module] = path
171
172             msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}'
173             logger.debug(msg)
174             print(msg)
175             return
176         msg = f'*** Import {loaded_module} from ?????'
177         logger.debug(msg)
178         print(msg)
179
180     def invalidate_caches(self):
181         pass
182
183     def find_importer(self, module: str):
184         if module in self.tree_node_by_module:
185             node = self.tree_node_by_module[module]
186             return node
187         return []
188
189
190 # Audit import events?  Note: this runs early in the lifetime of the
191 # process (assuming that import bootstrap happens early); config has
192 # (probably) not yet been loaded or parsed the commandline.  Also,
193 # some things have probably already been imported while we weren't
194 # watching so this information may be incomplete.
195 #
196 # Also note: move bootstrap up in the global import list to catch
197 # more import events and have a more complete record.
198 IMPORT_INTERCEPTOR = None
199 for arg in sys.argv:
200     if arg == '--audit_import_events':
201         IMPORT_INTERCEPTOR = ImportInterceptor()
202         sys.meta_path.insert(0, IMPORT_INTERCEPTOR)
203
204
205 def dump_all_objects() -> None:
206     messages = {}
207     all_modules = sys.modules
208     for obj in object.__subclasses__():
209         if not hasattr(obj, '__name__'):
210             continue
211         klass = obj.__name__
212         if not hasattr(obj, '__module__'):
213             continue
214         class_mod_name = obj.__module__
215         if class_mod_name in all_modules:
216             mod = all_modules[class_mod_name]
217             if not hasattr(mod, '__name__'):
218                 mod_name = class_mod_name
219             else:
220                 mod_name = mod.__name__
221             if hasattr(mod, '__file__'):
222                 mod_file = mod.__file__
223             else:
224                 mod_file = 'unknown'
225             if IMPORT_INTERCEPTOR is not None:
226                 import_path = IMPORT_INTERCEPTOR.find_importer(mod_name)
227             else:
228                 import_path = 'unknown'
229             msg = f'{class_mod_name}::{klass} ({mod_file})'
230             if import_path != 'unknown' and len(import_path) > 0:
231                 msg += f' imported by {import_path}'
232             messages[f'{class_mod_name}::{klass}'] = msg
233     for x in sorted(messages.keys()):
234         logger.debug(messages[x])
235         print(messages[x])
236
237
238 def initialize(entry_point):
239     """
240     Remember to initialize config, initialize logging, set/log a random
241     seed, etc... before running main.
242
243     """
244
245     @functools.wraps(entry_point)
246     def initialize_wrapper(*args, **kwargs):
247         # Hook top level unhandled exceptions, maybe invoke debugger.
248         if sys.excepthook == sys.__excepthook__:
249             sys.excepthook = handle_uncaught_exception
250
251         # Try to figure out the name of the program entry point.  Then
252         # parse configuration (based on cmdline flags, environment vars
253         # etc...)
254         entry_filename = None
255         entry_descr = None
256         try:
257             entry_filename = entry_point.__code__.co_filename
258             entry_descr = entry_point.__code__.__repr__()
259         except Exception:
260             if '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__:
261                 entry_filename = entry_point.__globals__['__file__']
262                 entry_descr = entry_filename
263         config.parse(entry_filename)
264
265         if config.config['trace_memory']:
266             import tracemalloc
267
268             tracemalloc.start()
269
270         # Initialize logging... and log some remembered messages from
271         # config module.
272         logging_utils.initialize_logging(logging.getLogger())
273         config.late_logging()
274
275         # Maybe log some info about the python interpreter itself.
276         logger.debug(
277             'Platform: %s, maxint=0x%x, byteorder=%s', sys.platform, sys.maxsize, sys.byteorder
278         )
279         logger.debug('Python interpreter version: %s', sys.version)
280         logger.debug('Python implementation: %s', sys.implementation)
281         logger.debug('Python C API version: %s', sys.api_version)
282         if __debug__:
283             logger.debug('Python interpreter running in __debug__ mode.')
284         else:
285             logger.debug('Python interpreter running in optimized mode.')
286         logger.debug('Python path: %s', sys.path)
287
288         # Log something about the site_config, many things use it.
289         import site_config
290
291         logger.debug('Global site_config: %s', site_config.get_config())
292
293         # Allow programs that don't bother to override the random seed
294         # to be replayed via the commandline.
295         import random
296
297         random_seed = config.config['set_random_seed']
298         if random_seed is not None:
299             random_seed = random_seed[0]
300         else:
301             random_seed = int.from_bytes(os.urandom(4), 'little')
302
303         if config.config['show_random_seed']:
304             msg = f'Global random seed is: {random_seed}'
305             logger.debug(msg)
306             print(msg)
307         random.seed(random_seed)
308
309         # Do it, invoke the user's code.  Pay attention to how long it takes.
310         logger.debug('Starting %s (program entry point)', entry_descr)
311         ret = None
312         import stopwatch
313
314         if config.config['run_profiler']:
315             import cProfile
316             from pstats import SortKey
317
318             with stopwatch.Timer() as t:
319                 cProfile.runctx(
320                     "ret = entry_point(*args, **kwargs)",
321                     globals(),
322                     locals(),
323                     None,
324                     SortKey.CUMULATIVE,
325                 )
326         else:
327             with stopwatch.Timer() as t:
328                 ret = entry_point(*args, **kwargs)
329
330         logger.debug('%s (program entry point) returned %s.', entry_descr, ret)
331
332         if config.config['trace_memory']:
333             snapshot = tracemalloc.take_snapshot()
334             top_stats = snapshot.statistics('lineno')
335             print()
336             print("--trace_memory's top 10 memory using files:")
337             for stat in top_stats[:10]:
338                 print(stat)
339
340         if config.config['dump_all_objects']:
341             dump_all_objects()
342
343         if config.config['audit_import_events']:
344             if IMPORT_INTERCEPTOR is not None:
345                 print(IMPORT_INTERCEPTOR.tree)
346
347         walltime = t()
348         (utime, stime, cutime, cstime, elapsed_time) = os.times()
349         logger.debug(
350             '\n'
351             'user: %.4fs\n'
352             'system: %.4fs\n'
353             'child user: %.4fs\n'
354             'child system: %.4fs\n'
355             'machine uptime: %.4fs\n'
356             'walltime: %.4fs',
357             utime,
358             stime,
359             cutime,
360             cstime,
361             elapsed_time,
362             walltime,
363         )
364
365         # If it doesn't return cleanly, call attention to the return value.
366         if ret is not None and ret != 0:
367             logger.error('Exit %s', ret)
368         else:
369             logger.debug('Exit %s', ret)
370         sys.exit(ret)
371
372     return initialize_wrapper